Super Update

Add WD14 tagger.
Add CLIP v2.1 interrogator.
Allow filtering wd14, booru tags by score.
Add don't rename option.
Update ReallySafe
Bump BLIP version?
Remove split image options.
Completely overhaul smartprocess...
pull/12/head
d8ahazard 2022-12-10 17:33:34 -06:00
parent 93380435c0
commit 58bcff7f78
13 changed files with 107608 additions and 325 deletions

View File

@ -8,11 +8,9 @@ import numpy as np
import torch
from PIL import Image
from clip import clip
from transformers import CLIPProcessor, CLIPModel, pipeline
import modules.paths
from modules import shared, modelloader
from repositories.CodeFormer.facelib.detection.yolov5face.utils.general import xyxy2xywh, xywh2xyxy
def clip_boxes(boxes, shape):

362
clipinterrogator.py Normal file
View File

@ -0,0 +1,362 @@
import hashlib
import inspect
import math
import numpy as np
import open_clip
import os
import pickle
import time
import torch
from dataclasses import dataclass
from models.blip import blip_decoder, BLIP_Decoder
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from typing import List
from extensions.sd_smartprocess.interrogator import Interrogator
@dataclass
class Config:
# models can optionally be passed in directly
blip_model: BLIP_Decoder = None
clip_model = None
clip_preprocess = None
# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 32
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_num_beams: int = 8
blip_offload: bool = False
# clip settings
clip_model_name: str = 'ViT-L-14/openai'
clip_model_path: str = None
# interrogator settings
cache_path: str = 'cache'
chunk_size: int = 2048
data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown
class ClipInterrogator(Interrogator):
def __init__(self,
use_v2,
append_artist,
append_medium,
append_movement,
append_flavor,
append_trending):
if use_v2:
model_name = "ViT-H-14/laion2b_s32b_b79k"
else:
model_name = "ViT-L-14/openai"
self.append_artist = append_artist
self.append_medium = append_medium
self.append_movement = append_movement
self.append_trending = append_trending
self.append_flavor = append_flavor
self.artists = None
self.flavors = None
self.mediums = None
self.movements = None
self.tokenize = None
self.trendings = None
self.clip_model = None
self.clip_preprocess = None
config = Config
config.clip_model_name = model_name
self.config = config
self.device = config.device
if config.blip_model is None:
if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json')
blip_model = blip_decoder(
pretrained=config.blip_model_url,
image_size=config.blip_image_eval_size,
vit='large',
med_config=med_config
)
blip_model.eval()
blip_model = blip_model.to(config.device)
self.blip_model = blip_model
else:
self.blip_model = config.blip_model
self.load_clip_model()
def load_clip_model(self):
start_time = time.time()
config = self.config
if config.clip_model is None:
if not config.quiet:
print("Loading CLIP model...")
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=clip_model_pretrained_name,
precision='fp16' if config.device == 'cuda' else 'fp32',
device=config.device,
jit=False,
cache_dir=config.clip_model_path
)
self.clip_model.to(config.device).eval()
else:
self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess
clip_model_name = config.clip_model_name
self.tokenize = open_clip.get_tokenizer(clip_model_name)
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram',
'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash',
'zbrush central']
trending_list = [site for site in sites]
trending_list.extend(["trending on " + site for site in sites])
trending_list.extend(["featured on " + site for site in sites])
trending_list.extend([site + " contest winner" for site in sites])
raw_artists = _load_list(config.data_path, 'artists.txt')
artists = [f"by {a}" for a in raw_artists]
artists.extend([f"inspired by {a}" for a in raw_artists])
if self.append_artist:
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
if self.append_flavor:
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model,
self.tokenize, config)
if self.append_medium:
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model,
self.tokenize, config)
if self.append_movement:
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model,
self.tokenize, config)
if self.append_trending:
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
end_time = time.time()
if not config.quiet:
print(f"Loaded CLIP model and data in {end_time - start_time:.2f} seconds.")
def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device)
size = self.config.blip_image_eval_size
gpu_image = transforms.Compose([
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
caption = self.blip_model.generate(
gpu_image,
sample=False,
num_beams=self.config.blip_num_beams,
max_length=self.config.blip_max_length,
min_length=5
)
if self.config.blip_offload:
self.blip_model = self.blip_model.to("cpu")
return caption[0]
def image_to_features(self, image: Image) -> torch.Tensor:
images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = self.clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def interrogate(self, image: Image, max_flavors: int = 32, short=False) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
best_prompt = caption
best_sim = self.similarity(image_features, best_prompt)
def check(addition: str) -> bool:
nonlocal best_prompt, best_sim
prompt = best_prompt + ", " + addition
sim = self.similarity(image_features, prompt)
if sim > best_sim:
best_sim = sim
best_prompt = prompt
return True
return False
def check_multi_batch(opts: List[str]):
nonlocal best_prompt, best_sim
prompts = []
for i in range(2 ** len(opts)):
prompt = best_prompt
for bit in range(len(opts)):
if i & (1 << bit):
prompt += ", " + opts[bit]
prompts.append(prompt)
t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
best_prompt = t.rank(image_features, 1)[0]
best_sim = self.similarity(image_features, best_prompt)
batch = []
if not short:
if self.append_artist:
batch.append(self.artists.rank(image_features, 1)[0])
if self.append_flavor:
best_flavors = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
extended_flavors = set(best_flavors)
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt) + 2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
extended_flavors.remove(flave)
if self.append_medium:
batch.append(self.mediums.rank(image_features, 1)[0])
if self.append_trending:
batch.append(self.trendings.rank(image_features, 1)[0])
if self.append_movement:
batch.append(self.movements.rank(image_features, 1)[0])
check_multi_batch(batch)
tags = best_prompt.split(",")
else:
tags = [best_prompt]
return tags
def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str:
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
return text_array[similarity.argmax().item()]
def similarity(self, image_features: torch.Tensor, text: str) -> float:
text_tokens = self.tokenize([text]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T
return similarity[0][0].item()
class LabelTable:
def __init__(self, labels: List[str], desc: str, clip_model, tokenize, config: Config):
self.chunk_size = config.chunk_size
self.config = config
self.device = config.device
self.embeds = []
self.labels = labels
self.tokenize = tokenize
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
cache_filepath = None
if config.cache_path is not None and desc is not None:
os.makedirs(config.cache_path, exist_ok=True)
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_')
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
if desc is not None and os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f:
try:
data = pickle.load(f)
if data.get('hash') == hash:
self.labels = data['labels']
self.embeds = data['embeds']
except Exception as e:
print(f"Error loading cached table {desc}: {e}")
if len(self.labels) != len(self.embeds):
self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels) / config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
text_tokens = self.tokenize(chunk).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy()
for i in range(text_features.shape[0]):
self.embeds.append(text_features[i])
if cache_filepath is not None:
with open(cache_filepath, 'wb') as f:
pickle.dump({
"labels": self.labels,
"embeds": self.embeds,
"hash": hash,
"model": config.clip_model_name
}, f)
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int = 1) -> str:
top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
with torch.cuda.amp.autocast():
similarity = image_features @ text_embeds.T
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
return [top_labels[0][i].numpy() for i in range(top_count)]
def rank(self, image_features: torch.Tensor, top_count: int = 1) -> List[str]:
if len(self.labels) <= self.chunk_size:
tops = self._rank(image_features, self.embeds, top_count=top_count)
return [self.labels[i] for i in tops]
num_chunks = int(math.ceil(len(self.labels) / self.chunk_size))
keep_per_chunk = int(self.chunk_size / num_chunks)
top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx * self.chunk_size
stop = min(start + self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
top_labels.extend([self.labels[start + i] for i in tops])
top_embeds.extend([self.embeds[start + i] for i in tops])
tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops]
def _load_list(data_path: str, filename: str) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, None, config)
for table in tables:
m.labels.extend(table.labels)
m.embeds.extend(table.embeds)
return m
def _prompt_at_max_len(text: str, tokenize) -> bool:
tokens = tokenize([text])
return tokens[0][-1] != 0
def _truncate_to_fit(text: str, tokenize) -> str:
parts = text.split(', ')
new_text = parts[0]
for part in parts[1:]:
if _prompt_at_max_len(new_text + part, tokenize):
break
new_text += ', ' + part
return new_text

5265
data/artists.txt Normal file

File diff suppressed because it is too large Load Diff

100970
data/flavors.txt Normal file

File diff suppressed because it is too large Load Diff

95
data/mediums.txt Normal file
View File

@ -0,0 +1,95 @@
a 3D render
a black and white photo
a bronze sculpture
a cartoon
a cave painting
a character portrait
a charcoal drawing
a child's drawing
a color pencil sketch
a colorized photo
a comic book panel
a computer rendering
a cross stitch
a cubist painting
a detailed drawing
a detailed matte painting
a detailed painting
a diagram
a digital painting
a digital rendering
a drawing
a fine art painting
a flemish Baroque
a gouache
a hologram
a hyperrealistic painting
a jigsaw puzzle
a low poly render
a macro photograph
a manga drawing
a marble sculpture
a matte painting
a microscopic photo
a mid-nineteenth century engraving
a minimalist painting
a mosaic
a painting
a pastel
a pencil sketch
a photo
a photocopy
a photorealistic painting
a picture
a pointillism painting
a polaroid photo
a pop art painting
a portrait
a poster
a raytraced image
a renaissance painting
a screenprint
a screenshot
a silk screen
a sketch
a statue
a still life
a stipple
a stock photo
a storybook illustration
a surrealist painting
a surrealist sculpture
a tattoo
a tilt shift photo
a watercolor painting
a wireframe diagram
a woodcut
an abstract drawing
an abstract painting
an abstract sculpture
an acrylic painting
an airbrush painting
an album cover
an ambient occlusion render
an anime drawing
an art deco painting
an art deco sculpture
an engraving
an etching
an illustration of
an impressionist painting
an ink drawing
an oil on canvas painting
an oil painting
an ultrafine detailed painting
chalk art
computer graphics
concept art
cyberpunk art
digital art
egyptian art
graffiti art
lineart
pixel art
poster art
vector art

200
data/movements.txt Normal file
View File

@ -0,0 +1,200 @@
abstract art
abstract expressionism
abstract illusionism
academic art
action painting
aestheticism
afrofuturism
altermodern
american barbizon school
american impressionism
american realism
american romanticism
american scene painting
analytical art
antipodeans
arabesque
arbeitsrat für kunst
art & language
art brut
art deco
art informel
art nouveau
art photography
arte povera
arts and crafts movement
ascii art
ashcan school
assemblage
australian tonalism
auto-destructive art
barbizon school
baroque
bauhaus
bengal school of art
berlin secession
black arts movement
brutalism
classical realism
cloisonnism
cobra
color field
computer art
conceptual art
concrete art
constructivism
context art
crayon art
crystal cubism
cubism
cubo-futurism
cynical realism
dada
danube school
dau-al-set
de stijl
deconstructivism
digital art
ecological art
environmental art
excessivism
expressionism
fantastic realism
fantasy art
fauvism
feminist art
figuration libre
figurative art
figurativism
fine art
fluxus
folk art
funk art
furry art
futurism
generative art
geometric abstract art
german romanticism
gothic art
graffiti
gutai group
happening
harlem renaissance
heidelberg school
holography
hudson river school
hurufiyya
hypermodernism
hyperrealism
impressionism
incoherents
institutional critique
interactive art
international gothic
international typographic style
kinetic art
kinetic pointillism
kitsch movement
land art
les automatistes
les nabis
letterism
light and space
lowbrow
lyco art
lyrical abstraction
magic realism
magical realism
mail art
mannerism
massurrealism
maximalism
metaphysical painting
mingei
minimalism
modern european ink painting
modernism
modular constructivism
naive art
naturalism
neo-dada
neo-expressionism
neo-fauvism
neo-figurative
neo-primitivism
neo-romanticism
neoclassicism
neogeo
neoism
neoplasticism
net art
new objectivity
new sculpture
northwest school
nuclear art
objective abstraction
op art
optical illusion
orphism
panfuturism
paris school
photorealism
pixel art
plasticien
plein air
pointillism
pop art
pop surrealism
post-impressionism
postminimalism
pre-raphaelitism
precisionism
primitivism
private press
process art
psychedelic art
purism
qajar art
quito school
rasquache
rayonism
realism
regionalism
remodernism
renaissance
retrofuturism
rococo
romanesque
romanticism
samikshavad
serial art
shin hanga
shock art
socialist realism
sots art
space art
street art
stuckism
sumatraism
superflat
suprematism
surrealism
symbolism
synchromism
synthetism
sōsaku hanga
tachisme
temporary art
tonalism
toyism
transgressive art
ukiyo-e
underground comix
unilalianism
vancouver school
vanitas
verdadism
video art
viennese actionism
visual art
vorticism

55
dbimutils.py Normal file
View File

@ -0,0 +1,55 @@
# DanBooru IMage Utility functions, borrowed from
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/dbimutils.py
import cv2
import numpy as np
from PIL import Image
def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
if img.endswith(".gif"):
img = Image.open(img)
img = img.convert("RGB")
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
else:
img = cv2.imread(img, flag)
return img
def smart_24bit(img):
if img.dtype is np.dtype(np.uint16):
img = (img / 257).astype(np.uint8)
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4:
trans_mask = img[:, :, 3] == 0
img[trans_mask] = [255, 255, 255, 255]
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
return img
def make_square(img, target_size):
old_size = img.shape[:2]
desired_size = max(old_size)
desired_size = max(desired_size, target_size)
delta_w = desired_size - old_size[1]
delta_h = desired_size - old_size[0]
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
left, right = delta_w // 2, delta_w - (delta_w // 2)
color = [255, 255, 255]
new_im = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
)
return new_im
def smart_resize(img, size):
# Assumes the image has already gone through make_square
if img.shape[0] > size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
elif img.shape[0] < size:
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
return img

View File

@ -1,9 +1,12 @@
import os
import sys
from launch import run
from launch import run, git_clone, repo_dir
name = "Smart Crop"
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
print(f"loading {name} reqs from {req_file}")
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements.", f"Couldn't install {name} requirements.")
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements.",
f"Couldn't install {name} requirements.")
blip_repo = "https://github.com/pharmapsychotic/BLIP"
git_clone(blip_repo, repo_dir('BLIP'), "BLIP")

252
interrogator.py Normal file
View File

@ -0,0 +1,252 @@
# Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/interrogator.py
import os
import re
import sys
import traceback
from collections import namedtuple
from pathlib import Path
from typing import Tuple, Dict
import numpy as np
import pandas as pd
import torch
import open_clip
from PIL import Image
from huggingface_hub import hf_hub_download
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.deepbooru
import modules.shared as shared
from extensions.sd_smartprocess import dbimutils
from modules import devices, paths, lowvram, modelloader
from modules import images
from modules.deepbooru import re_special as tag_escape_pattern
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
use_cpu = shared.cmd_opts.use_cpu == 'all' or shared.cmd_opts.use_cpu == 'interrogate'
onyx_providers = []
if use_cpu:
tf_device_name = '/cpu:0'
onyx_providers = ['CPUExecutionProvider']
else:
tf_device_name = '/gpu:0'
onyx_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
if shared.cmd_opts.device_id is not None:
try:
tf_device_name = f'/gpu:{int(shared.cmd_opts.device_id)}'
except ValueError:
print('--device-id is not a integer')
class Interrogator:
@staticmethod
def postprocess_tags(
tags: Dict[str, float],
threshold=0.35,
additional_tags=None,
exclude_tags=None,
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=False,
replace_underscore_excludes=None,
escape_tag=False
) -> Dict[str, float]:
if replace_underscore_excludes is None:
replace_underscore_excludes = []
if exclude_tags is None:
exclude_tags = []
if additional_tags is None:
additional_tags = []
tags = {
**{t: 1.0 for t in additional_tags},
**tags
}
# those lines are totally not "pythonic" but looks better to me
tags = {
t: c
# sort by tag name or confident
for t, c in sorted(
tags.items(),
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
reverse=not sort_by_alphabetical_order
)
# filter tags
if (
c >= threshold
and t not in exclude_tags
)
}
for tag in list(tags):
new_tag = tag
if replace_underscore and tag not in replace_underscore_excludes:
new_tag = new_tag.replace('_', ' ')
if escape_tag:
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
if add_confident_as_weight:
new_tag = f'({new_tag}:{tags[tag]})'
if new_tag != tag:
tags[new_tag] = tags.pop(tag)
return tags
def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidence
Dict[str, float] # tag confidence
]:
pass
re_special = re.compile(r'([\\()])')
class BooruInterrogator(Interrogator):
def __init__(self) -> None:
self.tags = None
self.booru = modules.deepbooru.DeepDanbooru()
self.booru.start()
self.model = self.booru.model
def unload(self):
self.booru.stop()
def interrogate(self, pil_image) -> Dict[str, float]:
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}
for tag, probability in zip(self.model.tags, y):
if tag.startswith("rating:"):
continue
probability_dict[tag] = probability
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
output = {}
for tag in tags:
probability = probability_dict[tag]
tag_outformat = tag
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
output[tag_outformat] = probability
return output
class WaifuDiffusionInterrogator(Interrogator):
def __init__(
self,
repo='SmilingWolf/wd-v1-4-vit-tagger',
model_path='model.onnx',
tags_path='selected_tags.csv'
) -> None:
self.tags = None
self.model = None
self.repo = repo
self.model_path = model_path
self.tags_path = tags_path
self.load()
def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f'Loading Waifu Diffusion tagger model file from {self.repo}')
model_path = Path(hf_hub_download(self.repo, filename=self.model_path))
tags_path = Path(hf_hub_download(self.repo, filename=self.tags_path))
return model_path, tags_path
def load(self) -> None:
model_path, tags_path = self.download()
from launch import is_installed, run_pip
if not is_installed('onnxruntime'):
package_name = 'onnxruntime-gpu'
if use_cpu or not torch.cuda.is_available():
package_name = 'onnxruntime'
package = os.environ.get(
'ONNXRUNTIME_PACKAGE',
package_name
)
run_pip(f'install {package}', package_name)
from onnxruntime import InferenceSession
self.model = InferenceSession(str(model_path), providers=onyx_providers)
print(f'Loaded Waifu Diffusion tagger model from {model_path}')
self.tags = pd.read_csv(tags_path)
def unload(self):
if self.model is not None:
del self.model
def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidence
Dict[str, float] # tag confidence
]:
# code for converting the image and running the model is taken from the link below
# thanks, SmilingWolf!
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
# convert an image to fit the model
_, height, _, _ = self.model.get_inputs()[0].shape
# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)
# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]
image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)
# evaluate model
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
confidence = self.model.run([label_name], {input_name: image})[0]
tags = self.tags[:][['name']]
tags['confidence'] = confidence[0]
# first 4 items are for rating (general, sensitive, questionable, explicit)
ratings = dict(tags[:4].values)
# rest are regular tags
tags = dict(tags[4:].values)
return ratings, tags

View File

@ -1,12 +1,16 @@
import _codecs
import collections
import pickle
import re
import sys
import traceback
import zipfile
import numpy
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
import torch
from modules import safe
from modules.safe import TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
@ -15,17 +19,24 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
if self.extra_handler is not None:
res = self.extra_handler(module, name)
if res is not None:
return res
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage',
'ByteStorage']:
'ByteStorage', 'BFloat16Storage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict', 'Sequential']:
return getattr(torch.nn.modules.container, name)
@ -54,7 +65,98 @@ class RestrictedUnpickler(pickle.Unpickler):
return set
# Forbid everything else.
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
raise Exception(f"global '{module}/{name}' is forbidden")
safe.RestrictedUnpickler = RestrictedUnpickler
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
"""
this functon is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
and returns that field's value:
```python
def extra(module, name):
if module == 'collections' and name == 'OrderedDict':
return collections.OrderedDict
return None
safe.load_with_extra('model.pt', extra_handler=extra)
```
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
definitely unsafe.
"""
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
unsafe_torch_load = torch.load
torch.load = load

View File

@ -1,2 +1,4 @@
ipython==8.6.0
seaborn==0.12.1
seaborn==0.12.1
tensorflow
open_clip_torch

View File

@ -1,16 +1,17 @@
import gradio as gr
from extensions.sd_smartprocess import smartprocess
from modules import script_callbacks, shared
from modules.shared import cmd_opts
from modules.ui import setup_progressbar, gr_show
from modules.ui import setup_progressbar
from webui import wrap_gradio_gpu_call
import smartprocess
def on_ui_tabs():
with gr.Blocks() as sp_interface:
with gr.Row(equal_height=True):
with gr.Column(variant="panel"):
sp_rename = gr.Checkbox(label="Rename images", value=False)
with gr.Tab("Directories"):
sp_src = gr.Textbox(label='Source directory')
sp_dst = gr.Textbox(label='Destination directory')
@ -20,23 +21,27 @@ def on_ui_tabs():
sp_pad = gr.Checkbox(label="Pad Images")
sp_crop = gr.Checkbox(label='Crop Images')
sp_flip = gr.Checkbox(label='Create flipped copies')
sp_split = gr.Checkbox(label='Split over-sized images')
sp_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0,
maximum=1.0,
step=0.05)
sp_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0,
maximum=0.9, step=0.05)
with gr.Tab("Captions"):
sp_caption = gr.Checkbox(label='Generate Captions')
sp_caption_length = gr.Number(label='Max Caption length (0=unlimited)', value=0, precision=0)
sp_txt_action = gr.Dropdown(label='Existing Caption Action', value="ignore",
choices=["ignore", "copy", "prepend", "append"])
sp_caption_append_file = gr.Checkbox(label="Append Caption to File Name", value=True)
sp_caption_save_txt = gr.Checkbox(label="Save Caption to .txt File", value=False)
sp_caption_deepbooru = gr.Checkbox(label='Append DeepDanbooru to Caption',
sp_caption_clip = gr.Checkbox(label="Add CLIP results to Caption")
sp_clip_use_v2 = gr.Checkbox(label="Use v2 CLIP Model", value=True)
sp_clip_append_flavor = gr.Checkbox(label="Append Flavor tags from CLIP")
sp_clip_append_medium = gr.Checkbox(label="Append Medium tags from CLIP")
sp_clip_append_movement = gr.Checkbox(label="Append Movement tags from CLIP")
sp_clip_append_artist = gr.Checkbox(label="Append Artist tags from CLIP")
sp_clip_append_trending = gr.Checkbox(label="Append Trending tags from CLIP")
sp_caption_wd14 = gr.Checkbox(label="Add WD14 Tags to Caption")
sp_wd14_min_score = gr.Slider(label="Minimum Score for WD14 Tags", value=0.75, minimum=0.01, maximum=1,
step=0.01)
sp_caption_deepbooru = gr.Checkbox(label='Add DeepDanbooru Tags to Caption',
visible=True if cmd_opts.deepdanbooru else False)
sp_replace_class = gr.Checkbox(label='Replace Class with Subject in Caption', value=True)
sp_booru_min_score = gr.Slider(label="Minimum Score for DeepDanbooru Tags", value=0.75,
minimum=0.01, maximum=1, step=0.01)
sp_replace_class = gr.Checkbox(label='Replace Class with Subject in Caption', value=False)
sp_class = gr.Textbox(label='Subject Class', placeholder='Subject class to crop (leave '
'blank to auto-detect)')
sp_subject = gr.Textbox(label='Subject Name', placeholder='Subject Name to replace class '
@ -73,21 +78,27 @@ def on_ui_tabs():
fn=wrap_gradio_gpu_call(smartprocess.preprocess, extra_outputs=[gr.update()]),
_js="start_smart_process",
inputs=[
sp_rename,
sp_src,
sp_dst,
sp_pad,
sp_crop,
sp_size,
sp_caption_append_file,
sp_caption_save_txt,
sp_txt_action,
sp_flip,
sp_split,
sp_caption,
sp_caption_length,
sp_caption_clip,
sp_clip_use_v2,
sp_clip_append_flavor,
sp_clip_append_medium,
sp_clip_append_movement,
sp_clip_append_artist,
sp_clip_append_trending,
sp_caption_wd14,
sp_wd14_min_score,
sp_caption_deepbooru,
sp_split_threshold,
sp_overlap_ratio,
sp_booru_min_score,
sp_class,
sp_subject,
sp_replace_class,

View File

@ -1,56 +1,51 @@
import math
import os
import sys
import traceback
import PIL
import numpy as np
import tqdm
from PIL import Image, ImageOps
from clipcrop import CropClip
import reallysafe
from modules import shared, images, safe
import modules.gfpgan_model
import modules.codeformer_model
from modules.shared import opts, cmd_opts
import modules.gfpgan_model
import reallysafe
from clipcrop import CropClip
from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator
from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator
from modules import shared, images, safe
from modules.shared import cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def interrogate_image(image: Image, full=False):
if not full:
prev_artists = shared.opts.interrogate_use_builtin_artists
prev_max = shared.opts.interrogate_clip_max_length
prev_min = shared.opts.interrogate_clip_min_length
shared.opts.interrogate_clip_min_length = 10
shared.opts.interrogate_clip_max_length = 20
shared.opts.interrogate_use_builtin_artists = False
caption = shared.interrogator.interrogate(image)
shared.opts.interrogate_clip_min_length = prev_min
shared.opts.interrogate_clip_max_length = prev_max
shared.opts.interrogate_use_builtin_artists = prev_artists
else:
caption = shared.interrogator.interrogate(image)
return caption
def printi(message):
shared.state.textinfo = message
print(message)
def preprocess(src,
def preprocess(rename,
src,
dst,
pad,
crop,
width,
append_filename,
save_txt,
pretxt_action,
max_size,
txt_action,
flip,
split,
caption,
caption_length,
caption_clip,
clip_use_v2,
clip_append_flavor,
clip_append_medium,
clip_append_movement,
clip_append_artist,
clip_append_trending,
caption_wd14,
wd14_min_score,
caption_deepbooru,
split_threshold,
overlap_ratio,
booru_min_score,
subject_class,
subject,
replace_class,
@ -60,324 +55,297 @@ def preprocess(src,
upscale_ratio,
scaler
):
try:
if pad and crop:
crop = False
shared.state.textinfo = "Loading models for smart processing..."
shared.state.textinfo = "Initializing smart processing..."
safe.RestrictedUnpickler = reallysafe.RestrictedUnpickler
if caption:
shared.interrogator.load()
if caption_deepbooru:
deepbooru.model.start()
if not crop and not caption and not restore_faces and not upscale and not pad:
msg = "Nothing to do."
printi(msg)
return msg, msg
prework(src,
dst,
pad,
crop,
width,
append_filename,
save_txt,
pretxt_action,
flip,
split,
caption,
caption_length,
caption_deepbooru,
split_threshold,
overlap_ratio,
subject_class,
subject,
replace_class,
restore_faces,
face_model,
upscale,
upscale_ratio,
scaler)
wd_interrogator = None
db_interrogator = None
clip_interrogator = None
crop_clip = None
finally:
if caption or crop:
printi("\rLoading captioning models...")
if caption_clip or crop:
printi("\rLoading CLIP interrogator...")
if shared.interrogator is not None:
shared.interrogator.unload()
clip_interrogator = ClipInterrogator(clip_use_v2,
clip_append_artist,
clip_append_medium,
clip_append_movement,
clip_append_flavor,
clip_append_trending)
if caption:
shared.interrogator.send_blip_to_ram()
if caption_deepbooru:
printi("\rLoading Deepbooru interrogator...")
db_interrogator = BooruInterrogator()
if caption_deepbooru:
deepbooru.model.stop()
if caption_wd14:
printi("\rLoading wd14 interrogator...")
wd_interrogator = WaifuDiffusionInterrogator()
return "Processing complete.", ""
if crop:
printi("Loading YOLOv5 interrogator...")
crop_clip = CropClip()
del sys.modules['models']
def prework(src,
dst,
pad_image,
crop_image,
width,
append_filename,
save_txt,
pretxt_action,
flip,
split,
caption_image,
caption_length,
caption_deepbooru,
split_threshold,
overlap_ratio,
subject_class,
subject,
replace_class,
restore_faces,
face_model,
upscale,
upscale_ratio,
scaler):
try:
del sys.modules['models']
except:
pass
width = width
height = width
src = os.path.abspath(src)
dst = os.path.abspath(dst)
src = os.path.abspath(src)
dst = os.path.abspath(dst)
if not crop_image and not caption_image and not restore_faces and not upscale and not pad_image:
print("Nothing to do.")
shared.state.textinfo = "Nothing to do!"
return
if src == dst:
msg = "Source and destination are the same, returning."
printi(msg)
return msg, msg
assert src != dst, 'same directory specified as source and destination'
os.makedirs(dst, exist_ok=True)
os.makedirs(dst, exist_ok=True)
files = os.listdir(src)
files = os.listdir(src)
printi("Preprocessing...")
shared.state.job_count = len(files)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
def build_caption(image):
# Read existing caption from path/txt file
existing_caption_txt_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_txt_filename):
with open(existing_caption_txt_filename, 'r', encoding="utf8") as file:
existing_caption_txt = file.read()
else:
existing_caption_txt = ''.join(c for c in filename if c.isalpha() or c in [" ", ", "])
def build_caption(image):
existing_caption = None
if not append_filename:
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename):
with open(existing_caption_filename, 'r', encoding="utf8") as file:
existing_caption = file.read()
else:
existing_caption = ''.join(c for c in filename if c.isalpha() or c in [" ", ","])
out_tags = []
if clip_interrogator is not None:
if caption_clip:
tags = clip_interrogator.interrogate(img)
for tag in tags:
#print(f"CLIPTag: {tag}")
out_tags.append(tag)
caption = ""
if caption_image:
caption = interrogate_image(img, True)
if wd_interrogator is not None:
ratings, tags = wd_interrogator.interrogate(img)
if caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.model.tag_multi(image)
for rating in ratings:
#print(f"Rating {rating} score is {ratings[rating]}")
if ratings[rating] >= wd14_min_score:
out_tags.append(rating)
if pretxt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif pretxt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif pretxt_action == 'copy' and existing_caption:
caption = existing_caption
for tag in sorted(tags, key=tags.get, reverse=True):
if tags[tag] >= wd14_min_score:
#print(f"WDTag {tag} score is {tags[tag]}")
out_tags.append(tag)
else:
break
caption = caption.strip()
if replace_class and subject is not None and subject_class is not None:
# Find and replace "a SUBJECT CLASS" in caption with subject name
if f"a {subject_class}" in caption:
caption = caption.replace(f"a {subject_class}", subject)
if caption_deepbooru:
tags = db_interrogator.interrogate(image)
for tag in sorted(tags, key=tags.get, reverse=True):
if tags[tag] >= booru_min_score:
#print(f"DBTag {tag} score is {tags[tag]}")
out_tags.append(tag)
if subject_class in caption:
caption = caption.replace(subject_class, subject)
# Remove duplicates
unique_tags = []
for tag in out_tags:
if not tag in unique_tags:
unique_tags.append(tag.strip())
caption_txt = ", ".join(unique_tags)
if 0 < caption_length < len(caption):
split_cap = caption.split(" ")
caption = ""
cap_test = ""
split_idx = 0
while True and split_idx < len(split_cap):
cap_test += f" {split_cap[split_idx]}"
if len(cap_test) < caption_length:
caption = cap_test
split_idx += 1
if txt_action == 'prepend' and existing_caption_txt:
caption_txt = existing_caption_txt + ' ' + caption_txt
elif txt_action == 'append' and existing_caption_txt:
caption_txt = caption_txt + ' ' + existing_caption_txt
elif txt_action == 'copy' and existing_caption_txt:
caption_txt = existing_caption_txt
caption = caption.strip()
return caption
caption_txt = caption_txt.strip()
if replace_class and subject is not None and subject_class is not None:
# Find and replace "a SUBJECT CLASS" in caption_txt with subject name
if f"a {subject_class}" in caption_txt:
caption_txt = caption_txt.replace(f"a {subject_class}", subject)
def save_pic_with_caption(image, img_index, existing_caption):
if subject_class in caption_txt:
caption_txt = caption_txt.replace(subject_class, subject)
if append_filename:
filename_part = existing_caption
basename = f"{img_index:05}-{subindex[0]}-{filename_part}"
else:
basename = f"{img_index:05}-{subindex[0]}"
if 0 < caption_length < len(caption_txt):
split_cap = caption_txt.split(" ")
caption_txt = ""
cap_test = ""
split_idx = 0
while True and split_idx < len(split_cap):
cap_test += f" {split_cap[split_idx]}"
if len(cap_test) < caption_length:
caption_txt = cap_test
split_idx += 1
shared.state.current_image = img
image.save(os.path.join(dst, f"{basename}.png"))
caption_txt = caption_txt.strip()
return caption_txt
if save_txt:
if len(existing_caption) > 0:
def save_pic(image, src_name, img_index, existing_caption=None, flipped=False):
if rename:
basename = f"{img_index:05}"
else:
basename = src_name
if flipped:
basename += "_flipped"
shared.state.current_image = img
image.save(os.path.join(dst, f"{basename}.png"))
if existing_caption is not None and len(existing_caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(existing_caption)
subindex[0] += 1
image_index = 0
def save_pic(image, img_index, existing_caption=None):
save_pic_with_caption(image, img_index, existing_caption=existing_caption)
# Enumerate images
for index, src_image in enumerate(tqdm.tqdm(files)):
# Quit on cancel
if shared.state.interrupted:
msg = f"Processing interrupted, {index}/{len(files)}"
return msg, msg
if flip:
save_pic_with_caption(ImageOps.mirror(image), img_index, existing_caption=existing_caption)
filename = os.path.join(src, src_image)
try:
img = Image.open(filename).convert("RGB")
except Exception as e:
msg = f"Exception processing: {e}"
printi(msg)
traceback.print_exc()
return msg, msg
def split_pic(image, img_inverse_xy):
if img_inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if img_inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
if crop:
# Interrogate once
short_caption = clip_interrogator.interrogate(img, short=True)
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if img_inverse_xy:
split_img = image.crop((y, 0, y + to_h, to_w))
else:
split_img = image.crop((0, y, to_w, y + to_h))
yield split_img
if subject_class is not None and subject_class != "":
short_caption = subject_class
crop_clip = None
shared.state.textinfo = f"Cropping: {short_caption}"
src_ratio = img.width / img.height
if crop_image:
split_threshold = max(0.0, min(1.0, split_threshold))
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
crop_clip = CropClip()
# Pad image before cropping?
if src_ratio != 1 and pad:
if img.width > img.height:
pad_width = img.width
pad_height = img.width
else:
pad_width = img.height
pad_height = img.height
res = Image.new("RGB", (pad_width, pad_height))
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
img = res
for index, imagefile in enumerate(tqdm.tqdm(files)):
# Do the actual crop clip
im_data = crop_clip.get_center(img, prompt=short_caption)
crop_width = im_data[1] - im_data[0]
center_x = im_data[0] + (crop_width / 2)
crop_height = im_data[3] - im_data[2]
center_y = im_data[2] + (crop_height / 2)
crop_ratio = crop_width / crop_height
dest_ratio = 1
tgt_width = crop_width
tgt_height = crop_height
if shared.state.interrupted:
break
if crop_ratio != dest_ratio:
if crop_width > crop_height:
tgt_height = crop_width / dest_ratio
tgt_width = crop_width
else:
tgt_width = crop_height / dest_ratio
tgt_height = crop_height
subindex = [0]
filename = os.path.join(src, imagefile)
try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
# Reverse the above if dest is too big
if tgt_width > img.width or tgt_height > img.height:
if tgt_width > img.width:
tgt_width = img.width
tgt_height = tgt_width / dest_ratio
else:
tgt_height = img.height
tgt_width = tgt_height / dest_ratio
shared.state.textinfo = f"Processing: '({filename})"
if crop_image:
# Interrogate once
short_caption = interrogate_image(img)
tgt_height = int(tgt_height)
tgt_width = int(tgt_width)
left = max(center_x - (tgt_width / 2), 0)
right = min(center_x + (tgt_width / 2), img.width)
top = max(center_y - (tgt_height / 2), 0)
bottom = min(center_y + (tgt_height / 2), img.height)
img = img.crop((left, top, right, bottom))
shared.state.current_image = img
if subject_class is not None and subject_class != "":
short_caption = subject_class
shared.state.textinfo = f"Cropping: {short_caption}"
if img.height > img.width:
ratio = (img.width * height) / (img.height * width)
inverse_xy = False
else:
ratio = (img.height * width) / (img.width * height)
inverse_xy = True
if split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
# Build our caption
full_caption = None
if caption_image:
full_caption = build_caption(splitted)
save_pic(splitted, index, existing_caption=full_caption)
src_ratio = img.width / img.height
# Pad image before cropping?
if src_ratio != 1:
if img.width > img.height:
pad_width = img.width
pad_height = img.width
if restore_faces:
shared.state.textinfo = f"Restoring faces using {face_model}..."
if face_model == "gfpgan":
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(img, dtype=np.uint8))
img = Image.fromarray(restored_img)
else:
pad_width = img.height
pad_height = img.height
res = Image.new("RGB", (pad_width, pad_height))
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
restored_img = modules.codeformer_model.codeformer.restore(np.array(img, dtype=np.uint8),
w=1.0)
img = Image.fromarray(restored_img)
shared.state.current_image = img
if upscale:
shared.state.textinfo = "Upscaling..."
upscaler = shared.sd_upscalers[scaler]
res = upscaler.scaler.upscale(img, upscale_ratio, upscaler.data_path)
img = res
shared.state.current_image = img
if pad:
ratio = 1
src_ratio = img.width / img.height
src_w = max_size if ratio < src_ratio else img.width * max_size // img.height
src_h = max_size if ratio >= src_ratio else img.height * max_size // img.width
resized = images.resize_image(0, img, src_w, src_h)
res = Image.new("RGB", (max_size, max_size))
res.paste(resized, box=(max_size // 2 - src_w // 2, max_size // 2 - src_h // 2))
img = res
im_data = crop_clip.get_center(img, prompt=short_caption)
crop_width = im_data[1] - im_data[0]
center_x = im_data[0] + (crop_width / 2)
crop_height = im_data[3] - im_data[2]
center_y = im_data[2] + (crop_height / 2)
crop_ratio = crop_width / crop_height
dest_ratio = width / height
tgt_width = crop_width
tgt_height = crop_height
# Resize again if image is not at the right size.
if img.width != max_size or img.height != max_size:
img = images.resize_image(1, img, max_size, max_size)
if crop_ratio != dest_ratio:
if crop_width > crop_height:
tgt_height = crop_width / dest_ratio
tgt_width = crop_width
else:
tgt_width = crop_height / dest_ratio
tgt_height = crop_height
# Reverse the above if dest is too big
if tgt_width > img.width or tgt_height > img.height:
if tgt_width > img.width:
tgt_width = img.width
tgt_height = tgt_width / dest_ratio
else:
tgt_height = img.height
tgt_width = tgt_height / dest_ratio
tgt_height = int(tgt_height)
tgt_width = int(tgt_width)
left = max(center_x - (tgt_width / 2), 0)
right = min(center_x + (tgt_width / 2), img.width)
top = max(center_y - (tgt_height / 2), 0)
bottom = min(center_y + (tgt_height / 2), img.height)
img = img.crop((left, top, right, bottom))
default_resize = True
# Build a caption, if enabled
full_caption = build_caption(img) if caption else None
# Show our output
shared.state.current_image = img
else:
default_resize = False
printi(f"Processed: '({src_image} - {full_caption})")
if restore_faces:
shared.state.textinfo = f"Restoring faces using {face_model}..."
if face_model == "gfpgan":
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(img, dtype=np.uint8))
img = Image.fromarray(restored_img)
else:
restored_img = modules.codeformer_model.codeformer.restore(np.array(img, dtype=np.uint8),
w=1.0)
img = Image.fromarray(restored_img)
shared.state.current_image = img
save_pic(img, src_image, image_index, existing_caption=full_caption)
image_index += 1
if upscale:
shared.state.textinfo = "Upscaling..."
upscaler = shared.sd_upscalers[scaler]
res = upscaler.scaler.upscale(img, upscale_ratio, upscaler.data_path)
img = res
default_resize = True
shared.state.current_image = img
if flip:
save_pic(ImageOps.flip(img), src_image, image_index, existing_caption=full_caption, flipped=True)
image_index += 1
if pad_image:
ratio = width / height
src_ratio = img.width / img.height
shared.state.nextjob()
src_w = width if ratio < src_ratio else img.width * height // img.height
src_h = height if ratio >= src_ratio else img.height * width // img.width
if caption_clip or crop:
printi("Unloading CLIP interrogator...")
shared.interrogator.send_blip_to_ram()
resized = images.resize_image(0, img, src_w, src_h)
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
img = res
if caption_deepbooru:
printi("Unloading Deepbooru interrogator...")
db_interrogator.unload()
if default_resize:
img = images.resize_image(1, img, width, height)
shared.state.current_image = img
full_caption = build_caption(img)
save_pic(img, index, existing_caption=full_caption)
if caption_wd14:
printi("Unloading wd14 interrogator...")
wd_interrogator.unload()
shared.state.nextjob()
return f"Successfully processed {len(files)} images.", f"Successfully processed {len(files)} images."
except Exception as e:
msg = f"Exception processing: {e}"
traceback.print_exc()
pass
return msg, msg