220 lines
7.5 KiB
Python
220 lines
7.5 KiB
Python
import gc
|
|
import random
|
|
from typing import List, Any
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printm
|
|
from extensions.sd_smartprocess import super_resolution
|
|
from extensions.sd_smartprocess.clipcrop import CropClip
|
|
from extensions.sd_smartprocess.interrogators.clip_interrogator import CLIPInterrogator
|
|
from extensions.sd_smartprocess.interrogators.booru_interrogator import BooruInterrogator
|
|
from modules import shared
|
|
|
|
# Base processor
|
|
class Processor:
|
|
def __init__(self):
|
|
printm("Model loaded.")
|
|
|
|
# Unload models
|
|
def unload(self):
|
|
if torch.has_cuda:
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
printm("Model unloaded.")
|
|
|
|
# Process images
|
|
def process(self, images: List[Image.Image]) -> List[Any]:
|
|
raise Exception("Not Implemented")
|
|
|
|
# CLIP Processing
|
|
class ClipProcessor(Processor):
|
|
def __init__(
|
|
self,
|
|
clip_use_v2,
|
|
clip_append_artist,
|
|
clip_append_medium,
|
|
clip_append_movement,
|
|
clip_append_flavor,
|
|
clip_append_trending,
|
|
num_beams,
|
|
min_clip_tokens,
|
|
max_clip_tokens,
|
|
max_flavors
|
|
):
|
|
self.description = "Processing CLIP"
|
|
if shared.interrogator is not None:
|
|
shared.interrogator.unload()
|
|
|
|
self.max_flavors = max_flavors
|
|
shared.state.textinfo = "Loading CLIP Model..."
|
|
self.model = CLIPInterrogator(
|
|
clip_use_v2,
|
|
clip_append_artist,
|
|
clip_append_medium,
|
|
clip_append_movement,
|
|
clip_append_flavor,
|
|
clip_append_trending,
|
|
num_beams,
|
|
min_clip_tokens,
|
|
max_clip_tokens
|
|
)
|
|
super().__init__()
|
|
|
|
def process(self, images: List[Image.Image], short:bool=False) -> List[str]:
|
|
output = []
|
|
shared.state.job_count = len(images)
|
|
shared.state.textinfo = f"{self.description}..."
|
|
for img in tqdm(images, desc=self.description):
|
|
short_caption = self.model.interrogate(img, short=short, max_flavors=self.max_flavors)
|
|
output.append(short_caption)
|
|
shared.state.current_image = img
|
|
shared.state.job_no += 1
|
|
return output
|
|
|
|
def unload(self):
|
|
if self.model.clip_model:
|
|
del self.model.clip_model
|
|
if self.model.blip_model:
|
|
del self.model.blip_model
|
|
super().unload()
|
|
|
|
# Danbooru Processing
|
|
class BooruProcessor(Processor):
|
|
def __init__(self, min_score: float):
|
|
self.description = "Processing Danbooru"
|
|
shared.state.textinfo = "Loading DeepDanbooru Model..."
|
|
self.model = BooruInterrogator()
|
|
self.min_score = min_score
|
|
super().__init__()
|
|
|
|
def process(self, images: List[Image.Image]) -> List[List[str]]:
|
|
output = []
|
|
shared.state.job_count = len(images)
|
|
shared.state.textinfo = f"{self.description}..."
|
|
for img in tqdm(images, desc=self.description):
|
|
out_tags = []
|
|
tags = self.model.interrogate(img)
|
|
for tag in sorted(tags, key=tags.get, reverse=True):
|
|
if tags[tag] >= self.min_score:
|
|
out_tags.append(tag)
|
|
output.append(out_tags)
|
|
shared.state.job_count += 1
|
|
|
|
def unload(self):
|
|
self.model.unload()
|
|
super().unload()
|
|
|
|
# WD14 Processing
|
|
|
|
# Crop Processing
|
|
class CropProcessor(Processor):
|
|
def __init__(self, subject_class: str, pad: bool, crop: bool):
|
|
self.description = "Cropping"
|
|
if crop:
|
|
shared.state.textinfo = "Loading CROP Model..."
|
|
self.model = CropClip() if crop else None
|
|
self.subject_class = subject_class
|
|
self.pad = pad
|
|
self.crop = crop
|
|
super().__init__()
|
|
|
|
def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]:
|
|
output = []
|
|
shared.state.job_count = len(images)
|
|
shared.state.textinfo = f"{self.description}..."
|
|
for img, caption in tqdm(zip(images, captions), desc=self.description):
|
|
cropped = self._process_img(img, caption)
|
|
output.append(cropped)
|
|
shared.state.job_no += 1
|
|
return output
|
|
|
|
|
|
def _process_img(self, img, short_caption):
|
|
if self.subject_class is not None and self.subject_class != "":
|
|
short_caption = self.subject_class
|
|
|
|
src_ratio = img.width / img.height
|
|
|
|
# Pad image before cropping?
|
|
if src_ratio != 1 and self.pad:
|
|
if img.width > img.height:
|
|
pad_width = img.width
|
|
pad_height = img.width
|
|
else:
|
|
pad_width = img.height
|
|
pad_height = img.height
|
|
res = Image.new("RGB", (pad_width, pad_height))
|
|
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
|
|
img = res
|
|
|
|
if self.crop:
|
|
# Do the actual crop clip
|
|
im_data = self.model.get_center(img, prompt=short_caption)
|
|
crop_width = im_data[1] - im_data[0]
|
|
center_x = im_data[0] + (crop_width / 2)
|
|
crop_height = im_data[3] - im_data[2]
|
|
center_y = im_data[2] + (crop_height / 2)
|
|
crop_ratio = crop_width / crop_height
|
|
dest_ratio = 1
|
|
tgt_width = crop_width
|
|
tgt_height = crop_height
|
|
|
|
if crop_ratio != dest_ratio:
|
|
if crop_width > crop_height:
|
|
tgt_height = crop_width / dest_ratio
|
|
tgt_width = crop_width
|
|
else:
|
|
tgt_width = crop_height / dest_ratio
|
|
tgt_height = crop_height
|
|
|
|
# Reverse the above if dest is too big
|
|
if tgt_width > img.width or tgt_height > img.height:
|
|
if tgt_width > img.width:
|
|
tgt_width = img.width
|
|
tgt_height = tgt_width / dest_ratio
|
|
else:
|
|
tgt_height = img.height
|
|
tgt_width = tgt_height / dest_ratio
|
|
|
|
tgt_height = int(tgt_height)
|
|
tgt_width = int(tgt_width)
|
|
left = max(center_x - (tgt_width / 2), 0)
|
|
right = min(center_x + (tgt_width / 2), img.width)
|
|
top = max(center_y - (tgt_height / 2), 0)
|
|
bottom = min(center_y + (tgt_height / 2), img.height)
|
|
img = img.crop((left, top, right, bottom))
|
|
return img
|
|
|
|
def unload(self):
|
|
if self.model is not None:
|
|
self.model.unload()
|
|
super().unload()
|
|
|
|
# Upscale Processing
|
|
class UpscaleProcessor(Processor):
|
|
def __init__(self):
|
|
self.description = "Upscaling"
|
|
shared.state.textinfo = "Loading Stable-Diffusion Upscaling Model..."
|
|
self.sampler, self.model = super_resolution.initialize_model()
|
|
super().__init__()
|
|
|
|
def process(self, images: List[Image.Image], captions: List[str] = None) -> List[Image.Image]:
|
|
output = []
|
|
shared.state.job_count = len(images)
|
|
shared.state.textinfo = f"{self.description}..."
|
|
for img, caption in tqdm(zip(images, captions), desc=self.description):
|
|
seed = int(random.randrange(2147483647))
|
|
img = super_resolution.predict(self.sampler, img, caption, 75, 1, 10, seed, 0, 20)
|
|
output.append(img)
|
|
shared.state.job_no += 1
|
|
return output
|
|
|
|
def unload(self):
|
|
del self.sampler
|
|
del self.model
|
|
super().unload()
|
|
|