unprompted/shortcodes/stable_diffusion/txt2mask.py

658 lines
30 KiB
Python

class Shortcode():
def __init__(self,Unprompted):
self.Unprompted = Unprompted
self.image_mask = None
self.show = False
self.description = "Creates an image mask from the content for use with inpainting."
try:
del self.cached_model
del self.cached_transform
del self.cached_model_method
del self.cached_predictor
except: pass
self.cached_model = -1
self.cached_transform = -1
self.cached_model_method = ""
self.cached_predictor = -1
def run_block(self, pargs, kwargs, context, content):
from PIL import ImageChops, Image, ImageOps
import os.path
import torch
from torchvision import transforms
from matplotlib import pyplot as plt
import cv2
import numpy
import gc
from modules.images import flatten
from modules.shared import opts
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
gc.collect()
if "txt2mask_init_image" in kwargs:
self.init_image = kwargs["txt2mask_init_image"].copy()
elif "init_images" not in self.Unprompted.shortcode_user_vars:
self.Unprompted.log("No init_images found...")
return
else: self.init_image = self.Unprompted.shortcode_user_vars["init_images"][0].copy()
method = self.Unprompted.parse_advanced(kwargs["method"],context) if "method" in kwargs else "clipseg"
if method == "clipseg":
mask_width = 512
mask_height = 512
else:
if method == "grounded_sam":
import launch
if not launch.is_installed("groundingdino"):
self.Unprompted.log("Attempting to install GroundingDINO library. Buckle up bro")
try:
launch.run_pip("install git+https://github.com/IDEA-Research/GroundingDINO","requirements for Unprompted - txt2mask SAM method")
except Exception as e:
self.Unprompted.log(f"GroundingDINO problem: {e}",context="ERROR")
self.Unprompted.log(f"Please open an issue on their repo, not mine.",context="ERROR")
return ""
mask_width = self.init_image.size[0]
mask_height = self.init_image.size[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == "cuda": torch.cuda.empty_cache()
if "stamp" in kwargs:
stamps = (self.Unprompted.parse_advanced(kwargs["stamp"],context)).split(self.Unprompted.Config.syntax.delimiter)
stamp_x = int(float(self.Unprompted.parse_advanced(kwargs["stamp_x"],context))) if "stamp_x" in kwargs else 0
stamp_y = int(float(self.Unprompted.parse_advanced(kwargs["stamp_y"],context))) if "stamp_y" in kwargs else 0
stamp_x_orig = stamp_x
stamp_y_orig = stamp_y
stamp_method = self.Unprompted.parse_advanced(kwargs["stamp_method"],context) if "stamp_method" in kwargs else "stretch"
for stamp in stamps:
# Checks for file in images/stamps, otherwise assumes absolute path
stamp_path = f"{self.Unprompted.base_dir}/images/stamps/{stamp}.png"
if not os.path.exists(stamp_path): stamp_path = stamp
if not os.path.exists(stamp_path):
self.Unprompted.log(f"Stamp not found: {stamp_path}",context="ERROR")
continue
stamp_img = Image.open(stamp_path).convert("RGBA")
if stamp_method == "stretch":
stamp_img = stamp_img.resize((self.init_image.size[0],self.init_image.size[1]))
elif stamp_method == "center":
stamp_x = stamp_x_orig + int((mask_width - stamp_img.size[0]) / 2)
stamp_y = stamp_y_orig + int((mask_height - stamp_img.size[1]) / 2)
stamp_blur = int(float(self.Unprompted.parse_advanced(kwargs["stamp_blur"],context))) if "stamp_blur" in kwargs else 0
if stamp_blur:
from PIL import ImageFilter
blur = ImageFilter.GaussianBlur(stamp_blur)
stamp_img = stamp_img.filter(blur)
self.init_image.paste(stamp_img,(stamp_x,stamp_y),stamp_img)
brush_mask_mode = self.Unprompted.parse_advanced(kwargs["mode"],context) if "mode" in kwargs else "add"
self.show = True if "show" in pargs else False
self.legacy_weights = True if "legacy_weights" in pargs else False
smoothing = int(self.Unprompted.parse_advanced(kwargs["smoothing"],context)) if "smoothing" in kwargs else 20
smoothing_kernel = None
if smoothing > 0:
smoothing_kernel = numpy.ones((smoothing,smoothing),numpy.float32)/(smoothing*smoothing)
neg_smoothing = int(self.Unprompted.parse_advanced(kwargs["neg_smoothing"],context)) if "neg_smoothing" in kwargs else 20
neg_smoothing_kernel = None
if neg_smoothing > 0:
neg_smoothing_kernel = numpy.ones((neg_smoothing,neg_smoothing),numpy.float32)/(neg_smoothing*neg_smoothing)
# Pad the mask by applying a dilation or erosion
mask_padding = int(self.Unprompted.parse_advanced(kwargs["padding"],context) if "padding" in kwargs else 0)
neg_mask_padding = int(self.Unprompted.parse_advanced(kwargs["neg_padding"],context) if "neg_padding" in kwargs else 0)
padding_dilation_kernel = None
if (mask_padding != 0):
padding_dilation_kernel = numpy.ones((abs(mask_padding), abs(mask_padding)), numpy.uint8)
neg_padding_dilation_kernel = None
if (neg_mask_padding != 0):
neg_padding_dilation_kernel = numpy.ones((abs(neg_mask_padding), abs(neg_mask_padding)), numpy.uint8)
prompts = content.split(self.Unprompted.Config.syntax.delimiter)
prompt_parts = len(prompts)
if "negative_mask" in kwargs:
neg_parsed = self.Unprompted.parse_advanced(kwargs["negative_mask"],context)
if len(neg_parsed) < 1: negative_prompts = None
else:
negative_prompts = neg_parsed.split(self.Unprompted.Config.syntax.delimiter)
negative_prompt_parts = len(negative_prompts)
else: negative_prompts = None
mask_precision = min(255,int(self.Unprompted.parse_advanced(kwargs["precision"],context) if "precision" in kwargs else 100))
neg_mask_precision = min(255,int(self.Unprompted.parse_advanced(kwargs["neg_precision"],context) if "neg_precision" in kwargs else 100))
def overlay_mask_part(img_a,img_b,mode):
if (mode == "discard"): img_a = ImageChops.darker(img_a, img_b)
else: img_a = ImageChops.lighter(img_a, img_b)
return(img_a)
def gray_to_pil(img):
return (Image.fromarray(cv2.cvtColor(img,cv2.COLOR_GRAY2RGBA)))
def process_mask_parts(masks, mode, final_img = None, mask_precision=100, mask_padding=0, padding_dilation_kernel=None, smoothing_kernel=None):
for i, mask in enumerate(masks):
filename = f"mask_{mode}_{i}.png"
if method == "clipseg":
plt.imsave(filename,torch.sigmoid(mask[0]))
img = cv2.imread(filename)
# TODO: Figure out how to convert the plot above to numpy instead of re-loading image
else:
plt.imsave(filename,mask)
import random
img = cv2.imread(filename)
img = cv2.resize(img,(mask_width,mask_height))
if padding_dilation_kernel is not None:
if (mask_padding > 0): img = cv2.dilate(img,padding_dilation_kernel,iterations=1)
else: img = cv2.erode(img,padding_dilation_kernel,iterations=1)
if smoothing_kernel is not None: img = cv2.filter2D(img,-1,smoothing_kernel)
#if method == "clip_surgery":
#gray_image = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_BGR2LUV), cv2.COLOR_BGR2GRAY)
#else: gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
Image.fromarray(gray_image).save("mask_gray_test.png")
(thresh, bw_image) = cv2.threshold(gray_image, mask_precision, 255, cv2.THRESH_BINARY)
if (mode == "discard"): bw_image = numpy.invert(bw_image)
# overlay mask parts
bw_image = gray_to_pil(bw_image)
if (i > 0 or final_img is not None): bw_image = overlay_mask_part(bw_image,final_img,mode)
final_img = bw_image
return(final_img)
def get_mask():
preds = []
negative_preds = []
image_pil = flatten(self.init_image, opts.img2img_background_color)
if method == "clip_surgery":
from lib_unprompted import clip_surgery as clip
import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from segment_anything import sam_model_registry, SamPredictor
# default imagenet redundant features
redundants = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
if "redundant_features" in kwargs: redundants.extend(kwargs["redundant_features"].split(self.Unprompted.Config.syntax.delimiter))
self.bypass_sam = True if "bypass_sam" in pargs else False
### Init CLIP and data
if self.cached_model == -1 or self.cached_model_method != method:
model, preprocess = clip.load("CS-ViT-B/16", device=device)
model.eval()
# Cache for future runs
self.cached_model = model
self.cached_transform = preprocess
else:
self.Unprompted.log("Using cached model(s) for CLIP_Surgery method")
model = self.cached_model
preprocess = self.cached_transform
image = preprocess(image_pil).unsqueeze(0).to(device)
cv2_img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
### CLIP Surgery for a single text, without fixed label sets
with torch.no_grad():
# CLIP architecture surgery acts on the image encoder
image_features = model.encode_image(image)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
# Prompt ensemble for text features with normalization
text_features = clip.encode_text_with_prompt_ensemble(model, prompts, device)
if (negative_prompts):
negative_text_features = clip.encode_text_with_prompt_ensemble(model, negative_prompts, device)
# Extract redundant features from an empty string
redundant_features = clip.encode_text_with_prompt_ensemble(model, [""], device, redundants)
# no sam
if self.bypass_sam:
def reg_inference(text_features):
preds = []
# Apply feature surgery for single text
similarity = clip.clip_feature_surgery(image_features, text_features, redundant_features)
similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2])
# Draw similarity map
for b in range(similarity_map.shape[0]):
for n in range(similarity_map.shape[-1]):
vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8')
preds.append(vis)
return(preds)
preds = reg_inference(text_features)
if (negative_prompts): negative_preds = reg_inference(negative_text_features)
else:
point_thresh = float(self.Unprompted.parse_advanced(kwargs["point_threshold"],context)) if "point_threshold" in kwargs else 0.98
multimask_output = True if "multimask_output" in pargs else False
# Init SAM
if self.cached_predictor == -1 or self.cached_model_method != method:
sam_model_dir = f"{self.Unprompted.base_dir}/models/segment_anything"
os.makedirs(sam_model_dir, exist_ok=True)
sam_filename = "sam_vit_h_4b8939.pth"
sam_file = f"{sam_model_dir}/{sam_filename}"
# Download model weights if we don't have them yet
if not os.path.exists(sam_file):
print("Downloading SAM model weights...")
self.Unprompted.download_file(sam_file,f"https://dl.fbaipublicfiles.com/segment_anything/{sam_filename}")
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_file)
sam.to(device=device)
predictor = SamPredictor(sam)
self.cached_predictor = predictor
else:
predictor = self.cached_predictor
predictor.set_image(np.array(image_pil))
self.cached_model_method = method
def sam_inference(text_features):
preds = []
# Combine features after removing redundant features and min-max norm
sm = clip.clip_feature_surgery(image_features, text_features, redundant_features)[0, 1:, :]
sm_norm = (sm - sm.min(0, keepdim=True)[0]) / (sm.max(0, keepdim=True)[0] - sm.min(0, keepdim=True)[0])
sm_mean = sm_norm.mean(-1, keepdim=True)
# get positive points from individual maps, and negative points from the mean map
p, l = clip.similarity_map_to_points(sm_mean, cv2_img.shape[:2], t=point_thresh)
num = len(p) // 2
points = p[num:] # negatives in the second half
labels = [l[num:]]
for i in range(sm.shape[-1]):
p, l = clip.similarity_map_to_points(sm[:, i], cv2_img.shape[:2], t=point_thresh)
num = len(p) // 2
points = points + p[:num] # positive in first half
labels.append(l[:num])
labels = np.concatenate(labels, 0)
# Inference SAM with points from CLIP Surgery
masks, scores, logits = predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=multimask_output)
mask = masks[np.argmax(scores)]
mask = mask.astype('uint8')
vis = cv2_img.copy()
vis[mask > 0] = np.array([255, 255, 255], dtype=np.uint8)
vis[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
preds.append(vis)
if self.show:
for idx,mask in enumerate(masks):
plt.imsave(f"mask{idx}.png",mask)
return(preds)
preds = sam_inference(text_features)
if negative_prompts: negative_preds = sam_inference(negative_text_features)
elif method == "grounded_sam":
box_thresh = float(self.Unprompted.parse_advanced(kwargs["box_threshold"],context)) if "box_threshold" in kwargs else 0.3
text_thresh = float(self.Unprompted.parse_advanced(kwargs["text_threshold"],context)) if "text_threshold" in kwargs else 0.25
# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
return boxes_filt, pred_phrases
sam_model_dir = f"{self.Unprompted.base_dir}/models/segment_anything"
os.makedirs(sam_model_dir, exist_ok=True)
sam_filename = "sam_vit_h_4b8939.pth"
sam_file = f"{sam_model_dir}/{sam_filename}"
# Download model weights if we don't have them yet
if not os.path.exists(sam_file):
print("Downloading SAM model weights...")
self.Unprompted.download_file(sam_file,f"https://dl.fbaipublicfiles.com/segment_anything/{sam_filename}")
dino_model_dir = f"{self.Unprompted.base_dir}/models/groundingdino"
os.makedirs(dino_model_dir, exist_ok=True)
dino_filename = "groundingdino_swint_ogc.pth"
dino_file = f"{dino_model_dir}/{dino_filename}"
if not os.path.exists(dino_file):
print("Downloading GroundingDINO model weights...")
self.Unprompted.download_file(dino_file,f"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/{dino_filename}")
model_config_path = f"{self.Unprompted.base_dir}/lib_unprompted/groundingdino/config/GroundingDINO_SwinT_OGC.py"
# load model
if self.cached_model == -1 or self.cached_model_method != method:
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(dino_file, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
print(load_res)
_ = model.eval()
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
self.cached_model = model
self.cached_transform = transform
self.cached_model_method = method
else:
self.Unprompted.log("Using cached GroundingDINO model.")
model = self.cached_model
transform = self.cached_transform
def sam_infer(boxes_filt):
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, img.shape[:2]).to(device)
masks, _, _ = predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(device),
multimask_output = False,
)
preds = []
value = 0
mask_img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
for idx, mask in enumerate(masks):
# mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
mask_img[mask.cpu().numpy()[0] >= 1] = np.array([255, 255, 255], dtype=np.uint8)
mask_img[mask.cpu().numpy()[0] < 1] = np.array([0, 0, 0], dtype=np.uint8)
# TODO: Figure out if we can take advantage of individual mask layers rather than stacking as composite
preds.append(mask_img)
return(preds)
# run grounding dino model
img, _ = transform(image_pil,None)
boxes_filt, pred_phrases = get_grounding_output(model, img, prompts[0], box_thresh, text_thresh, device=device)
if (negative_prompts):
neg_boxes_filt, pred_phrases = get_grounding_output(model, img, negative_prompts[0], box_thresh, text_thresh, device=device)
# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_file).to(device))
img = numpy.array(image_pil) # cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
predictor.set_image(img)
size = image_pil.size
H, W = size[0], size[1]
preds = sam_infer(boxes_filt)
if (negative_prompts): negative_preds = sam_infer(neg_boxes_filt)
# clipseg method
else:
from lib_unprompted.stable_diffusion.clipseg.models.clipseg import CLIPDensePredT
model_dir = f"{self.Unprompted.base_dir}/models/clipseg"
os.makedirs(model_dir, exist_ok=True)
d64_filename = "rd64-uni.pth" if self.legacy_weights else "rd64-uni-refined.pth"
d64_file = f"{model_dir}/{d64_filename}"
d16_file = f"{model_dir}/rd16-uni.pth"
# Download model weights if we don't have them yet
if not os.path.exists(d64_file):
print("Downloading clipseg model weights...")
self.Unprompted.download_file(d64_file,f"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files={d64_filename}")
self.Unprompted.download_file(d16_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd16-uni.pth")
# load model
if self.cached_model == -1 or self.cached_model_method != method:
self.Unprompted.log("Loading clipseg model...")
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=not self.legacy_weights)
# non-strict, because we only stored decoder weights (not CLIP weights)
model.load_state_dict(torch.load(d64_file, map_location=device), strict=False)
model = model.eval().to(device=device)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512)),
])
# Cache for future runs
self.cached_model = model
self.cached_transform = transform
self.cached_model_method = method
else:
self.Unprompted.log("Using cached clipseg model.")
model = self.cached_model
transform = self.cached_transform
img = transform(image_pil).unsqueeze(0)
# predict
with torch.no_grad():
if "image_prompt" in kwargs:
from PIL import Image
img_mask = flatten(Image.open(r"A:/inbox/test_mask.png"), opts.img2img_background_color)
img_mask = transform(img_mask).unsqueeze(0)
preds = model(img.to(device=device), img_mask.to(device=device))[0].cpu()
else:
preds = model(img.repeat(prompt_parts,1,1,1).to(device=device), prompts)[0].cpu()
if (negative_prompts): negative_preds = model(img.repeat(negative_prompt_parts,1,1,1).to(device=device), negative_prompts)[0].cpu()
# All of the below logic applies to both clipseg and sam
if "image_mask" not in self.Unprompted.shortcode_user_vars: self.Unprompted.shortcode_user_vars["image_mask"] = None
if (brush_mask_mode == "add" and self.Unprompted.shortcode_user_vars["image_mask"] is not None):
final_img = self.Unprompted.shortcode_user_vars["image_mask"].convert("RGBA").resize((mask_width,mask_height))
else: final_img = None
# process masking
final_img = process_mask_parts(preds,"add",final_img, mask_precision, mask_padding, padding_dilation_kernel, smoothing_kernel)
# process negative masking
if (brush_mask_mode == "subtract" and self.Unprompted.shortcode_user_vars["image_mask"] is not None):
self.Unprompted.shortcode_user_vars["image_mask"] = ImageOps.invert(self.Unprompted.shortcode_user_vars["image_mask"])
self.Unprompted.shortcode_user_vars["image_mask"] = self.Unprompted.shortcode_user_vars["image_mask"].convert("RGBA").resize((mask_width,mask_height))
final_img = overlay_mask_part(final_img,self.Unprompted.shortcode_user_vars["image_mask"],"discard")
if (negative_prompts): final_img = process_mask_parts(negative_preds,"discard",final_img, neg_mask_precision,neg_mask_padding, neg_padding_dilation_kernel, neg_smoothing_kernel)
if "size_var" in kwargs:
img_data = final_img.load()
# Count number of transparent pixels
black_pixels = 0
total_pixels = mask_width * mask_height
for y in range(mask_height):
for x in range(mask_width):
pixel_data = img_data[x,y]
if (pixel_data[0] == 0 and pixel_data[1] == 0 and pixel_data[2] == 0): black_pixels += 1
subject_size = 1 - black_pixels / total_pixels
self.Unprompted.shortcode_user_vars[kwargs["size_var"]] = subject_size
# Inpaint sketch compatibility
if "sketch_color" in kwargs:
self.Unprompted.shortcode_user_vars["mode"] = 3
this_color = kwargs["sketch_color"]
# Convert to tuple for use with colorize
if this_color[0].isdigit(): this_color = tuple(map(int,this_color.split(',')))
paste_mask = ImageOps.colorize(final_img.convert("L"),black="black",white=this_color)
# Convert black pixels to transparent
paste_mask = paste_mask.convert('RGBA')
mask_data = paste_mask.load()
width, height = paste_mask.size
for y in range(height):
for x in range(width):
if mask_data[x, y] == (0, 0, 0, 255): mask_data[x, y] = (0, 0, 0, 0)
# Match size just in case
paste_mask = paste_mask.resize((image_pil.size[0],image_pil.size[1]))
# Workaround for A1111 not including mask_alpha in p object
if "sketch_alpha" in kwargs:
alpha_channel = paste_mask.getchannel('A')
new_alpha = alpha_channel.point(lambda i: int(float(kwargs["sketch_alpha"])) if i>0 else 0)
paste_mask.putalpha(new_alpha)
# Workaround for A1111 bug, not accepting inpaint_color_sketch param w/ blur
if (self.Unprompted.shortcode_user_vars["mask_blur"] > 0):
from PIL import ImageFilter
blur = ImageFilter.GaussianBlur(self.Unprompted.shortcode_user_vars["mask_blur"])
paste_mask = paste_mask.filter(blur)
self.Unprompted.shortcode_user_vars["mask_blur"] = 0
# Paste mask on
image_pil.paste(paste_mask,box=None,mask=paste_mask)
self.Unprompted.shortcode_user_vars["init_images"][0] = image_pil
# not used by SD, just used to append to our GUI later
self.Unprompted.shortcode_user_vars["colorized_mask"] = paste_mask
# Assign webui vars, note - I think it should work this way but A1111 doesn't appear to store some of these in p obj
# note: inpaint_color_sketch = flattened image with mask on top
# self.Unprompted.shortcode_user_vars["inpaint_color_sketch"] = image_pil
# note: inpaint_color_sketch_orig = the init image
# self.Unprompted.shortcode_user_vars["inpaint_color_sketch_orig"] = self.Unprompted.shortcode_user_vars["init_images"][0]
# return image_pil
else:
self.Unprompted.shortcode_user_vars["mode"] = 4 # "mask upload" mode to avoid unnecessary processing
if ("mask_blur" in self.Unprompted.shortcode_user_vars and self.Unprompted.shortcode_user_vars["mask_blur"] > 0):
from PIL import ImageFilter
blur = ImageFilter.GaussianBlur(self.Unprompted.shortcode_user_vars["mask_blur"])
final_img = final_img.filter(blur)
self.Unprompted.shortcode_user_vars["mask_blur"] = 0
if "unload_model" in pargs:
self.model = -1
self.cached_model = -1
self.cached_model_method = ""
self.cached_predictor = -1
return final_img
# Set up processor parameters correctly
self.image_mask = get_mask().resize((self.init_image.width,self.init_image.height))
if "return_image" in pargs: return(self.image_mask)
self.Unprompted.shortcode_user_vars["mode"] = max(4,self.Unprompted.shortcode_user_vars["mode"])
self.Unprompted.shortcode_user_vars["image_mask"] =self.image_mask
self.Unprompted.shortcode_user_vars["mask"]=self.image_mask
self.Unprompted.shortcode_user_vars["mask_for_overlay"] = self.image_mask
self.Unprompted.shortcode_user_vars["latent_mask"] = None # fixes inpainting full resolution
arr = {}
arr["image"] = self.init_image
arr["mask"] = self.image_mask
self.Unprompted.shortcode_user_vars["init_img_with_mask"] = arr
self.Unprompted.shortcode_user_vars["init_mask"] = self.image_mask
if "save" in kwargs: self.image_mask.save(f"{self.Unprompted.parse_advanced(kwargs['save'],context)}.png")
return ""
def after(self,p=None,processed=None):
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import draw_segmentation_masks
if self.image_mask and self.show:
if self.Unprompted.shortcode_user_vars["mode"] == 4: processed.images.append(self.image_mask)
else: processed.images.append(self.Unprompted.shortcode_user_vars["colorized_mask"])
overlayed_init_img = draw_segmentation_masks(pil_to_tensor(self.Unprompted.shortcode_user_vars["init_images"][0]), pil_to_tensor(self.image_mask.convert("L")) > 0)
processed.images.append(to_pil_image(overlayed_init_img))
self.image_mask = None
self.show = False
return processed
def ui(self,gr):
gr.Radio(label="Mask blend mode 🡢 mode",choices=["add","subtract","discard"],value="add",interactive=True)
gr.Radio(label="Masking tech method 🡢 method",choices=["clipseg","clip_surgery","grounded_sam"],value="clipseg",interactive=True)
gr.Checkbox(label="Show mask in output 🡢 show")
gr.Checkbox(label="Use clipseg legacy weights 🡢 legacy_weights")
gr.Number(label="Precision of selected area 🡢 precision",value=100,interactive=True)
gr.Number(label="Padding radius in pixels 🡢 padding",value=0,interactive=True)
gr.Number(label="Smoothing radius in pixels 🡢 smoothing",value=20,interactive=True)
gr.Textbox(label="Negative mask prompt 🡢 negative_mask",max_lines=1)
gr.Number(label="Negative mask precision of selected area 🡢 neg_precision",value=100,interactive=True)
gr.Number(label="Negative mask padding radius in pixels 🡢 neg_padding",value=0,interactive=True)
gr.Number(label="Negative mask smoothing radius in pixels 🡢 neg_smoothing",value=20,interactive=True)
gr.Textbox(label="Mask color, enables Inpaint Sketch mode 🡢 sketch_color",max_lines=1,placeholder="e.g. tan or 127,127,127")
gr.Number(label="Mask alpha, must be used in conjunction with mask color 🡢 sketch_alpha",value=0,interactive=True)
gr.Textbox(label="Save the mask size to the following variable 🡢 size_var",max_lines=1)
gr.Checkbox(label="Unload model after inference (for low memory devices) 🡢 unload_model")