From 738d70ae2790cb7ea27d3858e49baf21c54df666 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 23 Apr 2023 09:47:53 +0800 Subject: [PATCH] fix a lot of problems in AutoSAM ControlNet but still a lot of problem --- scripts/auto.py | 70 ++++++++++++++++++++++++++++--------------------- scripts/sam.py | 35 ++++++++++++++++--------- 2 files changed, 63 insertions(+), 42 deletions(-) diff --git a/scripts/auto.py b/scripts/auto.py index 3f9bdf7..230d538 100644 --- a/scripts/auto.py +++ b/scripts/auto.py @@ -1,6 +1,7 @@ import os import gc import glob +import copy from PIL import Image from collections import OrderedDict import numpy as np @@ -11,7 +12,7 @@ from modules.paths import extensions_dir from modules.devices import torch_gc -global_sam = None +global_sam: SamAutomaticMaskGenerator = None sem_seg_cache = OrderedDict() sam_annotator_dir = os.path.join(scripts.basedir(), "annotator") original_uniformer_inference_segmentor = None @@ -19,7 +20,7 @@ original_oneformer_draw_sem_seg = None def blend_image_and_seg(image, seg, alpha=0.5): - image_blend = np.array(image) * (1 - alpha) + np.array(seg) * alpha + image_blend = image * (1 - alpha) + np.array(seg) * alpha return Image.fromarray(image_blend.astype(np.uint8)) @@ -40,52 +41,59 @@ def clear_sem_sam_cache(): def sem_sam_garbage_collect(): if shared.cmd_opts.lowvram: - for _, model in sem_seg_cache: - model.unload_model() + for model_key, model in sem_seg_cache: + if model_key == "uniformer": + from annotator.uniformer import unload_uniformer_model + unload_uniformer_model() + else: + model.unload_model() gc.collect() torch_gc() def strengthen_sem_seg(class_ids, img): + print("Auto SAM strengthening semantic segmentation") import pycocotools.mask as maskUtils - semantc_mask = class_ids.clone() - annotations = global_sam(img) + semantc_mask = copy.deepcopy(class_ids) + annotations = global_sam.generate(img) annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) + print(f"Auto SAM generated {len(annotations)} masks") for ann in annotations: valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool() - propose_classes_ids = class_ids[valid_mask] + propose_classes_ids = torch.tensor(class_ids[valid_mask]) num_class_proposals = len(torch.unique(propose_classes_ids)) if num_class_proposals == 1: - semantc_mask[valid_mask] = propose_classes_ids[0] + semantc_mask[valid_mask] = propose_classes_ids[0].numpy() continue top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices - semantc_mask[valid_mask] = top_1_propose_class_ids + semantc_mask[valid_mask] = top_1_propose_class_ids.numpy() + print("Auto SAM strengthen process end") return semantc_mask def random_segmentation(img): - print("Generating random segmentation for Edit-Anything") - img_np = np.array(img) - annotations = global_sam(img_np) + print("Auto SAM generating random segmentation for Edit-Anything") + img_np = np.array(img.convert("RGB")) + annotations = global_sam.generate(img_np) annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) - if len(annotations) == 0: - return [] + print(f"Auto SAM generated {len(annotations)} masks") H, W, C = img_np.shape cnet_input = np.zeros((H, W), dtype=np.uint16) for idx, annotation in enumerate(annotations): current_seg = annotation['segmentation'] - cnet_input[current_seg] = idx + 1 + cnet_input[current_seg] = idx + 1 # TODO: Add random mask, not the ugly detected map detected_map = np.zeros((cnet_input.shape[0], cnet_input.shape[1], 3)) detected_map[:, :, 0] = cnet_input % 256 detected_map[:, :, 1] = cnet_input // 256 from annotator.util import HWC3 detected_map = HWC3(detected_map.astype(np.uint8)) - return [blend_image_and_seg(img, detected_map), Image.fromarray(detected_map)], "Random segmentation done. Left is blended image, right is ControlNet input." + print("Auto SAM generation process end") + return [blend_image_and_seg(img_np, detected_map), Image.fromarray(detected_map)], "Random segmentation done. Left is blended image, right is ControlNet input." def image_layer_image(layout_input_image, layout_output_path): img_np = np.array(layout_input_image) - annotations = global_sam(img_np) + annotations = global_sam.generate(img_np) print(f"AutoSAM generated {len(annotations)} annotations") annotations = sorted(annotations, key=lambda x: x['area']) for idx, annotation in enumerate(annotations): @@ -123,7 +131,7 @@ def inject_inference_segmentor(model, img): def inject_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True, edge_color=(1.0, 1.0, 1.0)): if isinstance(sem_seg, torch.Tensor): - sem_seg = sem_seg.numpy() + sem_seg = sem_seg.numpy() # TODO: inject another function for oneformer return original_oneformer_draw_sem_seg(self, strengthen_sem_seg(sem_seg), area_threshold, alpha, is_text, edge_color) @@ -140,7 +148,7 @@ def _uniformer(img): from annotator.uniformer import apply_uniformer sem_seg_cache["uniformer"] = apply_uniformer result = sem_seg_cache["uniformer"](img) - return result, True + return result def _oneformer(img, dataset="coco"): @@ -149,10 +157,10 @@ def _oneformer(img, dataset="coco"): from annotator.oneformer import OneformerDetector sem_seg_cache[oneformer_key] = OneformerDetector(OneformerDetector.configs[dataset]) result = sem_seg_cache[oneformer_key](img) - return result, True + return result -def semantic_segmentation(input_image, annotator_name): +def semantic_segmentation(input_image, annotator_name, processor_res): if input_image is None: return [], "No input image." if "seg" in annotator_name: @@ -160,26 +168,27 @@ def semantic_segmentation(input_image, annotator_name): return [], "ControlNet extension not found." global original_uniformer_inference_segmentor global original_oneformer_draw_sem_seg - input_image_np = np.array(input_image) + from annotator.util import resize_image, HWC3 + input_image = resize_image(HWC3(np.array(input_image)), processor_res) print("Generating semantic segmentation without SAM") if annotator_name == "seg_ufade20k": - original_semseg = _uniformer(input_image_np) + original_semseg = _uniformer(input_image) print("Generating semantic segmentation with SAM") import annotator.uniformer as uniformer original_uniformer_inference_segmentor = uniformer.inference_segmentor uniformer.inference_segmentor = inject_inference_segmentor - sam_semseg = _uniformer(input_image_np) + sam_semseg = _uniformer(input_image) uniformer.inference_segmentor = original_uniformer_inference_segmentor output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)] return output_gallery, "Uniformer semantic segmentation of ade20k done. Left is segmentation before SAM, right is segmentation after SAM." else: dataset = annotator_name.split('_')[-1][2:] - original_semseg = _oneformer(input_image_np, dataset) + original_semseg = _oneformer(input_image, dataset=dataset) print("Generating semantic segmentation with SAM") from annotator.oneformer.oneformer.demo.visualizer import Visualizer original_oneformer_draw_sem_seg = Visualizer.draw_sem_seg Visualizer.draw_sem_seg = inject_sem_seg - sam_semseg = _oneformer(input_image_np, dataset) + sam_semseg = _oneformer(input_image, dataset=dataset) Visualizer.draw_sem_seg = original_oneformer_draw_sem_seg output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)] return output_gallery, f"Oneformer semantic segmentation of {dataset} done. Left is segmentation before SAM, right is segmentation after SAM." @@ -187,7 +196,7 @@ def semantic_segmentation(input_image, annotator_name): return random_segmentation(input_image) -def categorical_mask_image(crop_processor, crop_category_input, crop_input_image): +def categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image): if crop_input_image is None: return "No input image." if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link(): @@ -199,9 +208,10 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image filter_classes = [int(i) for i in filter_classes] except: return "Illegal class id. You may have input some string." + from annotator.util import resize_image, HWC3 + crop_input_image = resize_image(HWC3(np.array(crop_input_image)), crop_processor_res) global original_uniformer_inference_segmentor global original_oneformer_draw_sem_seg - input_image_np = np.array(crop_input_image) print(f"Generating categories with processor {crop_processor}") if crop_processor == "seg_ufade20k": import annotator.uniformer as uniformer @@ -209,7 +219,7 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image uniformer.inference_segmentor = inject_inference_segmentor tmp_ouis = uniformer.show_result_pyplot uniformer.show_result_pyplot = inject_show_result_pyplot - sam_semseg = _uniformer(input_image_np) + sam_semseg = _uniformer(crop_input_image) uniformer.inference_segmentor = original_uniformer_inference_segmentor uniformer.show_result_pyplot = tmp_ouis else: @@ -218,7 +228,7 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image tmp_oodss = Visualizer.draw_sem_seg Visualizer.draw_sem_seg = inject_sem_seg original_oneformer_draw_sem_seg = inject_oodss - sam_semseg = _oneformer(input_image_np, dataset) + sam_semseg = _oneformer(crop_input_image, dataset=dataset) Visualizer.draw_sem_seg = tmp_oodss mask = np.zeros(sam_semseg.shape, dtype=np.bool_) for i in filter_classes: diff --git a/scripts/sam.py b/scripts/sam.py index 08e1110..5df18ba 100644 --- a/scripts/sam.py +++ b/scripts/sam.py @@ -298,7 +298,7 @@ def dino_batch_process( def cnet_seg( - sam_model_name, cnet_seg_input_image, cnet_seg_processor, + sam_model_name, cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, @@ -310,7 +310,7 @@ def cnet_seg( auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode) - outputs = semantic_segmentation(cnet_seg_input_image, cnet_seg_processor) + outputs = semantic_segmentation(cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res) sem_sam_garbage_collect() garbage_collect(sam) return outputs @@ -335,7 +335,7 @@ def image_layout( def categorical_mask( - sam_model_name, crop_processor, crop_category_input, crop_input_image, + sam_model_name, crop_processor, crop_processor_res, crop_category_input, crop_input_image, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, @@ -346,7 +346,7 @@ def categorical_mask( auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle") - outputs = categorical_mask_image(crop_processor, crop_category_input, crop_input_image) + outputs = categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image) sem_sam_garbage_collect() garbage_collect(sam) if isinstance(outputs, str): @@ -356,7 +356,8 @@ def categorical_mask( def categorical_mask_batch( - sam_model_name, crop_processor, crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, + sam_model_name, crop_processor, crop_processor_res, + crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, @@ -374,7 +375,7 @@ def categorical_mask_batch( print(f"Processing {image_index}/{len(all_files)} {input_image_file}") try: crop_input_image = Image.open(input_image_file).convert("RGB") - outputs = categorical_mask_image(crop_processor, crop_category_input, crop_input_image) + outputs = categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image) if isinstance(outputs, str): outputs = f"Image {image_index}: {outputs}" print(outputs) @@ -480,6 +481,16 @@ def ui_batch(is_dino): return dino_batch_dilation_amt, dino_batch_source_dir, dino_batch_dest_dir, dino_batch_output_per_image, dino_batch_save_image, dino_batch_save_mask, dino_batch_save_image_with_mask, dino_batch_save_background, dino_batch_run_button, dino_batch_progress +def ui_processor(use_random=True): + processor_choices = ["seg_ufade20k", "seg_ofade20k", "seg_ofcoco"] + if use_random: + processor_choices.append("random") + with gr.Row(): # TODO: Add pixel perfect, preprocessor_res > 64 + cnet_seg_processor = gr.Radio(choices=processor_choices, value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ") + cnet_seg_processor_res = gr.Slider(label="Preprocessor res", value=512, minimum=64, maximum=2048, step=1) + return cnet_seg_processor, cnet_seg_processor_res + + class Script(scripts.Script): def title(self): @@ -604,14 +615,14 @@ class Script(scripts.Script): gr.Markdown( "You can enhance semantic segmentation for control_v11p_sd15_seg from lllyasviel. " "Non-semantic segmentation for [Edit-Anything](https://github.com/sail-sg/EditAnything) will be supported [when they convert their models to lllyasviel format](https://github.com/sail-sg/EditAnything/issues/14).") - cnet_seg_processor = gr.Radio(choices=["seg_ufade20k", "seg_ofade20k", "seg_ofcoco", "random"], value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ") + cnet_seg_processor, cnet_seg_processor_res = ui_processor() cnet_seg_input_image = gr.Image(label="Image for Auto Segmentation", source="upload", type="pil", image_mode="RGBA") cnet_seg_output_gallery = gr.Gallery(label="Auto segmentation output").style(grid=2) cnet_seg_submit = gr.Button(value="Generate segmentation image") cnet_seg_status = gr.Text(value="", label="Segmentation status") cnet_seg_submit.click( fn=cnet_seg, - inputs=[sam_model_name, cnet_seg_input_image, cnet_seg_processor, *auto_sam_config], + inputs=[sam_model_name, cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res, *auto_sam_config], outputs=[cnet_seg_output_gallery, cnet_seg_status]) with gr.Row(visible=(max_cn_num() > 0)): cnet_seg_enable_copy = gr.Checkbox(value=False, label='Copy to ControlNet Segmentation') @@ -649,7 +660,7 @@ class Script(scripts.Script): "You can mask images by their categories via semantic segmentation. Please enter category ids (integers), separated by `+`. " "Visit [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py#L12-L207) for ade20k " "and [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/detectron2/data/datasets/builtin_meta.py#L20-L153) for coco to get category->id map.") - crop_processor = gr.Radio(choices=["seg_ufade20k", "seg_ofade20k", "seg_ofcoco"], value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ") + crop_processor, crop_processor_res = ui_processor(False) crop_category_input = gr.Textbox(placeholder="Enter categody ids, separated by +. For example, if you want bed+person, your input should be 7+12 for ade20k and 65+1 for coco.", label="Enter category IDs") with gr.Tabs(): with gr.TabItem(label="Single Image"): @@ -660,7 +671,7 @@ class Script(scripts.Script): crop_result = gr.Text(value="", label="Categorical mask status") crop_submit.click( fn=categorical_mask, - inputs=[sam_model_name, crop_processor, crop_category_input, crop_input_image, *auto_sam_config], + inputs=[sam_model_name, crop_processor, crop_processor_res, crop_category_input, crop_input_image, *auto_sam_config], outputs=[crop_output_gallery, crop_result]) crop_inpaint_enable, crop_cnet_inpaint_invert, crop_cnet_inpaint_idx = ui_inpaint(is_img2img, max_cn_num()) crop_dilation_checkbox, crop_dilation_output_gallery = ui_dilation(crop_output_gallery, crop_padding, crop_input_image) @@ -676,8 +687,8 @@ class Script(scripts.Script): crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, _, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, crop_batch_run_button, crop_batch_progress = ui_batch(False) crop_batch_run_button.click( fn=categorical_mask_batch, - inputs=[sam_model_name, crop_processor, crop_category_input, crop_batch_dilation_amt, - crop_batch_source_dir, crop_batch_dest_dir, + inputs=[sam_model_name, crop_processor, crop_processor_res, + crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, *auto_sam_config], outputs=[crop_batch_progress])