sd_smartprocess/processors.py

220 lines
7.5 KiB
Python

import gc
import random
from typing import List, Any
import torch
from PIL import Image
from tqdm import tqdm
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm
from extensions.sd_smartprocess import super_resolution
from extensions.sd_smartprocess.clipcrop import CropClip
from extensions.sd_smartprocess.interrogators.clip_interrogator import CLIPInterrogator
from extensions.sd_smartprocess.interrogators.booru_interrogator import BooruInterrogator
from modules import shared
# Base processor
class Processor:
def __init__(self):
printm("Model loaded.")
# Unload models
def unload(self):
if torch.has_cuda:
torch.cuda.empty_cache()
gc.collect()
printm("Model unloaded.")
# Process images
def process(self, images: List[Image.Image]) -> List[Any]:
raise Exception("Not Implemented")
# CLIP Processing
class ClipProcessor(Processor):
def __init__(
self,
clip_use_v2,
clip_append_artist,
clip_append_medium,
clip_append_movement,
clip_append_flavor,
clip_append_trending,
num_beams,
min_clip_tokens,
max_clip_tokens,
max_flavors
):
self.description = "Processing CLIP"
if shared.interrogator is not None:
shared.interrogator.unload()
self.max_flavors = max_flavors
shared.state.textinfo = "Loading CLIP Model..."
self.model = CLIPInterrogator(
clip_use_v2,
clip_append_artist,
clip_append_medium,
clip_append_movement,
clip_append_flavor,
clip_append_trending,
num_beams,
min_clip_tokens,
max_clip_tokens
)
super().__init__()
def process(self, images: List[Image.Image], short:bool=False) -> List[str]:
output = []
shared.state.job_count = len(images)
shared.state.textinfo = f"{self.description}..."
for img in tqdm(images, desc=self.description):
short_caption = self.model.interrogate(img, short=short, max_flavors=self.max_flavors)
output.append(short_caption)
shared.state.current_image = img
shared.state.job_no += 1
return output
def unload(self):
if self.model.clip_model:
del self.model.clip_model
if self.model.blip_model:
del self.model.blip_model
super().unload()
# Danbooru Processing
class BooruProcessor(Processor):
def __init__(self, min_score: float):
self.description = "Processing Danbooru"
shared.state.textinfo = "Loading DeepDanbooru Model..."
self.model = BooruInterrogator()
self.min_score = min_score
super().__init__()
def process(self, images: List[Image.Image]) -> List[List[str]]:
output = []
shared.state.job_count = len(images)
shared.state.textinfo = f"{self.description}..."
for img in tqdm(images, desc=self.description):
out_tags = []
tags = self.model.interrogate(img)
for tag in sorted(tags, key=tags.get, reverse=True):
if tags[tag] >= self.min_score:
out_tags.append(tag)
output.append(out_tags)
shared.state.job_count += 1
def unload(self):
self.model.unload()
super().unload()
# WD14 Processing
# Crop Processing
class CropProcessor(Processor):
def __init__(self, subject_class: str, pad: bool, crop: bool):
self.description = "Cropping"
if crop:
shared.state.textinfo = "Loading CROP Model..."
self.model = CropClip() if crop else None
self.subject_class = subject_class
self.pad = pad
self.crop = crop
super().__init__()
def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]:
output = []
shared.state.job_count = len(images)
shared.state.textinfo = f"{self.description}..."
for img, caption in tqdm(zip(images, captions), desc=self.description):
cropped = self._process_img(img, caption)
output.append(cropped)
shared.state.job_no += 1
return output
def _process_img(self, img, short_caption):
if self.subject_class is not None and self.subject_class != "":
short_caption = self.subject_class
src_ratio = img.width / img.height
# Pad image before cropping?
if src_ratio != 1 and self.pad:
if img.width > img.height:
pad_width = img.width
pad_height = img.width
else:
pad_width = img.height
pad_height = img.height
res = Image.new("RGB", (pad_width, pad_height))
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
img = res
if self.crop:
# Do the actual crop clip
im_data = self.model.get_center(img, prompt=short_caption)
crop_width = im_data[1] - im_data[0]
center_x = im_data[0] + (crop_width / 2)
crop_height = im_data[3] - im_data[2]
center_y = im_data[2] + (crop_height / 2)
crop_ratio = crop_width / crop_height
dest_ratio = 1
tgt_width = crop_width
tgt_height = crop_height
if crop_ratio != dest_ratio:
if crop_width > crop_height:
tgt_height = crop_width / dest_ratio
tgt_width = crop_width
else:
tgt_width = crop_height / dest_ratio
tgt_height = crop_height
# Reverse the above if dest is too big
if tgt_width > img.width or tgt_height > img.height:
if tgt_width > img.width:
tgt_width = img.width
tgt_height = tgt_width / dest_ratio
else:
tgt_height = img.height
tgt_width = tgt_height / dest_ratio
tgt_height = int(tgt_height)
tgt_width = int(tgt_width)
left = max(center_x - (tgt_width / 2), 0)
right = min(center_x + (tgt_width / 2), img.width)
top = max(center_y - (tgt_height / 2), 0)
bottom = min(center_y + (tgt_height / 2), img.height)
img = img.crop((left, top, right, bottom))
return img
def unload(self):
if self.model is not None:
self.model.unload()
super().unload()
# Upscale Processing
class UpscaleProcessor(Processor):
def __init__(self):
self.description = "Upscaling"
shared.state.textinfo = "Loading Stable-Diffusion Upscaling Model..."
self.sampler, self.model = super_resolution.initialize_model()
super().__init__()
def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]:
output = []
shared.state.job_count = len(images)
shared.state.textinfo = f"{self.description}..."
for img, caption in tqdm(zip(images, captions), desc=self.description):
seed = int(random.randrange(2147483647))
img = super_resolution.predict(self.sampler, img, caption, 75, 1, 10, seed, 0, 20)
output.append(img)
shared.state.job_no += 1
return output
def unload(self):
del self.sampler
del self.model
super().unload()