sd_smartprocess/smartprocess.py

376 lines
14 KiB
Python

import os
import sys
import traceback
from io import StringIO
from pathlib import Path
import numpy as np
import tqdm
from PIL import Image, ImageOps, features
import modules.codeformer_model
import modules.gfpgan_model
import reallysafe
from clipcrop import CropClip
from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator
from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator
from modules import shared, images, safe
def printi(message):
shared.state.textinfo = message
print(message)
def list_features():
# Create buffer for pilinfo() to write into rather than stdout
buffer = StringIO()
features.pilinfo(out=buffer)
pil_features = []
# Parse and analyse lines
for line in buffer.getvalue().splitlines():
if "Extensions:" in line:
ext_list = line.split(": ")[1]
extensions = ext_list.split(", ")
for extension in extensions:
if extension not in pil_features:
pil_features.append(extension)
return pil_features
def is_image(path: Path, feats=None):
if feats is None:
feats = []
if not len(feats):
feats = list_features()
is_img = path.is_file() and path.suffix.lower() in feats
return is_img
def preprocess(rename,
src,
dst,
pad,
crop,
max_size,
txt_action,
flip,
caption,
caption_length,
caption_clip,
clip_use_v2,
clip_append_flavor,
clip_max_flavors,
clip_append_medium,
clip_append_movement,
clip_append_artist,
clip_append_trending,
caption_wd14,
wd14_min_score,
caption_deepbooru,
booru_min_score,
subject_class,
subject,
replace_class,
restore_faces,
face_model,
upscale,
upscale_ratio,
scaler
):
try:
shared.state.textinfo = "Initializing smart processing..."
safe.RestrictedUnpickler = reallysafe.RestrictedUnpickler
if not crop and not caption and not restore_faces and not upscale and not pad:
msg = "Nothing to do."
printi(msg)
return msg, msg
wd_interrogator = None
db_interrogator = None
clip_interrogator = None
crop_clip = None
if caption or crop:
printi("\rLoading captioning models...")
if caption_clip or crop:
printi("\rLoading CLIP interrogator...")
if shared.interrogator is not None:
shared.interrogator.unload()
clip_interrogator = ClipInterrogator(clip_use_v2,
clip_append_artist,
clip_append_medium,
clip_append_movement,
clip_append_flavor,
clip_append_trending)
if caption_deepbooru:
printi("\rLoading Deepbooru interrogator...")
db_interrogator = BooruInterrogator()
if caption_wd14:
printi("\rLoading wd14 interrogator...")
wd_interrogator = WaifuDiffusionInterrogator()
if crop:
printi("Loading YOLOv5 interrogator...")
try:
del sys.modules['models']
except:
pass
crop_clip = CropClip()
src = os.path.abspath(src)
dst = os.path.abspath(dst)
if src == dst:
msg = "Source and destination are the same, returning."
printi(msg)
return msg, msg
os.makedirs(dst, exist_ok=True)
files = os.listdir(src)
printi("Preprocessing...")
shared.state.job_count = len(files)
def build_caption(image):
# Read existing caption from path/txt file
existing_caption_txt_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_txt_filename):
with open(existing_caption_txt_filename, 'r', encoding="utf8") as file:
existing_caption_txt = file.read()
else:
existing_caption_txt = ''.join(c for c in filename if c.isalpha() or c in [" ", ", "])
out_tags = []
if clip_interrogator is not None:
if caption_clip:
tags = clip_interrogator.interrogate(img, max_flavors=clip_max_flavors)
for tag in tags:
# print(f"CLIPTag: {tag}")
out_tags.append(tag)
if wd_interrogator is not None:
ratings, tags = wd_interrogator.interrogate(img)
for tag in sorted(tags, key=tags.get, reverse=True):
if tags[tag] >= wd14_min_score:
# print(f"WDTag {tag} score is {tags[tag]}")
out_tags.append(tag)
else:
break
if caption_deepbooru:
tags = db_interrogator.interrogate(image)
for tag in sorted(tags, key=tags.get, reverse=True):
if tags[tag] >= booru_min_score:
# print(f"DBTag {tag} score is {tags[tag]}")
out_tags.append(tag)
# Remove duplicates
unique_tags = []
for tag in out_tags:
if not tag in unique_tags:
unique_tags.append(tag.strip())
caption_txt = ", ".join(unique_tags)
if txt_action == 'prepend' and existing_caption_txt:
caption_txt = existing_caption_txt + ' ' + caption_txt
elif txt_action == 'append' and existing_caption_txt:
caption_txt = caption_txt + ' ' + existing_caption_txt
elif txt_action == 'copy' and existing_caption_txt:
caption_txt = existing_caption_txt
caption_txt = caption_txt.strip()
if replace_class and subject is not None and subject_class is not None:
# Find and replace "a SUBJECT CLASS" in caption_txt with subject name
if f"a {subject_class}" in caption_txt:
caption_txt = caption_txt.replace(f"a {subject_class}", subject)
if subject_class in caption_txt:
caption_txt = caption_txt.replace(subject_class, subject)
if 0 < caption_length < len(caption_txt):
split_cap = caption_txt.split(" ")
caption_txt = ""
cap_test = ""
split_idx = 0
while True and split_idx < len(split_cap):
cap_test += f" {split_cap[split_idx]}"
if len(cap_test) < caption_length:
caption_txt = cap_test
split_idx += 1
caption_txt = caption_txt.strip()
# danbooru_replace = ("_", " "), ("\\", ""), ("(", ""), (")", "")
# caption_text = reduce(lambda a, kv: a.replace(*kv), danbooru_replace, caption_text)
return caption_txt
def save_pic(image, src_name, img_index, existing_caption=None, flipped=False):
if rename:
basename = f"{img_index:05}"
else:
basename = os.path.splitext(src_name)
if flipped:
basename += "_flipped"
shared.state.current_image = img
image.save(os.path.join(dst, f"{basename}.png"))
if existing_caption is not None and len(existing_caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(existing_caption)
image_index = 0
# Enumerate images
pil_features = list_features()
for index, src_image in enumerate(tqdm.tqdm(files)):
# Quit on cancel
if shared.state.interrupted:
msg = f"Processing interrupted, {index}/{len(files)}"
return msg, msg
filename = os.path.join(src, src_image)
if not is_image(Path(filename), pil_features):
continue
try:
img = Image.open(filename).convert("RGB")
except Exception as e:
msg = f"Exception processing: {e}"
printi(msg)
traceback.print_exc()
return msg, msg
if crop:
# Interrogate once
short_caption = clip_interrogator.interrogate(img, short=True)
if subject_class is not None and subject_class != "":
short_caption = subject_class
src_ratio = img.width / img.height
# Pad image before cropping?
if src_ratio != 1 and pad:
if img.width > img.height:
pad_width = img.width
pad_height = img.width
else:
pad_width = img.height
pad_height = img.height
res = Image.new("RGB", (pad_width, pad_height))
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
img = res
# Do the actual crop clip
im_data = crop_clip.get_center(img, prompt=short_caption)
crop_width = im_data[1] - im_data[0]
center_x = im_data[0] + (crop_width / 2)
crop_height = im_data[3] - im_data[2]
center_y = im_data[2] + (crop_height / 2)
crop_ratio = crop_width / crop_height
dest_ratio = 1
tgt_width = crop_width
tgt_height = crop_height
if crop_ratio != dest_ratio:
if crop_width > crop_height:
tgt_height = crop_width / dest_ratio
tgt_width = crop_width
else:
tgt_width = crop_height / dest_ratio
tgt_height = crop_height
# Reverse the above if dest is too big
if tgt_width > img.width or tgt_height > img.height:
if tgt_width > img.width:
tgt_width = img.width
tgt_height = tgt_width / dest_ratio
else:
tgt_height = img.height
tgt_width = tgt_height / dest_ratio
tgt_height = int(tgt_height)
tgt_width = int(tgt_width)
left = max(center_x - (tgt_width / 2), 0)
right = min(center_x + (tgt_width / 2), img.width)
top = max(center_y - (tgt_height / 2), 0)
bottom = min(center_y + (tgt_height / 2), img.height)
img = img.crop((left, top, right, bottom))
shared.state.current_image = img
if restore_faces:
shared.state.textinfo = f"Restoring faces using {face_model}..."
if face_model == "gfpgan":
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(img, dtype=np.uint8))
img = Image.fromarray(restored_img)
else:
restored_img = modules.codeformer_model.codeformer.restore(np.array(img, dtype=np.uint8),
w=1.0)
img = Image.fromarray(restored_img)
shared.state.current_image = img
if upscale:
shared.state.textinfo = "Upscaling..."
upscaler = shared.sd_upscalers[scaler]
res = upscaler.scaler.upscale(img, upscale_ratio, upscaler.data_path)
img = res
shared.state.current_image = img
if pad:
ratio = 1
src_ratio = img.width / img.height
src_w = max_size if ratio < src_ratio else img.width * max_size // img.height
src_h = max_size if ratio >= src_ratio else img.height * max_size // img.width
resized = images.resize_image(0, img, src_w, src_h)
res = Image.new("RGB", (max_size, max_size))
res.paste(resized, box=(max_size // 2 - src_w // 2, max_size // 2 - src_h // 2))
img = res
# Resize again if image is not at the right size.
if img.width != max_size or img.height != max_size:
img = images.resize_image(1, img, max_size, max_size)
# Build a caption, if enabled
full_caption = build_caption(img) if caption else None
# Show our output
shared.state.current_image = img
printi(f"Processed: '({src_image} - {full_caption})")
save_pic(img, src_image, image_index, existing_caption=full_caption)
image_index += 1
if flip:
save_pic(ImageOps.flip(img), src_image, image_index, existing_caption=full_caption, flipped=True)
image_index += 1
shared.state.nextjob()
if caption_clip or crop:
printi("Unloading CLIP interrogator...")
shared.interrogator.send_blip_to_ram()
if caption_deepbooru:
printi("Unloading Deepbooru interrogator...")
db_interrogator.unload()
if caption_wd14:
printi("Unloading wd14 interrogator...")
wd_interrogator.unload()
return f"Successfully processed {len(files)} images.", f"Successfully processed {len(files)} images."
except Exception as e:
msg = f"Exception processing: {e}"
traceback.print_exc()
pass
return msg, msg