fix a lot of problems in AutoSAM ControlNet but still a lot of problem

pull/57/head
Chengsong Zhang 2023-04-23 09:47:53 +08:00
parent ae36032134
commit 738d70ae27
2 changed files with 63 additions and 42 deletions

View File

@ -1,6 +1,7 @@
import os import os
import gc import gc
import glob import glob
import copy
from PIL import Image from PIL import Image
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
@ -11,7 +12,7 @@ from modules.paths import extensions_dir
from modules.devices import torch_gc from modules.devices import torch_gc
global_sam = None global_sam: SamAutomaticMaskGenerator = None
sem_seg_cache = OrderedDict() sem_seg_cache = OrderedDict()
sam_annotator_dir = os.path.join(scripts.basedir(), "annotator") sam_annotator_dir = os.path.join(scripts.basedir(), "annotator")
original_uniformer_inference_segmentor = None original_uniformer_inference_segmentor = None
@ -19,7 +20,7 @@ original_oneformer_draw_sem_seg = None
def blend_image_and_seg(image, seg, alpha=0.5): def blend_image_and_seg(image, seg, alpha=0.5):
image_blend = np.array(image) * (1 - alpha) + np.array(seg) * alpha image_blend = image * (1 - alpha) + np.array(seg) * alpha
return Image.fromarray(image_blend.astype(np.uint8)) return Image.fromarray(image_blend.astype(np.uint8))
@ -40,52 +41,59 @@ def clear_sem_sam_cache():
def sem_sam_garbage_collect(): def sem_sam_garbage_collect():
if shared.cmd_opts.lowvram: if shared.cmd_opts.lowvram:
for _, model in sem_seg_cache: for model_key, model in sem_seg_cache:
if model_key == "uniformer":
from annotator.uniformer import unload_uniformer_model
unload_uniformer_model()
else:
model.unload_model() model.unload_model()
gc.collect() gc.collect()
torch_gc() torch_gc()
def strengthen_sem_seg(class_ids, img): def strengthen_sem_seg(class_ids, img):
print("Auto SAM strengthening semantic segmentation")
import pycocotools.mask as maskUtils import pycocotools.mask as maskUtils
semantc_mask = class_ids.clone() semantc_mask = copy.deepcopy(class_ids)
annotations = global_sam(img) annotations = global_sam.generate(img)
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
print(f"Auto SAM generated {len(annotations)} masks")
for ann in annotations: for ann in annotations:
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool() valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
propose_classes_ids = class_ids[valid_mask] propose_classes_ids = torch.tensor(class_ids[valid_mask])
num_class_proposals = len(torch.unique(propose_classes_ids)) num_class_proposals = len(torch.unique(propose_classes_ids))
if num_class_proposals == 1: if num_class_proposals == 1:
semantc_mask[valid_mask] = propose_classes_ids[0] semantc_mask[valid_mask] = propose_classes_ids[0].numpy()
continue continue
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
semantc_mask[valid_mask] = top_1_propose_class_ids semantc_mask[valid_mask] = top_1_propose_class_ids.numpy()
print("Auto SAM strengthen process end")
return semantc_mask return semantc_mask
def random_segmentation(img): def random_segmentation(img):
print("Generating random segmentation for Edit-Anything") print("Auto SAM generating random segmentation for Edit-Anything")
img_np = np.array(img) img_np = np.array(img.convert("RGB"))
annotations = global_sam(img_np) annotations = global_sam.generate(img_np)
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
if len(annotations) == 0: print(f"Auto SAM generated {len(annotations)} masks")
return []
H, W, C = img_np.shape H, W, C = img_np.shape
cnet_input = np.zeros((H, W), dtype=np.uint16) cnet_input = np.zeros((H, W), dtype=np.uint16)
for idx, annotation in enumerate(annotations): for idx, annotation in enumerate(annotations):
current_seg = annotation['segmentation'] current_seg = annotation['segmentation']
cnet_input[current_seg] = idx + 1 cnet_input[current_seg] = idx + 1 # TODO: Add random mask, not the ugly detected map
detected_map = np.zeros((cnet_input.shape[0], cnet_input.shape[1], 3)) detected_map = np.zeros((cnet_input.shape[0], cnet_input.shape[1], 3))
detected_map[:, :, 0] = cnet_input % 256 detected_map[:, :, 0] = cnet_input % 256
detected_map[:, :, 1] = cnet_input // 256 detected_map[:, :, 1] = cnet_input // 256
from annotator.util import HWC3 from annotator.util import HWC3
detected_map = HWC3(detected_map.astype(np.uint8)) detected_map = HWC3(detected_map.astype(np.uint8))
return [blend_image_and_seg(img, detected_map), Image.fromarray(detected_map)], "Random segmentation done. Left is blended image, right is ControlNet input." print("Auto SAM generation process end")
return [blend_image_and_seg(img_np, detected_map), Image.fromarray(detected_map)], "Random segmentation done. Left is blended image, right is ControlNet input."
def image_layer_image(layout_input_image, layout_output_path): def image_layer_image(layout_input_image, layout_output_path):
img_np = np.array(layout_input_image) img_np = np.array(layout_input_image)
annotations = global_sam(img_np) annotations = global_sam.generate(img_np)
print(f"AutoSAM generated {len(annotations)} annotations") print(f"AutoSAM generated {len(annotations)} annotations")
annotations = sorted(annotations, key=lambda x: x['area']) annotations = sorted(annotations, key=lambda x: x['area'])
for idx, annotation in enumerate(annotations): for idx, annotation in enumerate(annotations):
@ -123,7 +131,7 @@ def inject_inference_segmentor(model, img):
def inject_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True, edge_color=(1.0, 1.0, 1.0)): def inject_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True, edge_color=(1.0, 1.0, 1.0)):
if isinstance(sem_seg, torch.Tensor): if isinstance(sem_seg, torch.Tensor):
sem_seg = sem_seg.numpy() sem_seg = sem_seg.numpy() # TODO: inject another function for oneformer
return original_oneformer_draw_sem_seg(self, strengthen_sem_seg(sem_seg), area_threshold, alpha, is_text, edge_color) return original_oneformer_draw_sem_seg(self, strengthen_sem_seg(sem_seg), area_threshold, alpha, is_text, edge_color)
@ -140,7 +148,7 @@ def _uniformer(img):
from annotator.uniformer import apply_uniformer from annotator.uniformer import apply_uniformer
sem_seg_cache["uniformer"] = apply_uniformer sem_seg_cache["uniformer"] = apply_uniformer
result = sem_seg_cache["uniformer"](img) result = sem_seg_cache["uniformer"](img)
return result, True return result
def _oneformer(img, dataset="coco"): def _oneformer(img, dataset="coco"):
@ -149,10 +157,10 @@ def _oneformer(img, dataset="coco"):
from annotator.oneformer import OneformerDetector from annotator.oneformer import OneformerDetector
sem_seg_cache[oneformer_key] = OneformerDetector(OneformerDetector.configs[dataset]) sem_seg_cache[oneformer_key] = OneformerDetector(OneformerDetector.configs[dataset])
result = sem_seg_cache[oneformer_key](img) result = sem_seg_cache[oneformer_key](img)
return result, True return result
def semantic_segmentation(input_image, annotator_name): def semantic_segmentation(input_image, annotator_name, processor_res):
if input_image is None: if input_image is None:
return [], "No input image." return [], "No input image."
if "seg" in annotator_name: if "seg" in annotator_name:
@ -160,26 +168,27 @@ def semantic_segmentation(input_image, annotator_name):
return [], "ControlNet extension not found." return [], "ControlNet extension not found."
global original_uniformer_inference_segmentor global original_uniformer_inference_segmentor
global original_oneformer_draw_sem_seg global original_oneformer_draw_sem_seg
input_image_np = np.array(input_image) from annotator.util import resize_image, HWC3
input_image = resize_image(HWC3(np.array(input_image)), processor_res)
print("Generating semantic segmentation without SAM") print("Generating semantic segmentation without SAM")
if annotator_name == "seg_ufade20k": if annotator_name == "seg_ufade20k":
original_semseg = _uniformer(input_image_np) original_semseg = _uniformer(input_image)
print("Generating semantic segmentation with SAM") print("Generating semantic segmentation with SAM")
import annotator.uniformer as uniformer import annotator.uniformer as uniformer
original_uniformer_inference_segmentor = uniformer.inference_segmentor original_uniformer_inference_segmentor = uniformer.inference_segmentor
uniformer.inference_segmentor = inject_inference_segmentor uniformer.inference_segmentor = inject_inference_segmentor
sam_semseg = _uniformer(input_image_np) sam_semseg = _uniformer(input_image)
uniformer.inference_segmentor = original_uniformer_inference_segmentor uniformer.inference_segmentor = original_uniformer_inference_segmentor
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)] output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)]
return output_gallery, "Uniformer semantic segmentation of ade20k done. Left is segmentation before SAM, right is segmentation after SAM." return output_gallery, "Uniformer semantic segmentation of ade20k done. Left is segmentation before SAM, right is segmentation after SAM."
else: else:
dataset = annotator_name.split('_')[-1][2:] dataset = annotator_name.split('_')[-1][2:]
original_semseg = _oneformer(input_image_np, dataset) original_semseg = _oneformer(input_image, dataset=dataset)
print("Generating semantic segmentation with SAM") print("Generating semantic segmentation with SAM")
from annotator.oneformer.oneformer.demo.visualizer import Visualizer from annotator.oneformer.oneformer.demo.visualizer import Visualizer
original_oneformer_draw_sem_seg = Visualizer.draw_sem_seg original_oneformer_draw_sem_seg = Visualizer.draw_sem_seg
Visualizer.draw_sem_seg = inject_sem_seg Visualizer.draw_sem_seg = inject_sem_seg
sam_semseg = _oneformer(input_image_np, dataset) sam_semseg = _oneformer(input_image, dataset=dataset)
Visualizer.draw_sem_seg = original_oneformer_draw_sem_seg Visualizer.draw_sem_seg = original_oneformer_draw_sem_seg
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)] output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)]
return output_gallery, f"Oneformer semantic segmentation of {dataset} done. Left is segmentation before SAM, right is segmentation after SAM." return output_gallery, f"Oneformer semantic segmentation of {dataset} done. Left is segmentation before SAM, right is segmentation after SAM."
@ -187,7 +196,7 @@ def semantic_segmentation(input_image, annotator_name):
return random_segmentation(input_image) return random_segmentation(input_image)
def categorical_mask_image(crop_processor, crop_category_input, crop_input_image): def categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image):
if crop_input_image is None: if crop_input_image is None:
return "No input image." return "No input image."
if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link(): if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link():
@ -199,9 +208,10 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image
filter_classes = [int(i) for i in filter_classes] filter_classes = [int(i) for i in filter_classes]
except: except:
return "Illegal class id. You may have input some string." return "Illegal class id. You may have input some string."
from annotator.util import resize_image, HWC3
crop_input_image = resize_image(HWC3(np.array(crop_input_image)), crop_processor_res)
global original_uniformer_inference_segmentor global original_uniformer_inference_segmentor
global original_oneformer_draw_sem_seg global original_oneformer_draw_sem_seg
input_image_np = np.array(crop_input_image)
print(f"Generating categories with processor {crop_processor}") print(f"Generating categories with processor {crop_processor}")
if crop_processor == "seg_ufade20k": if crop_processor == "seg_ufade20k":
import annotator.uniformer as uniformer import annotator.uniformer as uniformer
@ -209,7 +219,7 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image
uniformer.inference_segmentor = inject_inference_segmentor uniformer.inference_segmentor = inject_inference_segmentor
tmp_ouis = uniformer.show_result_pyplot tmp_ouis = uniformer.show_result_pyplot
uniformer.show_result_pyplot = inject_show_result_pyplot uniformer.show_result_pyplot = inject_show_result_pyplot
sam_semseg = _uniformer(input_image_np) sam_semseg = _uniformer(crop_input_image)
uniformer.inference_segmentor = original_uniformer_inference_segmentor uniformer.inference_segmentor = original_uniformer_inference_segmentor
uniformer.show_result_pyplot = tmp_ouis uniformer.show_result_pyplot = tmp_ouis
else: else:
@ -218,7 +228,7 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image
tmp_oodss = Visualizer.draw_sem_seg tmp_oodss = Visualizer.draw_sem_seg
Visualizer.draw_sem_seg = inject_sem_seg Visualizer.draw_sem_seg = inject_sem_seg
original_oneformer_draw_sem_seg = inject_oodss original_oneformer_draw_sem_seg = inject_oodss
sam_semseg = _oneformer(input_image_np, dataset) sam_semseg = _oneformer(crop_input_image, dataset=dataset)
Visualizer.draw_sem_seg = tmp_oodss Visualizer.draw_sem_seg = tmp_oodss
mask = np.zeros(sam_semseg.shape, dtype=np.bool_) mask = np.zeros(sam_semseg.shape, dtype=np.bool_)
for i in filter_classes: for i in filter_classes:

View File

@ -298,7 +298,7 @@ def dino_batch_process(
def cnet_seg( def cnet_seg(
sam_model_name, cnet_seg_input_image, cnet_seg_processor, sam_model_name, cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res,
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
@ -310,7 +310,7 @@ def cnet_seg(
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode) auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode)
outputs = semantic_segmentation(cnet_seg_input_image, cnet_seg_processor) outputs = semantic_segmentation(cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res)
sem_sam_garbage_collect() sem_sam_garbage_collect()
garbage_collect(sam) garbage_collect(sam)
return outputs return outputs
@ -335,7 +335,7 @@ def image_layout(
def categorical_mask( def categorical_mask(
sam_model_name, crop_processor, crop_category_input, crop_input_image, sam_model_name, crop_processor, crop_processor_res, crop_category_input, crop_input_image,
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
@ -346,7 +346,7 @@ def categorical_mask(
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio, auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle") auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle")
outputs = categorical_mask_image(crop_processor, crop_category_input, crop_input_image) outputs = categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image)
sem_sam_garbage_collect() sem_sam_garbage_collect()
garbage_collect(sam) garbage_collect(sam)
if isinstance(outputs, str): if isinstance(outputs, str):
@ -356,7 +356,8 @@ def categorical_mask(
def categorical_mask_batch( def categorical_mask_batch(
sam_model_name, crop_processor, crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, sam_model_name, crop_processor, crop_processor_res,
crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir,
crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background,
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh, auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh, auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
@ -374,7 +375,7 @@ def categorical_mask_batch(
print(f"Processing {image_index}/{len(all_files)} {input_image_file}") print(f"Processing {image_index}/{len(all_files)} {input_image_file}")
try: try:
crop_input_image = Image.open(input_image_file).convert("RGB") crop_input_image = Image.open(input_image_file).convert("RGB")
outputs = categorical_mask_image(crop_processor, crop_category_input, crop_input_image) outputs = categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image)
if isinstance(outputs, str): if isinstance(outputs, str):
outputs = f"Image {image_index}: {outputs}" outputs = f"Image {image_index}: {outputs}"
print(outputs) print(outputs)
@ -480,6 +481,16 @@ def ui_batch(is_dino):
return dino_batch_dilation_amt, dino_batch_source_dir, dino_batch_dest_dir, dino_batch_output_per_image, dino_batch_save_image, dino_batch_save_mask, dino_batch_save_image_with_mask, dino_batch_save_background, dino_batch_run_button, dino_batch_progress return dino_batch_dilation_amt, dino_batch_source_dir, dino_batch_dest_dir, dino_batch_output_per_image, dino_batch_save_image, dino_batch_save_mask, dino_batch_save_image_with_mask, dino_batch_save_background, dino_batch_run_button, dino_batch_progress
def ui_processor(use_random=True):
processor_choices = ["seg_ufade20k", "seg_ofade20k", "seg_ofcoco"]
if use_random:
processor_choices.append("random")
with gr.Row(): # TODO: Add pixel perfect, preprocessor_res > 64
cnet_seg_processor = gr.Radio(choices=processor_choices, value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ")
cnet_seg_processor_res = gr.Slider(label="Preprocessor res", value=512, minimum=64, maximum=2048, step=1)
return cnet_seg_processor, cnet_seg_processor_res
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
@ -604,14 +615,14 @@ class Script(scripts.Script):
gr.Markdown( gr.Markdown(
"You can enhance semantic segmentation for control_v11p_sd15_seg from lllyasviel. " "You can enhance semantic segmentation for control_v11p_sd15_seg from lllyasviel. "
"Non-semantic segmentation for [Edit-Anything](https://github.com/sail-sg/EditAnything) will be supported [when they convert their models to lllyasviel format](https://github.com/sail-sg/EditAnything/issues/14).") "Non-semantic segmentation for [Edit-Anything](https://github.com/sail-sg/EditAnything) will be supported [when they convert their models to lllyasviel format](https://github.com/sail-sg/EditAnything/issues/14).")
cnet_seg_processor = gr.Radio(choices=["seg_ufade20k", "seg_ofade20k", "seg_ofcoco", "random"], value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ") cnet_seg_processor, cnet_seg_processor_res = ui_processor()
cnet_seg_input_image = gr.Image(label="Image for Auto Segmentation", source="upload", type="pil", image_mode="RGBA") cnet_seg_input_image = gr.Image(label="Image for Auto Segmentation", source="upload", type="pil", image_mode="RGBA")
cnet_seg_output_gallery = gr.Gallery(label="Auto segmentation output").style(grid=2) cnet_seg_output_gallery = gr.Gallery(label="Auto segmentation output").style(grid=2)
cnet_seg_submit = gr.Button(value="Generate segmentation image") cnet_seg_submit = gr.Button(value="Generate segmentation image")
cnet_seg_status = gr.Text(value="", label="Segmentation status") cnet_seg_status = gr.Text(value="", label="Segmentation status")
cnet_seg_submit.click( cnet_seg_submit.click(
fn=cnet_seg, fn=cnet_seg,
inputs=[sam_model_name, cnet_seg_input_image, cnet_seg_processor, *auto_sam_config], inputs=[sam_model_name, cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res, *auto_sam_config],
outputs=[cnet_seg_output_gallery, cnet_seg_status]) outputs=[cnet_seg_output_gallery, cnet_seg_status])
with gr.Row(visible=(max_cn_num() > 0)): with gr.Row(visible=(max_cn_num() > 0)):
cnet_seg_enable_copy = gr.Checkbox(value=False, label='Copy to ControlNet Segmentation') cnet_seg_enable_copy = gr.Checkbox(value=False, label='Copy to ControlNet Segmentation')
@ -649,7 +660,7 @@ class Script(scripts.Script):
"You can mask images by their categories via semantic segmentation. Please enter category ids (integers), separated by `+`. " "You can mask images by their categories via semantic segmentation. Please enter category ids (integers), separated by `+`. "
"Visit [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py#L12-L207) for ade20k " "Visit [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py#L12-L207) for ade20k "
"and [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/detectron2/data/datasets/builtin_meta.py#L20-L153) for coco to get category->id map.") "and [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/detectron2/data/datasets/builtin_meta.py#L20-L153) for coco to get category->id map.")
crop_processor = gr.Radio(choices=["seg_ufade20k", "seg_ofade20k", "seg_ofcoco"], value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ") crop_processor, crop_processor_res = ui_processor(False)
crop_category_input = gr.Textbox(placeholder="Enter categody ids, separated by +. For example, if you want bed+person, your input should be 7+12 for ade20k and 65+1 for coco.", label="Enter category IDs") crop_category_input = gr.Textbox(placeholder="Enter categody ids, separated by +. For example, if you want bed+person, your input should be 7+12 for ade20k and 65+1 for coco.", label="Enter category IDs")
with gr.Tabs(): with gr.Tabs():
with gr.TabItem(label="Single Image"): with gr.TabItem(label="Single Image"):
@ -660,7 +671,7 @@ class Script(scripts.Script):
crop_result = gr.Text(value="", label="Categorical mask status") crop_result = gr.Text(value="", label="Categorical mask status")
crop_submit.click( crop_submit.click(
fn=categorical_mask, fn=categorical_mask,
inputs=[sam_model_name, crop_processor, crop_category_input, crop_input_image, *auto_sam_config], inputs=[sam_model_name, crop_processor, crop_processor_res, crop_category_input, crop_input_image, *auto_sam_config],
outputs=[crop_output_gallery, crop_result]) outputs=[crop_output_gallery, crop_result])
crop_inpaint_enable, crop_cnet_inpaint_invert, crop_cnet_inpaint_idx = ui_inpaint(is_img2img, max_cn_num()) crop_inpaint_enable, crop_cnet_inpaint_invert, crop_cnet_inpaint_idx = ui_inpaint(is_img2img, max_cn_num())
crop_dilation_checkbox, crop_dilation_output_gallery = ui_dilation(crop_output_gallery, crop_padding, crop_input_image) crop_dilation_checkbox, crop_dilation_output_gallery = ui_dilation(crop_output_gallery, crop_padding, crop_input_image)
@ -676,8 +687,8 @@ class Script(scripts.Script):
crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, _, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, crop_batch_run_button, crop_batch_progress = ui_batch(False) crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, _, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, crop_batch_run_button, crop_batch_progress = ui_batch(False)
crop_batch_run_button.click( crop_batch_run_button.click(
fn=categorical_mask_batch, fn=categorical_mask_batch,
inputs=[sam_model_name, crop_processor, crop_category_input, crop_batch_dilation_amt, inputs=[sam_model_name, crop_processor, crop_processor_res,
crop_batch_source_dir, crop_batch_dest_dir, crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir,
crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, *auto_sam_config], crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, *auto_sam_config],
outputs=[crop_batch_progress]) outputs=[crop_batch_progress])