fix a lot of problems in AutoSAM ControlNet but still a lot of problem
parent
ae36032134
commit
738d70ae27
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
|
import copy
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -11,7 +12,7 @@ from modules.paths import extensions_dir
|
||||||
from modules.devices import torch_gc
|
from modules.devices import torch_gc
|
||||||
|
|
||||||
|
|
||||||
global_sam = None
|
global_sam: SamAutomaticMaskGenerator = None
|
||||||
sem_seg_cache = OrderedDict()
|
sem_seg_cache = OrderedDict()
|
||||||
sam_annotator_dir = os.path.join(scripts.basedir(), "annotator")
|
sam_annotator_dir = os.path.join(scripts.basedir(), "annotator")
|
||||||
original_uniformer_inference_segmentor = None
|
original_uniformer_inference_segmentor = None
|
||||||
|
|
@ -19,7 +20,7 @@ original_oneformer_draw_sem_seg = None
|
||||||
|
|
||||||
|
|
||||||
def blend_image_and_seg(image, seg, alpha=0.5):
|
def blend_image_and_seg(image, seg, alpha=0.5):
|
||||||
image_blend = np.array(image) * (1 - alpha) + np.array(seg) * alpha
|
image_blend = image * (1 - alpha) + np.array(seg) * alpha
|
||||||
return Image.fromarray(image_blend.astype(np.uint8))
|
return Image.fromarray(image_blend.astype(np.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,52 +41,59 @@ def clear_sem_sam_cache():
|
||||||
|
|
||||||
def sem_sam_garbage_collect():
|
def sem_sam_garbage_collect():
|
||||||
if shared.cmd_opts.lowvram:
|
if shared.cmd_opts.lowvram:
|
||||||
for _, model in sem_seg_cache:
|
for model_key, model in sem_seg_cache:
|
||||||
|
if model_key == "uniformer":
|
||||||
|
from annotator.uniformer import unload_uniformer_model
|
||||||
|
unload_uniformer_model()
|
||||||
|
else:
|
||||||
model.unload_model()
|
model.unload_model()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def strengthen_sem_seg(class_ids, img):
|
def strengthen_sem_seg(class_ids, img):
|
||||||
|
print("Auto SAM strengthening semantic segmentation")
|
||||||
import pycocotools.mask as maskUtils
|
import pycocotools.mask as maskUtils
|
||||||
semantc_mask = class_ids.clone()
|
semantc_mask = copy.deepcopy(class_ids)
|
||||||
annotations = global_sam(img)
|
annotations = global_sam.generate(img)
|
||||||
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
|
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
|
||||||
|
print(f"Auto SAM generated {len(annotations)} masks")
|
||||||
for ann in annotations:
|
for ann in annotations:
|
||||||
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
|
valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
|
||||||
propose_classes_ids = class_ids[valid_mask]
|
propose_classes_ids = torch.tensor(class_ids[valid_mask])
|
||||||
num_class_proposals = len(torch.unique(propose_classes_ids))
|
num_class_proposals = len(torch.unique(propose_classes_ids))
|
||||||
if num_class_proposals == 1:
|
if num_class_proposals == 1:
|
||||||
semantc_mask[valid_mask] = propose_classes_ids[0]
|
semantc_mask[valid_mask] = propose_classes_ids[0].numpy()
|
||||||
continue
|
continue
|
||||||
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
|
top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
|
||||||
semantc_mask[valid_mask] = top_1_propose_class_ids
|
semantc_mask[valid_mask] = top_1_propose_class_ids.numpy()
|
||||||
|
print("Auto SAM strengthen process end")
|
||||||
return semantc_mask
|
return semantc_mask
|
||||||
|
|
||||||
|
|
||||||
def random_segmentation(img):
|
def random_segmentation(img):
|
||||||
print("Generating random segmentation for Edit-Anything")
|
print("Auto SAM generating random segmentation for Edit-Anything")
|
||||||
img_np = np.array(img)
|
img_np = np.array(img.convert("RGB"))
|
||||||
annotations = global_sam(img_np)
|
annotations = global_sam.generate(img_np)
|
||||||
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
|
annotations = sorted(annotations, key=lambda x: x['area'], reverse=True)
|
||||||
if len(annotations) == 0:
|
print(f"Auto SAM generated {len(annotations)} masks")
|
||||||
return []
|
|
||||||
H, W, C = img_np.shape
|
H, W, C = img_np.shape
|
||||||
cnet_input = np.zeros((H, W), dtype=np.uint16)
|
cnet_input = np.zeros((H, W), dtype=np.uint16)
|
||||||
for idx, annotation in enumerate(annotations):
|
for idx, annotation in enumerate(annotations):
|
||||||
current_seg = annotation['segmentation']
|
current_seg = annotation['segmentation']
|
||||||
cnet_input[current_seg] = idx + 1
|
cnet_input[current_seg] = idx + 1 # TODO: Add random mask, not the ugly detected map
|
||||||
detected_map = np.zeros((cnet_input.shape[0], cnet_input.shape[1], 3))
|
detected_map = np.zeros((cnet_input.shape[0], cnet_input.shape[1], 3))
|
||||||
detected_map[:, :, 0] = cnet_input % 256
|
detected_map[:, :, 0] = cnet_input % 256
|
||||||
detected_map[:, :, 1] = cnet_input // 256
|
detected_map[:, :, 1] = cnet_input // 256
|
||||||
from annotator.util import HWC3
|
from annotator.util import HWC3
|
||||||
detected_map = HWC3(detected_map.astype(np.uint8))
|
detected_map = HWC3(detected_map.astype(np.uint8))
|
||||||
return [blend_image_and_seg(img, detected_map), Image.fromarray(detected_map)], "Random segmentation done. Left is blended image, right is ControlNet input."
|
print("Auto SAM generation process end")
|
||||||
|
return [blend_image_and_seg(img_np, detected_map), Image.fromarray(detected_map)], "Random segmentation done. Left is blended image, right is ControlNet input."
|
||||||
|
|
||||||
|
|
||||||
def image_layer_image(layout_input_image, layout_output_path):
|
def image_layer_image(layout_input_image, layout_output_path):
|
||||||
img_np = np.array(layout_input_image)
|
img_np = np.array(layout_input_image)
|
||||||
annotations = global_sam(img_np)
|
annotations = global_sam.generate(img_np)
|
||||||
print(f"AutoSAM generated {len(annotations)} annotations")
|
print(f"AutoSAM generated {len(annotations)} annotations")
|
||||||
annotations = sorted(annotations, key=lambda x: x['area'])
|
annotations = sorted(annotations, key=lambda x: x['area'])
|
||||||
for idx, annotation in enumerate(annotations):
|
for idx, annotation in enumerate(annotations):
|
||||||
|
|
@ -123,7 +131,7 @@ def inject_inference_segmentor(model, img):
|
||||||
|
|
||||||
def inject_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True, edge_color=(1.0, 1.0, 1.0)):
|
def inject_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True, edge_color=(1.0, 1.0, 1.0)):
|
||||||
if isinstance(sem_seg, torch.Tensor):
|
if isinstance(sem_seg, torch.Tensor):
|
||||||
sem_seg = sem_seg.numpy()
|
sem_seg = sem_seg.numpy() # TODO: inject another function for oneformer
|
||||||
return original_oneformer_draw_sem_seg(self, strengthen_sem_seg(sem_seg), area_threshold, alpha, is_text, edge_color)
|
return original_oneformer_draw_sem_seg(self, strengthen_sem_seg(sem_seg), area_threshold, alpha, is_text, edge_color)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,7 +148,7 @@ def _uniformer(img):
|
||||||
from annotator.uniformer import apply_uniformer
|
from annotator.uniformer import apply_uniformer
|
||||||
sem_seg_cache["uniformer"] = apply_uniformer
|
sem_seg_cache["uniformer"] = apply_uniformer
|
||||||
result = sem_seg_cache["uniformer"](img)
|
result = sem_seg_cache["uniformer"](img)
|
||||||
return result, True
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _oneformer(img, dataset="coco"):
|
def _oneformer(img, dataset="coco"):
|
||||||
|
|
@ -149,10 +157,10 @@ def _oneformer(img, dataset="coco"):
|
||||||
from annotator.oneformer import OneformerDetector
|
from annotator.oneformer import OneformerDetector
|
||||||
sem_seg_cache[oneformer_key] = OneformerDetector(OneformerDetector.configs[dataset])
|
sem_seg_cache[oneformer_key] = OneformerDetector(OneformerDetector.configs[dataset])
|
||||||
result = sem_seg_cache[oneformer_key](img)
|
result = sem_seg_cache[oneformer_key](img)
|
||||||
return result, True
|
return result
|
||||||
|
|
||||||
|
|
||||||
def semantic_segmentation(input_image, annotator_name):
|
def semantic_segmentation(input_image, annotator_name, processor_res):
|
||||||
if input_image is None:
|
if input_image is None:
|
||||||
return [], "No input image."
|
return [], "No input image."
|
||||||
if "seg" in annotator_name:
|
if "seg" in annotator_name:
|
||||||
|
|
@ -160,26 +168,27 @@ def semantic_segmentation(input_image, annotator_name):
|
||||||
return [], "ControlNet extension not found."
|
return [], "ControlNet extension not found."
|
||||||
global original_uniformer_inference_segmentor
|
global original_uniformer_inference_segmentor
|
||||||
global original_oneformer_draw_sem_seg
|
global original_oneformer_draw_sem_seg
|
||||||
input_image_np = np.array(input_image)
|
from annotator.util import resize_image, HWC3
|
||||||
|
input_image = resize_image(HWC3(np.array(input_image)), processor_res)
|
||||||
print("Generating semantic segmentation without SAM")
|
print("Generating semantic segmentation without SAM")
|
||||||
if annotator_name == "seg_ufade20k":
|
if annotator_name == "seg_ufade20k":
|
||||||
original_semseg = _uniformer(input_image_np)
|
original_semseg = _uniformer(input_image)
|
||||||
print("Generating semantic segmentation with SAM")
|
print("Generating semantic segmentation with SAM")
|
||||||
import annotator.uniformer as uniformer
|
import annotator.uniformer as uniformer
|
||||||
original_uniformer_inference_segmentor = uniformer.inference_segmentor
|
original_uniformer_inference_segmentor = uniformer.inference_segmentor
|
||||||
uniformer.inference_segmentor = inject_inference_segmentor
|
uniformer.inference_segmentor = inject_inference_segmentor
|
||||||
sam_semseg = _uniformer(input_image_np)
|
sam_semseg = _uniformer(input_image)
|
||||||
uniformer.inference_segmentor = original_uniformer_inference_segmentor
|
uniformer.inference_segmentor = original_uniformer_inference_segmentor
|
||||||
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)]
|
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)]
|
||||||
return output_gallery, "Uniformer semantic segmentation of ade20k done. Left is segmentation before SAM, right is segmentation after SAM."
|
return output_gallery, "Uniformer semantic segmentation of ade20k done. Left is segmentation before SAM, right is segmentation after SAM."
|
||||||
else:
|
else:
|
||||||
dataset = annotator_name.split('_')[-1][2:]
|
dataset = annotator_name.split('_')[-1][2:]
|
||||||
original_semseg = _oneformer(input_image_np, dataset)
|
original_semseg = _oneformer(input_image, dataset=dataset)
|
||||||
print("Generating semantic segmentation with SAM")
|
print("Generating semantic segmentation with SAM")
|
||||||
from annotator.oneformer.oneformer.demo.visualizer import Visualizer
|
from annotator.oneformer.oneformer.demo.visualizer import Visualizer
|
||||||
original_oneformer_draw_sem_seg = Visualizer.draw_sem_seg
|
original_oneformer_draw_sem_seg = Visualizer.draw_sem_seg
|
||||||
Visualizer.draw_sem_seg = inject_sem_seg
|
Visualizer.draw_sem_seg = inject_sem_seg
|
||||||
sam_semseg = _oneformer(input_image_np, dataset)
|
sam_semseg = _oneformer(input_image, dataset=dataset)
|
||||||
Visualizer.draw_sem_seg = original_oneformer_draw_sem_seg
|
Visualizer.draw_sem_seg = original_oneformer_draw_sem_seg
|
||||||
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)]
|
output_gallery = [original_semseg, sam_semseg, blend_image_and_seg(input_image, original_semseg), blend_image_and_seg(input_image, sam_semseg)]
|
||||||
return output_gallery, f"Oneformer semantic segmentation of {dataset} done. Left is segmentation before SAM, right is segmentation after SAM."
|
return output_gallery, f"Oneformer semantic segmentation of {dataset} done. Left is segmentation before SAM, right is segmentation after SAM."
|
||||||
|
|
@ -187,7 +196,7 @@ def semantic_segmentation(input_image, annotator_name):
|
||||||
return random_segmentation(input_image)
|
return random_segmentation(input_image)
|
||||||
|
|
||||||
|
|
||||||
def categorical_mask_image(crop_processor, crop_category_input, crop_input_image):
|
def categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image):
|
||||||
if crop_input_image is None:
|
if crop_input_image is None:
|
||||||
return "No input image."
|
return "No input image."
|
||||||
if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link():
|
if not os.path.isdir(os.path.join(scripts.basedir(), "annotator")) and not create_symbolic_link():
|
||||||
|
|
@ -199,9 +208,10 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image
|
||||||
filter_classes = [int(i) for i in filter_classes]
|
filter_classes = [int(i) for i in filter_classes]
|
||||||
except:
|
except:
|
||||||
return "Illegal class id. You may have input some string."
|
return "Illegal class id. You may have input some string."
|
||||||
|
from annotator.util import resize_image, HWC3
|
||||||
|
crop_input_image = resize_image(HWC3(np.array(crop_input_image)), crop_processor_res)
|
||||||
global original_uniformer_inference_segmentor
|
global original_uniformer_inference_segmentor
|
||||||
global original_oneformer_draw_sem_seg
|
global original_oneformer_draw_sem_seg
|
||||||
input_image_np = np.array(crop_input_image)
|
|
||||||
print(f"Generating categories with processor {crop_processor}")
|
print(f"Generating categories with processor {crop_processor}")
|
||||||
if crop_processor == "seg_ufade20k":
|
if crop_processor == "seg_ufade20k":
|
||||||
import annotator.uniformer as uniformer
|
import annotator.uniformer as uniformer
|
||||||
|
|
@ -209,7 +219,7 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image
|
||||||
uniformer.inference_segmentor = inject_inference_segmentor
|
uniformer.inference_segmentor = inject_inference_segmentor
|
||||||
tmp_ouis = uniformer.show_result_pyplot
|
tmp_ouis = uniformer.show_result_pyplot
|
||||||
uniformer.show_result_pyplot = inject_show_result_pyplot
|
uniformer.show_result_pyplot = inject_show_result_pyplot
|
||||||
sam_semseg = _uniformer(input_image_np)
|
sam_semseg = _uniformer(crop_input_image)
|
||||||
uniformer.inference_segmentor = original_uniformer_inference_segmentor
|
uniformer.inference_segmentor = original_uniformer_inference_segmentor
|
||||||
uniformer.show_result_pyplot = tmp_ouis
|
uniformer.show_result_pyplot = tmp_ouis
|
||||||
else:
|
else:
|
||||||
|
|
@ -218,7 +228,7 @@ def categorical_mask_image(crop_processor, crop_category_input, crop_input_image
|
||||||
tmp_oodss = Visualizer.draw_sem_seg
|
tmp_oodss = Visualizer.draw_sem_seg
|
||||||
Visualizer.draw_sem_seg = inject_sem_seg
|
Visualizer.draw_sem_seg = inject_sem_seg
|
||||||
original_oneformer_draw_sem_seg = inject_oodss
|
original_oneformer_draw_sem_seg = inject_oodss
|
||||||
sam_semseg = _oneformer(input_image_np, dataset)
|
sam_semseg = _oneformer(crop_input_image, dataset=dataset)
|
||||||
Visualizer.draw_sem_seg = tmp_oodss
|
Visualizer.draw_sem_seg = tmp_oodss
|
||||||
mask = np.zeros(sam_semseg.shape, dtype=np.bool_)
|
mask = np.zeros(sam_semseg.shape, dtype=np.bool_)
|
||||||
for i in filter_classes:
|
for i in filter_classes:
|
||||||
|
|
|
||||||
|
|
@ -298,7 +298,7 @@ def dino_batch_process(
|
||||||
|
|
||||||
|
|
||||||
def cnet_seg(
|
def cnet_seg(
|
||||||
sam_model_name, cnet_seg_input_image, cnet_seg_processor,
|
sam_model_name, cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res,
|
||||||
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
||||||
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||||
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
||||||
|
|
@ -310,7 +310,7 @@ def cnet_seg(
|
||||||
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||||
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
||||||
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode)
|
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, auto_sam_output_mode)
|
||||||
outputs = semantic_segmentation(cnet_seg_input_image, cnet_seg_processor)
|
outputs = semantic_segmentation(cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res)
|
||||||
sem_sam_garbage_collect()
|
sem_sam_garbage_collect()
|
||||||
garbage_collect(sam)
|
garbage_collect(sam)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
@ -335,7 +335,7 @@ def image_layout(
|
||||||
|
|
||||||
|
|
||||||
def categorical_mask(
|
def categorical_mask(
|
||||||
sam_model_name, crop_processor, crop_category_input, crop_input_image,
|
sam_model_name, crop_processor, crop_processor_res, crop_category_input, crop_input_image,
|
||||||
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
||||||
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||||
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
||||||
|
|
@ -346,7 +346,7 @@ def categorical_mask(
|
||||||
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||||
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
auto_sam_crop_n_layers, auto_sam_crop_nms_thresh, auto_sam_crop_overlap_ratio,
|
||||||
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle")
|
auto_sam_crop_n_points_downscale_factor, auto_sam_min_mask_region_area, "coco_rle")
|
||||||
outputs = categorical_mask_image(crop_processor, crop_category_input, crop_input_image)
|
outputs = categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image)
|
||||||
sem_sam_garbage_collect()
|
sem_sam_garbage_collect()
|
||||||
garbage_collect(sam)
|
garbage_collect(sam)
|
||||||
if isinstance(outputs, str):
|
if isinstance(outputs, str):
|
||||||
|
|
@ -356,7 +356,8 @@ def categorical_mask(
|
||||||
|
|
||||||
|
|
||||||
def categorical_mask_batch(
|
def categorical_mask_batch(
|
||||||
sam_model_name, crop_processor, crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir,
|
sam_model_name, crop_processor, crop_processor_res,
|
||||||
|
crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir,
|
||||||
crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background,
|
crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background,
|
||||||
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
auto_sam_points_per_side, auto_sam_points_per_batch, auto_sam_pred_iou_thresh,
|
||||||
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
auto_sam_stability_score_thresh, auto_sam_stability_score_offset, auto_sam_box_nms_thresh,
|
||||||
|
|
@ -374,7 +375,7 @@ def categorical_mask_batch(
|
||||||
print(f"Processing {image_index}/{len(all_files)} {input_image_file}")
|
print(f"Processing {image_index}/{len(all_files)} {input_image_file}")
|
||||||
try:
|
try:
|
||||||
crop_input_image = Image.open(input_image_file).convert("RGB")
|
crop_input_image = Image.open(input_image_file).convert("RGB")
|
||||||
outputs = categorical_mask_image(crop_processor, crop_category_input, crop_input_image)
|
outputs = categorical_mask_image(crop_processor, crop_processor_res, crop_category_input, crop_input_image)
|
||||||
if isinstance(outputs, str):
|
if isinstance(outputs, str):
|
||||||
outputs = f"Image {image_index}: {outputs}"
|
outputs = f"Image {image_index}: {outputs}"
|
||||||
print(outputs)
|
print(outputs)
|
||||||
|
|
@ -480,6 +481,16 @@ def ui_batch(is_dino):
|
||||||
return dino_batch_dilation_amt, dino_batch_source_dir, dino_batch_dest_dir, dino_batch_output_per_image, dino_batch_save_image, dino_batch_save_mask, dino_batch_save_image_with_mask, dino_batch_save_background, dino_batch_run_button, dino_batch_progress
|
return dino_batch_dilation_amt, dino_batch_source_dir, dino_batch_dest_dir, dino_batch_output_per_image, dino_batch_save_image, dino_batch_save_mask, dino_batch_save_image_with_mask, dino_batch_save_background, dino_batch_run_button, dino_batch_progress
|
||||||
|
|
||||||
|
|
||||||
|
def ui_processor(use_random=True):
|
||||||
|
processor_choices = ["seg_ufade20k", "seg_ofade20k", "seg_ofcoco"]
|
||||||
|
if use_random:
|
||||||
|
processor_choices.append("random")
|
||||||
|
with gr.Row(): # TODO: Add pixel perfect, preprocessor_res > 64
|
||||||
|
cnet_seg_processor = gr.Radio(choices=processor_choices, value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ")
|
||||||
|
cnet_seg_processor_res = gr.Slider(label="Preprocessor res", value=512, minimum=64, maximum=2048, step=1)
|
||||||
|
return cnet_seg_processor, cnet_seg_processor_res
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
|
|
@ -604,14 +615,14 @@ class Script(scripts.Script):
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
"You can enhance semantic segmentation for control_v11p_sd15_seg from lllyasviel. "
|
"You can enhance semantic segmentation for control_v11p_sd15_seg from lllyasviel. "
|
||||||
"Non-semantic segmentation for [Edit-Anything](https://github.com/sail-sg/EditAnything) will be supported [when they convert their models to lllyasviel format](https://github.com/sail-sg/EditAnything/issues/14).")
|
"Non-semantic segmentation for [Edit-Anything](https://github.com/sail-sg/EditAnything) will be supported [when they convert their models to lllyasviel format](https://github.com/sail-sg/EditAnything/issues/14).")
|
||||||
cnet_seg_processor = gr.Radio(choices=["seg_ufade20k", "seg_ofade20k", "seg_ofcoco", "random"], value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ")
|
cnet_seg_processor, cnet_seg_processor_res = ui_processor()
|
||||||
cnet_seg_input_image = gr.Image(label="Image for Auto Segmentation", source="upload", type="pil", image_mode="RGBA")
|
cnet_seg_input_image = gr.Image(label="Image for Auto Segmentation", source="upload", type="pil", image_mode="RGBA")
|
||||||
cnet_seg_output_gallery = gr.Gallery(label="Auto segmentation output").style(grid=2)
|
cnet_seg_output_gallery = gr.Gallery(label="Auto segmentation output").style(grid=2)
|
||||||
cnet_seg_submit = gr.Button(value="Generate segmentation image")
|
cnet_seg_submit = gr.Button(value="Generate segmentation image")
|
||||||
cnet_seg_status = gr.Text(value="", label="Segmentation status")
|
cnet_seg_status = gr.Text(value="", label="Segmentation status")
|
||||||
cnet_seg_submit.click(
|
cnet_seg_submit.click(
|
||||||
fn=cnet_seg,
|
fn=cnet_seg,
|
||||||
inputs=[sam_model_name, cnet_seg_input_image, cnet_seg_processor, *auto_sam_config],
|
inputs=[sam_model_name, cnet_seg_input_image, cnet_seg_processor, cnet_seg_processor_res, *auto_sam_config],
|
||||||
outputs=[cnet_seg_output_gallery, cnet_seg_status])
|
outputs=[cnet_seg_output_gallery, cnet_seg_status])
|
||||||
with gr.Row(visible=(max_cn_num() > 0)):
|
with gr.Row(visible=(max_cn_num() > 0)):
|
||||||
cnet_seg_enable_copy = gr.Checkbox(value=False, label='Copy to ControlNet Segmentation')
|
cnet_seg_enable_copy = gr.Checkbox(value=False, label='Copy to ControlNet Segmentation')
|
||||||
|
|
@ -649,7 +660,7 @@ class Script(scripts.Script):
|
||||||
"You can mask images by their categories via semantic segmentation. Please enter category ids (integers), separated by `+`. "
|
"You can mask images by their categories via semantic segmentation. Please enter category ids (integers), separated by `+`. "
|
||||||
"Visit [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py#L12-L207) for ade20k "
|
"Visit [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/oneformer/data/datasets/register_ade20k_panoptic.py#L12-L207) for ade20k "
|
||||||
"and [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/detectron2/data/datasets/builtin_meta.py#L20-L153) for coco to get category->id map.")
|
"and [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/annotator/oneformer/detectron2/data/datasets/builtin_meta.py#L20-L153) for coco to get category->id map.")
|
||||||
crop_processor = gr.Radio(choices=["seg_ufade20k", "seg_ofade20k", "seg_ofcoco"], value="seg_ufade20k", label="Choose preprocessor for semantic segmentation: ")
|
crop_processor, crop_processor_res = ui_processor(False)
|
||||||
crop_category_input = gr.Textbox(placeholder="Enter categody ids, separated by +. For example, if you want bed+person, your input should be 7+12 for ade20k and 65+1 for coco.", label="Enter category IDs")
|
crop_category_input = gr.Textbox(placeholder="Enter categody ids, separated by +. For example, if you want bed+person, your input should be 7+12 for ade20k and 65+1 for coco.", label="Enter category IDs")
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.TabItem(label="Single Image"):
|
with gr.TabItem(label="Single Image"):
|
||||||
|
|
@ -660,7 +671,7 @@ class Script(scripts.Script):
|
||||||
crop_result = gr.Text(value="", label="Categorical mask status")
|
crop_result = gr.Text(value="", label="Categorical mask status")
|
||||||
crop_submit.click(
|
crop_submit.click(
|
||||||
fn=categorical_mask,
|
fn=categorical_mask,
|
||||||
inputs=[sam_model_name, crop_processor, crop_category_input, crop_input_image, *auto_sam_config],
|
inputs=[sam_model_name, crop_processor, crop_processor_res, crop_category_input, crop_input_image, *auto_sam_config],
|
||||||
outputs=[crop_output_gallery, crop_result])
|
outputs=[crop_output_gallery, crop_result])
|
||||||
crop_inpaint_enable, crop_cnet_inpaint_invert, crop_cnet_inpaint_idx = ui_inpaint(is_img2img, max_cn_num())
|
crop_inpaint_enable, crop_cnet_inpaint_invert, crop_cnet_inpaint_idx = ui_inpaint(is_img2img, max_cn_num())
|
||||||
crop_dilation_checkbox, crop_dilation_output_gallery = ui_dilation(crop_output_gallery, crop_padding, crop_input_image)
|
crop_dilation_checkbox, crop_dilation_output_gallery = ui_dilation(crop_output_gallery, crop_padding, crop_input_image)
|
||||||
|
|
@ -676,8 +687,8 @@ class Script(scripts.Script):
|
||||||
crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, _, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, crop_batch_run_button, crop_batch_progress = ui_batch(False)
|
crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir, _, crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, crop_batch_run_button, crop_batch_progress = ui_batch(False)
|
||||||
crop_batch_run_button.click(
|
crop_batch_run_button.click(
|
||||||
fn=categorical_mask_batch,
|
fn=categorical_mask_batch,
|
||||||
inputs=[sam_model_name, crop_processor, crop_category_input, crop_batch_dilation_amt,
|
inputs=[sam_model_name, crop_processor, crop_processor_res,
|
||||||
crop_batch_source_dir, crop_batch_dest_dir,
|
crop_category_input, crop_batch_dilation_amt, crop_batch_source_dir, crop_batch_dest_dir,
|
||||||
crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, *auto_sam_config],
|
crop_batch_save_image, crop_batch_save_mask, crop_batch_save_image_with_mask, crop_batch_save_background, *auto_sam_config],
|
||||||
outputs=[crop_batch_progress])
|
outputs=[crop_batch_progress])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue