Super Update
Add WD14 tagger. Add CLIP v2.1 interrogator. Allow filtering wd14, booru tags by score. Add don't rename option. Update ReallySafe Bump BLIP version? Remove split image options. Completely overhaul smartprocess...pull/12/head
parent
93380435c0
commit
58bcff7f78
|
|
@ -8,11 +8,9 @@ import numpy as np
|
|||
import torch
|
||||
from PIL import Image
|
||||
from clip import clip
|
||||
from transformers import CLIPProcessor, CLIPModel, pipeline
|
||||
|
||||
import modules.paths
|
||||
from modules import shared, modelloader
|
||||
from repositories.CodeFormer.facelib.detection.yolov5face.utils.general import xyxy2xywh, xywh2xyxy
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,362 @@
|
|||
import hashlib
|
||||
import inspect
|
||||
import math
|
||||
import numpy as np
|
||||
import open_clip
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from models.blip import blip_decoder, BLIP_Decoder
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from tqdm import tqdm
|
||||
from typing import List
|
||||
|
||||
from extensions.sd_smartprocess.interrogator import Interrogator
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
# models can optionally be passed in directly
|
||||
blip_model: BLIP_Decoder = None
|
||||
clip_model = None
|
||||
clip_preprocess = None
|
||||
|
||||
# blip settings
|
||||
blip_image_eval_size: int = 384
|
||||
blip_max_length: int = 32
|
||||
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
|
||||
blip_num_beams: int = 8
|
||||
blip_offload: bool = False
|
||||
|
||||
# clip settings
|
||||
clip_model_name: str = 'ViT-L-14/openai'
|
||||
clip_model_path: str = None
|
||||
|
||||
# interrogator settings
|
||||
cache_path: str = 'cache'
|
||||
chunk_size: int = 2048
|
||||
data_path: str = os.path.join(os.path.dirname(__file__), 'data')
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
flavor_intermediate_count: int = 2048
|
||||
quiet: bool = False # when quiet progress bars are not shown
|
||||
|
||||
|
||||
class ClipInterrogator(Interrogator):
|
||||
def __init__(self,
|
||||
use_v2,
|
||||
append_artist,
|
||||
append_medium,
|
||||
append_movement,
|
||||
append_flavor,
|
||||
append_trending):
|
||||
if use_v2:
|
||||
model_name = "ViT-H-14/laion2b_s32b_b79k"
|
||||
else:
|
||||
model_name = "ViT-L-14/openai"
|
||||
self.append_artist = append_artist
|
||||
self.append_medium = append_medium
|
||||
self.append_movement = append_movement
|
||||
self.append_trending = append_trending
|
||||
self.append_flavor = append_flavor
|
||||
self.artists = None
|
||||
self.flavors = None
|
||||
self.mediums = None
|
||||
self.movements = None
|
||||
self.tokenize = None
|
||||
self.trendings = None
|
||||
self.clip_model = None
|
||||
self.clip_preprocess = None
|
||||
config = Config
|
||||
config.clip_model_name = model_name
|
||||
self.config = config
|
||||
self.device = config.device
|
||||
|
||||
if config.blip_model is None:
|
||||
if not config.quiet:
|
||||
print("Loading BLIP model...")
|
||||
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
|
||||
configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
|
||||
med_config = os.path.join(configs_path, 'med_config.json')
|
||||
blip_model = blip_decoder(
|
||||
pretrained=config.blip_model_url,
|
||||
image_size=config.blip_image_eval_size,
|
||||
vit='large',
|
||||
med_config=med_config
|
||||
)
|
||||
blip_model.eval()
|
||||
blip_model = blip_model.to(config.device)
|
||||
self.blip_model = blip_model
|
||||
else:
|
||||
self.blip_model = config.blip_model
|
||||
|
||||
self.load_clip_model()
|
||||
|
||||
def load_clip_model(self):
|
||||
start_time = time.time()
|
||||
config = self.config
|
||||
|
||||
if config.clip_model is None:
|
||||
if not config.quiet:
|
||||
print("Loading CLIP model...")
|
||||
|
||||
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
|
||||
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
|
||||
clip_model_name,
|
||||
pretrained=clip_model_pretrained_name,
|
||||
precision='fp16' if config.device == 'cuda' else 'fp32',
|
||||
device=config.device,
|
||||
jit=False,
|
||||
cache_dir=config.clip_model_path
|
||||
)
|
||||
self.clip_model.to(config.device).eval()
|
||||
else:
|
||||
self.clip_model = config.clip_model
|
||||
self.clip_preprocess = config.clip_preprocess
|
||||
clip_model_name = config.clip_model_name
|
||||
|
||||
self.tokenize = open_clip.get_tokenizer(clip_model_name)
|
||||
|
||||
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram',
|
||||
'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash',
|
||||
'zbrush central']
|
||||
trending_list = [site for site in sites]
|
||||
trending_list.extend(["trending on " + site for site in sites])
|
||||
trending_list.extend(["featured on " + site for site in sites])
|
||||
trending_list.extend([site + " contest winner" for site in sites])
|
||||
|
||||
raw_artists = _load_list(config.data_path, 'artists.txt')
|
||||
artists = [f"by {a}" for a in raw_artists]
|
||||
artists.extend([f"inspired by {a}" for a in raw_artists])
|
||||
if self.append_artist:
|
||||
self.artists = LabelTable(artists, "artists", self.clip_model, self.tokenize, config)
|
||||
if self.append_flavor:
|
||||
self.flavors = LabelTable(_load_list(config.data_path, 'flavors.txt'), "flavors", self.clip_model,
|
||||
self.tokenize, config)
|
||||
if self.append_medium:
|
||||
self.mediums = LabelTable(_load_list(config.data_path, 'mediums.txt'), "mediums", self.clip_model,
|
||||
self.tokenize, config)
|
||||
if self.append_movement:
|
||||
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model,
|
||||
self.tokenize, config)
|
||||
if self.append_trending:
|
||||
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
|
||||
|
||||
end_time = time.time()
|
||||
if not config.quiet:
|
||||
print(f"Loaded CLIP model and data in {end_time - start_time:.2f} seconds.")
|
||||
|
||||
def generate_caption(self, pil_image: Image) -> str:
|
||||
if self.config.blip_offload:
|
||||
self.blip_model = self.blip_model.to(self.device)
|
||||
size = self.config.blip_image_eval_size
|
||||
gpu_image = transforms.Compose([
|
||||
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
caption = self.blip_model.generate(
|
||||
gpu_image,
|
||||
sample=False,
|
||||
num_beams=self.config.blip_num_beams,
|
||||
max_length=self.config.blip_max_length,
|
||||
min_length=5
|
||||
)
|
||||
if self.config.blip_offload:
|
||||
self.blip_model = self.blip_model.to("cpu")
|
||||
return caption[0]
|
||||
|
||||
def image_to_features(self, image: Image) -> torch.Tensor:
|
||||
images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
image_features = self.clip_model.encode_image(images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features
|
||||
|
||||
def interrogate(self, image: Image, max_flavors: int = 32, short=False) -> str:
|
||||
caption = self.generate_caption(image)
|
||||
image_features = self.image_to_features(image)
|
||||
best_prompt = caption
|
||||
best_sim = self.similarity(image_features, best_prompt)
|
||||
|
||||
def check(addition: str) -> bool:
|
||||
nonlocal best_prompt, best_sim
|
||||
prompt = best_prompt + ", " + addition
|
||||
sim = self.similarity(image_features, prompt)
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_prompt = prompt
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_multi_batch(opts: List[str]):
|
||||
nonlocal best_prompt, best_sim
|
||||
prompts = []
|
||||
for i in range(2 ** len(opts)):
|
||||
prompt = best_prompt
|
||||
for bit in range(len(opts)):
|
||||
if i & (1 << bit):
|
||||
prompt += ", " + opts[bit]
|
||||
prompts.append(prompt)
|
||||
|
||||
t = LabelTable(prompts, None, self.clip_model, self.tokenize, self.config)
|
||||
best_prompt = t.rank(image_features, 1)[0]
|
||||
best_sim = self.similarity(image_features, best_prompt)
|
||||
|
||||
batch = []
|
||||
|
||||
if not short:
|
||||
if self.append_artist:
|
||||
batch.append(self.artists.rank(image_features, 1)[0])
|
||||
if self.append_flavor:
|
||||
best_flavors = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
|
||||
extended_flavors = set(best_flavors)
|
||||
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
|
||||
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
|
||||
flave = best[len(best_prompt) + 2:]
|
||||
if not check(flave):
|
||||
break
|
||||
if _prompt_at_max_len(best_prompt, self.tokenize):
|
||||
break
|
||||
extended_flavors.remove(flave)
|
||||
if self.append_medium:
|
||||
batch.append(self.mediums.rank(image_features, 1)[0])
|
||||
if self.append_trending:
|
||||
batch.append(self.trendings.rank(image_features, 1)[0])
|
||||
if self.append_movement:
|
||||
batch.append(self.movements.rank(image_features, 1)[0])
|
||||
|
||||
check_multi_batch(batch)
|
||||
tags = best_prompt.split(",")
|
||||
else:
|
||||
tags = [best_prompt]
|
||||
return tags
|
||||
|
||||
def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str:
|
||||
text_tokens = self.tokenize([text for text in text_array]).to(self.device)
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
text_features = self.clip_model.encode_text(text_tokens)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity = text_features @ image_features.T
|
||||
return text_array[similarity.argmax().item()]
|
||||
|
||||
def similarity(self, image_features: torch.Tensor, text: str) -> float:
|
||||
text_tokens = self.tokenize([text]).to(self.device)
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
text_features = self.clip_model.encode_text(text_tokens)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity = text_features @ image_features.T
|
||||
return similarity[0][0].item()
|
||||
|
||||
|
||||
class LabelTable:
|
||||
def __init__(self, labels: List[str], desc: str, clip_model, tokenize, config: Config):
|
||||
self.chunk_size = config.chunk_size
|
||||
self.config = config
|
||||
self.device = config.device
|
||||
self.embeds = []
|
||||
self.labels = labels
|
||||
self.tokenize = tokenize
|
||||
|
||||
hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
|
||||
|
||||
cache_filepath = None
|
||||
if config.cache_path is not None and desc is not None:
|
||||
os.makedirs(config.cache_path, exist_ok=True)
|
||||
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_')
|
||||
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl")
|
||||
if desc is not None and os.path.exists(cache_filepath):
|
||||
with open(cache_filepath, 'rb') as f:
|
||||
try:
|
||||
data = pickle.load(f)
|
||||
if data.get('hash') == hash:
|
||||
self.labels = data['labels']
|
||||
self.embeds = data['embeds']
|
||||
except Exception as e:
|
||||
print(f"Error loading cached table {desc}: {e}")
|
||||
|
||||
if len(self.labels) != len(self.embeds):
|
||||
self.embeds = []
|
||||
chunks = np.array_split(self.labels, max(1, len(self.labels) / config.chunk_size))
|
||||
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
|
||||
text_tokens = self.tokenize(chunk).to(self.device)
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
text_features = clip_model.encode_text(text_tokens)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features.half().cpu().numpy()
|
||||
for i in range(text_features.shape[0]):
|
||||
self.embeds.append(text_features[i])
|
||||
|
||||
if cache_filepath is not None:
|
||||
with open(cache_filepath, 'wb') as f:
|
||||
pickle.dump({
|
||||
"labels": self.labels,
|
||||
"embeds": self.embeds,
|
||||
"hash": hash,
|
||||
"model": config.clip_model_name
|
||||
}, f)
|
||||
|
||||
if self.device == 'cpu' or self.device == torch.device('cpu'):
|
||||
self.embeds = [e.astype(np.float32) for e in self.embeds]
|
||||
|
||||
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int = 1) -> str:
|
||||
top_count = min(top_count, len(text_embeds))
|
||||
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
|
||||
with torch.cuda.amp.autocast():
|
||||
similarity = image_features @ text_embeds.T
|
||||
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1)
|
||||
return [top_labels[0][i].numpy() for i in range(top_count)]
|
||||
|
||||
def rank(self, image_features: torch.Tensor, top_count: int = 1) -> List[str]:
|
||||
if len(self.labels) <= self.chunk_size:
|
||||
tops = self._rank(image_features, self.embeds, top_count=top_count)
|
||||
return [self.labels[i] for i in tops]
|
||||
|
||||
num_chunks = int(math.ceil(len(self.labels) / self.chunk_size))
|
||||
keep_per_chunk = int(self.chunk_size / num_chunks)
|
||||
|
||||
top_labels, top_embeds = [], []
|
||||
for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
|
||||
start = chunk_idx * self.chunk_size
|
||||
stop = min(start + self.chunk_size, len(self.embeds))
|
||||
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
|
||||
top_labels.extend([self.labels[start + i] for i in tops])
|
||||
top_embeds.extend([self.embeds[start + i] for i in tops])
|
||||
|
||||
tops = self._rank(image_features, top_embeds, top_count=top_count)
|
||||
return [top_labels[i] for i in tops]
|
||||
|
||||
|
||||
def _load_list(data_path: str, filename: str) -> List[str]:
|
||||
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
|
||||
items = [line.strip() for line in f.readlines()]
|
||||
return items
|
||||
|
||||
|
||||
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
|
||||
m = LabelTable([], None, None, None, config)
|
||||
for table in tables:
|
||||
m.labels.extend(table.labels)
|
||||
m.embeds.extend(table.embeds)
|
||||
return m
|
||||
|
||||
|
||||
def _prompt_at_max_len(text: str, tokenize) -> bool:
|
||||
tokens = tokenize([text])
|
||||
return tokens[0][-1] != 0
|
||||
|
||||
|
||||
def _truncate_to_fit(text: str, tokenize) -> str:
|
||||
parts = text.split(', ')
|
||||
new_text = parts[0]
|
||||
for part in parts[1:]:
|
||||
if _prompt_at_max_len(new_text + part, tokenize):
|
||||
break
|
||||
new_text += ', ' + part
|
||||
return new_text
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,95 @@
|
|||
a 3D render
|
||||
a black and white photo
|
||||
a bronze sculpture
|
||||
a cartoon
|
||||
a cave painting
|
||||
a character portrait
|
||||
a charcoal drawing
|
||||
a child's drawing
|
||||
a color pencil sketch
|
||||
a colorized photo
|
||||
a comic book panel
|
||||
a computer rendering
|
||||
a cross stitch
|
||||
a cubist painting
|
||||
a detailed drawing
|
||||
a detailed matte painting
|
||||
a detailed painting
|
||||
a diagram
|
||||
a digital painting
|
||||
a digital rendering
|
||||
a drawing
|
||||
a fine art painting
|
||||
a flemish Baroque
|
||||
a gouache
|
||||
a hologram
|
||||
a hyperrealistic painting
|
||||
a jigsaw puzzle
|
||||
a low poly render
|
||||
a macro photograph
|
||||
a manga drawing
|
||||
a marble sculpture
|
||||
a matte painting
|
||||
a microscopic photo
|
||||
a mid-nineteenth century engraving
|
||||
a minimalist painting
|
||||
a mosaic
|
||||
a painting
|
||||
a pastel
|
||||
a pencil sketch
|
||||
a photo
|
||||
a photocopy
|
||||
a photorealistic painting
|
||||
a picture
|
||||
a pointillism painting
|
||||
a polaroid photo
|
||||
a pop art painting
|
||||
a portrait
|
||||
a poster
|
||||
a raytraced image
|
||||
a renaissance painting
|
||||
a screenprint
|
||||
a screenshot
|
||||
a silk screen
|
||||
a sketch
|
||||
a statue
|
||||
a still life
|
||||
a stipple
|
||||
a stock photo
|
||||
a storybook illustration
|
||||
a surrealist painting
|
||||
a surrealist sculpture
|
||||
a tattoo
|
||||
a tilt shift photo
|
||||
a watercolor painting
|
||||
a wireframe diagram
|
||||
a woodcut
|
||||
an abstract drawing
|
||||
an abstract painting
|
||||
an abstract sculpture
|
||||
an acrylic painting
|
||||
an airbrush painting
|
||||
an album cover
|
||||
an ambient occlusion render
|
||||
an anime drawing
|
||||
an art deco painting
|
||||
an art deco sculpture
|
||||
an engraving
|
||||
an etching
|
||||
an illustration of
|
||||
an impressionist painting
|
||||
an ink drawing
|
||||
an oil on canvas painting
|
||||
an oil painting
|
||||
an ultrafine detailed painting
|
||||
chalk art
|
||||
computer graphics
|
||||
concept art
|
||||
cyberpunk art
|
||||
digital art
|
||||
egyptian art
|
||||
graffiti art
|
||||
lineart
|
||||
pixel art
|
||||
poster art
|
||||
vector art
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
abstract art
|
||||
abstract expressionism
|
||||
abstract illusionism
|
||||
academic art
|
||||
action painting
|
||||
aestheticism
|
||||
afrofuturism
|
||||
altermodern
|
||||
american barbizon school
|
||||
american impressionism
|
||||
american realism
|
||||
american romanticism
|
||||
american scene painting
|
||||
analytical art
|
||||
antipodeans
|
||||
arabesque
|
||||
arbeitsrat für kunst
|
||||
art & language
|
||||
art brut
|
||||
art deco
|
||||
art informel
|
||||
art nouveau
|
||||
art photography
|
||||
arte povera
|
||||
arts and crafts movement
|
||||
ascii art
|
||||
ashcan school
|
||||
assemblage
|
||||
australian tonalism
|
||||
auto-destructive art
|
||||
barbizon school
|
||||
baroque
|
||||
bauhaus
|
||||
bengal school of art
|
||||
berlin secession
|
||||
black arts movement
|
||||
brutalism
|
||||
classical realism
|
||||
cloisonnism
|
||||
cobra
|
||||
color field
|
||||
computer art
|
||||
conceptual art
|
||||
concrete art
|
||||
constructivism
|
||||
context art
|
||||
crayon art
|
||||
crystal cubism
|
||||
cubism
|
||||
cubo-futurism
|
||||
cynical realism
|
||||
dada
|
||||
danube school
|
||||
dau-al-set
|
||||
de stijl
|
||||
deconstructivism
|
||||
digital art
|
||||
ecological art
|
||||
environmental art
|
||||
excessivism
|
||||
expressionism
|
||||
fantastic realism
|
||||
fantasy art
|
||||
fauvism
|
||||
feminist art
|
||||
figuration libre
|
||||
figurative art
|
||||
figurativism
|
||||
fine art
|
||||
fluxus
|
||||
folk art
|
||||
funk art
|
||||
furry art
|
||||
futurism
|
||||
generative art
|
||||
geometric abstract art
|
||||
german romanticism
|
||||
gothic art
|
||||
graffiti
|
||||
gutai group
|
||||
happening
|
||||
harlem renaissance
|
||||
heidelberg school
|
||||
holography
|
||||
hudson river school
|
||||
hurufiyya
|
||||
hypermodernism
|
||||
hyperrealism
|
||||
impressionism
|
||||
incoherents
|
||||
institutional critique
|
||||
interactive art
|
||||
international gothic
|
||||
international typographic style
|
||||
kinetic art
|
||||
kinetic pointillism
|
||||
kitsch movement
|
||||
land art
|
||||
les automatistes
|
||||
les nabis
|
||||
letterism
|
||||
light and space
|
||||
lowbrow
|
||||
lyco art
|
||||
lyrical abstraction
|
||||
magic realism
|
||||
magical realism
|
||||
mail art
|
||||
mannerism
|
||||
massurrealism
|
||||
maximalism
|
||||
metaphysical painting
|
||||
mingei
|
||||
minimalism
|
||||
modern european ink painting
|
||||
modernism
|
||||
modular constructivism
|
||||
naive art
|
||||
naturalism
|
||||
neo-dada
|
||||
neo-expressionism
|
||||
neo-fauvism
|
||||
neo-figurative
|
||||
neo-primitivism
|
||||
neo-romanticism
|
||||
neoclassicism
|
||||
neogeo
|
||||
neoism
|
||||
neoplasticism
|
||||
net art
|
||||
new objectivity
|
||||
new sculpture
|
||||
northwest school
|
||||
nuclear art
|
||||
objective abstraction
|
||||
op art
|
||||
optical illusion
|
||||
orphism
|
||||
panfuturism
|
||||
paris school
|
||||
photorealism
|
||||
pixel art
|
||||
plasticien
|
||||
plein air
|
||||
pointillism
|
||||
pop art
|
||||
pop surrealism
|
||||
post-impressionism
|
||||
postminimalism
|
||||
pre-raphaelitism
|
||||
precisionism
|
||||
primitivism
|
||||
private press
|
||||
process art
|
||||
psychedelic art
|
||||
purism
|
||||
qajar art
|
||||
quito school
|
||||
rasquache
|
||||
rayonism
|
||||
realism
|
||||
regionalism
|
||||
remodernism
|
||||
renaissance
|
||||
retrofuturism
|
||||
rococo
|
||||
romanesque
|
||||
romanticism
|
||||
samikshavad
|
||||
serial art
|
||||
shin hanga
|
||||
shock art
|
||||
socialist realism
|
||||
sots art
|
||||
space art
|
||||
street art
|
||||
stuckism
|
||||
sumatraism
|
||||
superflat
|
||||
suprematism
|
||||
surrealism
|
||||
symbolism
|
||||
synchromism
|
||||
synthetism
|
||||
sōsaku hanga
|
||||
tachisme
|
||||
temporary art
|
||||
tonalism
|
||||
toyism
|
||||
transgressive art
|
||||
ukiyo-e
|
||||
underground comix
|
||||
unilalianism
|
||||
vancouver school
|
||||
vanitas
|
||||
verdadism
|
||||
video art
|
||||
viennese actionism
|
||||
visual art
|
||||
vorticism
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# DanBooru IMage Utility functions, borrowed from
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/dbimutils.py
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
|
||||
if img.endswith(".gif"):
|
||||
img = Image.open(img)
|
||||
img = img.convert("RGB")
|
||||
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
img = cv2.imread(img, flag)
|
||||
return img
|
||||
|
||||
|
||||
def smart_24bit(img):
|
||||
if img.dtype is np.dtype(np.uint16):
|
||||
img = (img / 257).astype(np.uint8)
|
||||
|
||||
if len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4:
|
||||
trans_mask = img[:, :, 3] == 0
|
||||
img[trans_mask] = [255, 255, 255, 255]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
return img
|
||||
|
||||
|
||||
def make_square(img, target_size):
|
||||
old_size = img.shape[:2]
|
||||
desired_size = max(old_size)
|
||||
desired_size = max(desired_size, target_size)
|
||||
|
||||
delta_w = desired_size - old_size[1]
|
||||
delta_h = desired_size - old_size[0]
|
||||
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
||||
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
||||
|
||||
color = [255, 255, 255]
|
||||
new_im = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
||||
)
|
||||
return new_im
|
||||
|
||||
|
||||
def smart_resize(img, size):
|
||||
# Assumes the image has already gone through make_square
|
||||
if img.shape[0] > size:
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
||||
elif img.shape[0] < size:
|
||||
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
||||
return img
|
||||
|
|
@ -1,9 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from launch import run
|
||||
from launch import run, git_clone, repo_dir
|
||||
|
||||
name = "Smart Crop"
|
||||
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
||||
print(f"loading {name} reqs from {req_file}")
|
||||
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements.", f"Couldn't install {name} requirements.")
|
||||
run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements.",
|
||||
f"Couldn't install {name} requirements.")
|
||||
blip_repo = "https://github.com/pharmapsychotic/BLIP"
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,252 @@
|
|||
# Borrowed from https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/master/tagger/interrogator.py
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Dict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import open_clip
|
||||
from PIL import Image
|
||||
from huggingface_hub import hf_hub_download
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
import modules.deepbooru
|
||||
import modules.shared as shared
|
||||
from extensions.sd_smartprocess import dbimutils
|
||||
from modules import devices, paths, lowvram, modelloader
|
||||
from modules import images
|
||||
from modules.deepbooru import re_special as tag_escape_pattern
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
|
||||
use_cpu = shared.cmd_opts.use_cpu == 'all' or shared.cmd_opts.use_cpu == 'interrogate'
|
||||
onyx_providers = []
|
||||
if use_cpu:
|
||||
tf_device_name = '/cpu:0'
|
||||
onyx_providers = ['CPUExecutionProvider']
|
||||
else:
|
||||
tf_device_name = '/gpu:0'
|
||||
onyx_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
try:
|
||||
tf_device_name = f'/gpu:{int(shared.cmd_opts.device_id)}'
|
||||
except ValueError:
|
||||
print('--device-id is not a integer')
|
||||
|
||||
|
||||
class Interrogator:
|
||||
@staticmethod
|
||||
def postprocess_tags(
|
||||
tags: Dict[str, float],
|
||||
threshold=0.35,
|
||||
additional_tags=None,
|
||||
exclude_tags=None,
|
||||
sort_by_alphabetical_order=False,
|
||||
add_confident_as_weight=False,
|
||||
replace_underscore=False,
|
||||
replace_underscore_excludes=None,
|
||||
escape_tag=False
|
||||
) -> Dict[str, float]:
|
||||
|
||||
if replace_underscore_excludes is None:
|
||||
replace_underscore_excludes = []
|
||||
if exclude_tags is None:
|
||||
exclude_tags = []
|
||||
if additional_tags is None:
|
||||
additional_tags = []
|
||||
tags = {
|
||||
**{t: 1.0 for t in additional_tags},
|
||||
**tags
|
||||
}
|
||||
|
||||
# those lines are totally not "pythonic" but looks better to me
|
||||
tags = {
|
||||
t: c
|
||||
|
||||
# sort by tag name or confident
|
||||
for t, c in sorted(
|
||||
tags.items(),
|
||||
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
|
||||
reverse=not sort_by_alphabetical_order
|
||||
)
|
||||
|
||||
# filter tags
|
||||
if (
|
||||
c >= threshold
|
||||
and t not in exclude_tags
|
||||
)
|
||||
}
|
||||
|
||||
for tag in list(tags):
|
||||
new_tag = tag
|
||||
|
||||
if replace_underscore and tag not in replace_underscore_excludes:
|
||||
new_tag = new_tag.replace('_', ' ')
|
||||
|
||||
if escape_tag:
|
||||
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
|
||||
|
||||
if add_confident_as_weight:
|
||||
new_tag = f'({new_tag}:{tags[tag]})'
|
||||
|
||||
if new_tag != tag:
|
||||
tags[new_tag] = tags.pop(tag)
|
||||
|
||||
return tags
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> Tuple[
|
||||
Dict[str, float], # rating confidence
|
||||
Dict[str, float] # tag confidence
|
||||
]:
|
||||
pass
|
||||
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
|
||||
|
||||
class BooruInterrogator(Interrogator):
|
||||
def __init__(self) -> None:
|
||||
self.tags = None
|
||||
self.booru = modules.deepbooru.DeepDanbooru()
|
||||
self.booru.start()
|
||||
self.model = self.booru.model
|
||||
|
||||
def unload(self):
|
||||
self.booru.stop()
|
||||
|
||||
def interrogate(self, pil_image) -> Dict[str, float]:
|
||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
|
||||
with torch.no_grad(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
y = self.model(x)[0].detach().cpu().numpy()
|
||||
|
||||
probability_dict = {}
|
||||
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if tag.startswith("rating:"):
|
||||
continue
|
||||
|
||||
probability_dict[tag] = probability
|
||||
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
|
||||
output = {}
|
||||
for tag in tags:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
output[tag_outformat] = probability
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class WaifuDiffusionInterrogator(Interrogator):
|
||||
def __init__(
|
||||
self,
|
||||
repo='SmilingWolf/wd-v1-4-vit-tagger',
|
||||
model_path='model.onnx',
|
||||
tags_path='selected_tags.csv'
|
||||
) -> None:
|
||||
self.tags = None
|
||||
self.model = None
|
||||
self.repo = repo
|
||||
self.model_path = model_path
|
||||
self.tags_path = tags_path
|
||||
self.load()
|
||||
|
||||
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
||||
print(f'Loading Waifu Diffusion tagger model file from {self.repo}')
|
||||
|
||||
model_path = Path(hf_hub_download(self.repo, filename=self.model_path))
|
||||
tags_path = Path(hf_hub_download(self.repo, filename=self.tags_path))
|
||||
return model_path, tags_path
|
||||
|
||||
def load(self) -> None:
|
||||
model_path, tags_path = self.download()
|
||||
from launch import is_installed, run_pip
|
||||
if not is_installed('onnxruntime'):
|
||||
package_name = 'onnxruntime-gpu'
|
||||
|
||||
if use_cpu or not torch.cuda.is_available():
|
||||
package_name = 'onnxruntime'
|
||||
|
||||
package = os.environ.get(
|
||||
'ONNXRUNTIME_PACKAGE',
|
||||
package_name
|
||||
)
|
||||
|
||||
run_pip(f'install {package}', package_name)
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
self.model = InferenceSession(str(model_path), providers=onyx_providers)
|
||||
|
||||
print(f'Loaded Waifu Diffusion tagger model from {model_path}')
|
||||
self.tags = pd.read_csv(tags_path)
|
||||
|
||||
def unload(self):
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
|
||||
def interrogate(
|
||||
self,
|
||||
image: Image
|
||||
) -> Tuple[
|
||||
Dict[str, float], # rating confidence
|
||||
Dict[str, float] # tag confidence
|
||||
]:
|
||||
# code for converting the image and running the model is taken from the link below
|
||||
# thanks, SmilingWolf!
|
||||
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
|
||||
|
||||
# convert an image to fit the model
|
||||
_, height, _, _ = self.model.get_inputs()[0].shape
|
||||
|
||||
# alpha to white
|
||||
image = image.convert('RGBA')
|
||||
new_image = Image.new('RGBA', image.size, 'WHITE')
|
||||
new_image.paste(image, mask=image)
|
||||
image = new_image.convert('RGB')
|
||||
image = np.asarray(image)
|
||||
|
||||
# PIL RGB to OpenCV BGR
|
||||
image = image[:, :, ::-1]
|
||||
|
||||
image = dbimutils.make_square(image, height)
|
||||
image = dbimutils.smart_resize(image, height)
|
||||
image = image.astype(np.float32)
|
||||
image = np.expand_dims(image, 0)
|
||||
|
||||
# evaluate model
|
||||
input_name = self.model.get_inputs()[0].name
|
||||
label_name = self.model.get_outputs()[0].name
|
||||
confidence = self.model.run([label_name], {input_name: image})[0]
|
||||
|
||||
tags = self.tags[:][['name']]
|
||||
tags['confidence'] = confidence[0]
|
||||
|
||||
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
||||
ratings = dict(tags[:4].values)
|
||||
|
||||
# rest are regular tags
|
||||
tags = dict(tags[4:].values)
|
||||
|
||||
return ratings, tags
|
||||
112
reallysafe.py
112
reallysafe.py
|
|
@ -1,12 +1,16 @@
|
|||
import _codecs
|
||||
import collections
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import zipfile
|
||||
|
||||
import numpy
|
||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||
import torch
|
||||
|
||||
from modules import safe
|
||||
from modules.safe import TypedStorage
|
||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||
|
||||
|
||||
def encode(*args):
|
||||
|
|
@ -15,17 +19,24 @@ def encode(*args):
|
|||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
def persistent_load(self, saved_id):
|
||||
assert saved_id[0] == 'storage'
|
||||
return TypedStorage()
|
||||
|
||||
def find_class(self, module, name):
|
||||
if self.extra_handler is not None:
|
||||
res = self.extra_handler(module, name)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return getattr(collections, name)
|
||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||
return getattr(torch._utils, name)
|
||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage',
|
||||
'ByteStorage']:
|
||||
'ByteStorage', 'BFloat16Storage']:
|
||||
return getattr(torch, name)
|
||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict', 'Sequential']:
|
||||
return getattr(torch.nn.modules.container, name)
|
||||
|
|
@ -54,7 +65,98 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
safe.RestrictedUnpickler = RestrictedUnpickler
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
|
||||
def check_zip_filenames(filename, names):
|
||||
for name in names:
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
def check_pt(filename, extra_handler):
|
||||
try:
|
||||
|
||||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
raise Exception(f"data.pkl not found in {filename}")
|
||||
if len(data_pkl_filenames) > 1:
|
||||
raise Exception(f"Multiple data.pkl found in {filename}")
|
||||
with z.open(data_pkl_filenames[0]) as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
unpickler.load()
|
||||
|
||||
except zipfile.BadZipfile:
|
||||
|
||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||
with open(filename, "rb") as file:
|
||||
unpickler = RestrictedUnpickler(file)
|
||||
unpickler.extra_handler = extra_handler
|
||||
for i in range(5):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def load(filename, *args, **kwargs):
|
||||
return load_with_extra(filename, *args, **kwargs)
|
||||
|
||||
|
||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||
"""
|
||||
this functon is intended to be used by extensions that want to load models with
|
||||
some extra classes in them that the usual unpickler would find suspicious.
|
||||
|
||||
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||
and returns that field's value:
|
||||
|
||||
```python
|
||||
def extra(module, name):
|
||||
if module == 'collections' and name == 'OrderedDict':
|
||||
return collections.OrderedDict
|
||||
|
||||
return None
|
||||
|
||||
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||
```
|
||||
|
||||
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||
definitely unsafe.
|
||||
"""
|
||||
|
||||
from modules import shared
|
||||
|
||||
try:
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename, extra_handler)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
||||
|
||||
unsafe_torch_load = torch.load
|
||||
torch.load = load
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
ipython==8.6.0
|
||||
seaborn==0.12.1
|
||||
seaborn==0.12.1
|
||||
tensorflow
|
||||
open_clip_torch
|
||||
|
|
@ -1,16 +1,17 @@
|
|||
import gradio as gr
|
||||
|
||||
from extensions.sd_smartprocess import smartprocess
|
||||
from modules import script_callbacks, shared
|
||||
from modules.shared import cmd_opts
|
||||
from modules.ui import setup_progressbar, gr_show
|
||||
from modules.ui import setup_progressbar
|
||||
from webui import wrap_gradio_gpu_call
|
||||
import smartprocess
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Blocks() as sp_interface:
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(variant="panel"):
|
||||
sp_rename = gr.Checkbox(label="Rename images", value=False)
|
||||
with gr.Tab("Directories"):
|
||||
sp_src = gr.Textbox(label='Source directory')
|
||||
sp_dst = gr.Textbox(label='Destination directory')
|
||||
|
|
@ -20,23 +21,27 @@ def on_ui_tabs():
|
|||
sp_pad = gr.Checkbox(label="Pad Images")
|
||||
sp_crop = gr.Checkbox(label='Crop Images')
|
||||
sp_flip = gr.Checkbox(label='Create flipped copies')
|
||||
sp_split = gr.Checkbox(label='Split over-sized images')
|
||||
sp_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0,
|
||||
maximum=1.0,
|
||||
step=0.05)
|
||||
sp_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0,
|
||||
maximum=0.9, step=0.05)
|
||||
|
||||
with gr.Tab("Captions"):
|
||||
sp_caption = gr.Checkbox(label='Generate Captions')
|
||||
sp_caption_length = gr.Number(label='Max Caption length (0=unlimited)', value=0, precision=0)
|
||||
sp_txt_action = gr.Dropdown(label='Existing Caption Action', value="ignore",
|
||||
choices=["ignore", "copy", "prepend", "append"])
|
||||
sp_caption_append_file = gr.Checkbox(label="Append Caption to File Name", value=True)
|
||||
sp_caption_save_txt = gr.Checkbox(label="Save Caption to .txt File", value=False)
|
||||
sp_caption_deepbooru = gr.Checkbox(label='Append DeepDanbooru to Caption',
|
||||
sp_caption_clip = gr.Checkbox(label="Add CLIP results to Caption")
|
||||
sp_clip_use_v2 = gr.Checkbox(label="Use v2 CLIP Model", value=True)
|
||||
sp_clip_append_flavor = gr.Checkbox(label="Append Flavor tags from CLIP")
|
||||
sp_clip_append_medium = gr.Checkbox(label="Append Medium tags from CLIP")
|
||||
sp_clip_append_movement = gr.Checkbox(label="Append Movement tags from CLIP")
|
||||
sp_clip_append_artist = gr.Checkbox(label="Append Artist tags from CLIP")
|
||||
sp_clip_append_trending = gr.Checkbox(label="Append Trending tags from CLIP")
|
||||
sp_caption_wd14 = gr.Checkbox(label="Add WD14 Tags to Caption")
|
||||
sp_wd14_min_score = gr.Slider(label="Minimum Score for WD14 Tags", value=0.75, minimum=0.01, maximum=1,
|
||||
step=0.01)
|
||||
sp_caption_deepbooru = gr.Checkbox(label='Add DeepDanbooru Tags to Caption',
|
||||
visible=True if cmd_opts.deepdanbooru else False)
|
||||
sp_replace_class = gr.Checkbox(label='Replace Class with Subject in Caption', value=True)
|
||||
sp_booru_min_score = gr.Slider(label="Minimum Score for DeepDanbooru Tags", value=0.75,
|
||||
minimum=0.01, maximum=1, step=0.01)
|
||||
sp_replace_class = gr.Checkbox(label='Replace Class with Subject in Caption', value=False)
|
||||
sp_class = gr.Textbox(label='Subject Class', placeholder='Subject class to crop (leave '
|
||||
'blank to auto-detect)')
|
||||
sp_subject = gr.Textbox(label='Subject Name', placeholder='Subject Name to replace class '
|
||||
|
|
@ -73,21 +78,27 @@ def on_ui_tabs():
|
|||
fn=wrap_gradio_gpu_call(smartprocess.preprocess, extra_outputs=[gr.update()]),
|
||||
_js="start_smart_process",
|
||||
inputs=[
|
||||
sp_rename,
|
||||
sp_src,
|
||||
sp_dst,
|
||||
sp_pad,
|
||||
sp_crop,
|
||||
sp_size,
|
||||
sp_caption_append_file,
|
||||
sp_caption_save_txt,
|
||||
sp_txt_action,
|
||||
sp_flip,
|
||||
sp_split,
|
||||
sp_caption,
|
||||
sp_caption_length,
|
||||
sp_caption_clip,
|
||||
sp_clip_use_v2,
|
||||
sp_clip_append_flavor,
|
||||
sp_clip_append_medium,
|
||||
sp_clip_append_movement,
|
||||
sp_clip_append_artist,
|
||||
sp_clip_append_trending,
|
||||
sp_caption_wd14,
|
||||
sp_wd14_min_score,
|
||||
sp_caption_deepbooru,
|
||||
sp_split_threshold,
|
||||
sp_overlap_ratio,
|
||||
sp_booru_min_score,
|
||||
sp_class,
|
||||
sp_subject,
|
||||
sp_replace_class,
|
||||
|
|
|
|||
564
smartprocess.py
564
smartprocess.py
|
|
@ -1,56 +1,51 @@
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from clipcrop import CropClip
|
||||
import reallysafe
|
||||
from modules import shared, images, safe
|
||||
import modules.gfpgan_model
|
||||
import modules.codeformer_model
|
||||
from modules.shared import opts, cmd_opts
|
||||
import modules.gfpgan_model
|
||||
import reallysafe
|
||||
from clipcrop import CropClip
|
||||
from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator
|
||||
from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator
|
||||
from modules import shared, images, safe
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def interrogate_image(image: Image, full=False):
|
||||
if not full:
|
||||
prev_artists = shared.opts.interrogate_use_builtin_artists
|
||||
prev_max = shared.opts.interrogate_clip_max_length
|
||||
prev_min = shared.opts.interrogate_clip_min_length
|
||||
shared.opts.interrogate_clip_min_length = 10
|
||||
shared.opts.interrogate_clip_max_length = 20
|
||||
shared.opts.interrogate_use_builtin_artists = False
|
||||
caption = shared.interrogator.interrogate(image)
|
||||
shared.opts.interrogate_clip_min_length = prev_min
|
||||
shared.opts.interrogate_clip_max_length = prev_max
|
||||
shared.opts.interrogate_use_builtin_artists = prev_artists
|
||||
else:
|
||||
caption = shared.interrogator.interrogate(image)
|
||||
|
||||
return caption
|
||||
def printi(message):
|
||||
shared.state.textinfo = message
|
||||
print(message)
|
||||
|
||||
|
||||
def preprocess(src,
|
||||
def preprocess(rename,
|
||||
src,
|
||||
dst,
|
||||
pad,
|
||||
crop,
|
||||
width,
|
||||
append_filename,
|
||||
save_txt,
|
||||
pretxt_action,
|
||||
max_size,
|
||||
txt_action,
|
||||
flip,
|
||||
split,
|
||||
caption,
|
||||
caption_length,
|
||||
caption_clip,
|
||||
clip_use_v2,
|
||||
clip_append_flavor,
|
||||
clip_append_medium,
|
||||
clip_append_movement,
|
||||
clip_append_artist,
|
||||
clip_append_trending,
|
||||
caption_wd14,
|
||||
wd14_min_score,
|
||||
caption_deepbooru,
|
||||
split_threshold,
|
||||
overlap_ratio,
|
||||
booru_min_score,
|
||||
subject_class,
|
||||
subject,
|
||||
replace_class,
|
||||
|
|
@ -60,324 +55,297 @@ def preprocess(src,
|
|||
upscale_ratio,
|
||||
scaler
|
||||
):
|
||||
|
||||
try:
|
||||
if pad and crop:
|
||||
crop = False
|
||||
shared.state.textinfo = "Loading models for smart processing..."
|
||||
shared.state.textinfo = "Initializing smart processing..."
|
||||
safe.RestrictedUnpickler = reallysafe.RestrictedUnpickler
|
||||
if caption:
|
||||
shared.interrogator.load()
|
||||
|
||||
if caption_deepbooru:
|
||||
deepbooru.model.start()
|
||||
if not crop and not caption and not restore_faces and not upscale and not pad:
|
||||
msg = "Nothing to do."
|
||||
printi(msg)
|
||||
return msg, msg
|
||||
|
||||
prework(src,
|
||||
dst,
|
||||
pad,
|
||||
crop,
|
||||
width,
|
||||
append_filename,
|
||||
save_txt,
|
||||
pretxt_action,
|
||||
flip,
|
||||
split,
|
||||
caption,
|
||||
caption_length,
|
||||
caption_deepbooru,
|
||||
split_threshold,
|
||||
overlap_ratio,
|
||||
subject_class,
|
||||
subject,
|
||||
replace_class,
|
||||
restore_faces,
|
||||
face_model,
|
||||
upscale,
|
||||
upscale_ratio,
|
||||
scaler)
|
||||
wd_interrogator = None
|
||||
db_interrogator = None
|
||||
clip_interrogator = None
|
||||
crop_clip = None
|
||||
|
||||
finally:
|
||||
if caption or crop:
|
||||
printi("\rLoading captioning models...")
|
||||
if caption_clip or crop:
|
||||
printi("\rLoading CLIP interrogator...")
|
||||
if shared.interrogator is not None:
|
||||
shared.interrogator.unload()
|
||||
clip_interrogator = ClipInterrogator(clip_use_v2,
|
||||
clip_append_artist,
|
||||
clip_append_medium,
|
||||
clip_append_movement,
|
||||
clip_append_flavor,
|
||||
clip_append_trending)
|
||||
|
||||
if caption:
|
||||
shared.interrogator.send_blip_to_ram()
|
||||
if caption_deepbooru:
|
||||
printi("\rLoading Deepbooru interrogator...")
|
||||
db_interrogator = BooruInterrogator()
|
||||
|
||||
if caption_deepbooru:
|
||||
deepbooru.model.stop()
|
||||
if caption_wd14:
|
||||
printi("\rLoading wd14 interrogator...")
|
||||
wd_interrogator = WaifuDiffusionInterrogator()
|
||||
|
||||
return "Processing complete.", ""
|
||||
if crop:
|
||||
printi("Loading YOLOv5 interrogator...")
|
||||
crop_clip = CropClip()
|
||||
|
||||
del sys.modules['models']
|
||||
|
||||
def prework(src,
|
||||
dst,
|
||||
pad_image,
|
||||
crop_image,
|
||||
width,
|
||||
append_filename,
|
||||
save_txt,
|
||||
pretxt_action,
|
||||
flip,
|
||||
split,
|
||||
caption_image,
|
||||
caption_length,
|
||||
caption_deepbooru,
|
||||
split_threshold,
|
||||
overlap_ratio,
|
||||
subject_class,
|
||||
subject,
|
||||
replace_class,
|
||||
restore_faces,
|
||||
face_model,
|
||||
upscale,
|
||||
upscale_ratio,
|
||||
scaler):
|
||||
try:
|
||||
del sys.modules['models']
|
||||
except:
|
||||
pass
|
||||
width = width
|
||||
height = width
|
||||
src = os.path.abspath(src)
|
||||
dst = os.path.abspath(dst)
|
||||
src = os.path.abspath(src)
|
||||
dst = os.path.abspath(dst)
|
||||
|
||||
if not crop_image and not caption_image and not restore_faces and not upscale and not pad_image:
|
||||
print("Nothing to do.")
|
||||
shared.state.textinfo = "Nothing to do!"
|
||||
return
|
||||
if src == dst:
|
||||
msg = "Source and destination are the same, returning."
|
||||
printi(msg)
|
||||
return msg, msg
|
||||
|
||||
assert src != dst, 'same directory specified as source and destination'
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
|
||||
os.makedirs(dst, exist_ok=True)
|
||||
files = os.listdir(src)
|
||||
|
||||
files = os.listdir(src)
|
||||
printi("Preprocessing...")
|
||||
shared.state.job_count = len(files)
|
||||
|
||||
shared.state.textinfo = "Preprocessing..."
|
||||
shared.state.job_count = len(files)
|
||||
def build_caption(image):
|
||||
# Read existing caption from path/txt file
|
||||
existing_caption_txt_filename = os.path.splitext(filename)[0] + '.txt'
|
||||
if os.path.exists(existing_caption_txt_filename):
|
||||
with open(existing_caption_txt_filename, 'r', encoding="utf8") as file:
|
||||
existing_caption_txt = file.read()
|
||||
else:
|
||||
existing_caption_txt = ''.join(c for c in filename if c.isalpha() or c in [" ", ", "])
|
||||
|
||||
def build_caption(image):
|
||||
existing_caption = None
|
||||
if not append_filename:
|
||||
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
|
||||
if os.path.exists(existing_caption_filename):
|
||||
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
||||
existing_caption = file.read()
|
||||
else:
|
||||
existing_caption = ''.join(c for c in filename if c.isalpha() or c in [" ", ","])
|
||||
out_tags = []
|
||||
if clip_interrogator is not None:
|
||||
if caption_clip:
|
||||
tags = clip_interrogator.interrogate(img)
|
||||
for tag in tags:
|
||||
#print(f"CLIPTag: {tag}")
|
||||
out_tags.append(tag)
|
||||
|
||||
caption = ""
|
||||
if caption_image:
|
||||
caption = interrogate_image(img, True)
|
||||
if wd_interrogator is not None:
|
||||
ratings, tags = wd_interrogator.interrogate(img)
|
||||
|
||||
if caption_deepbooru:
|
||||
if len(caption) > 0:
|
||||
caption += ", "
|
||||
caption += deepbooru.model.tag_multi(image)
|
||||
for rating in ratings:
|
||||
#print(f"Rating {rating} score is {ratings[rating]}")
|
||||
if ratings[rating] >= wd14_min_score:
|
||||
out_tags.append(rating)
|
||||
|
||||
if pretxt_action == 'prepend' and existing_caption:
|
||||
caption = existing_caption + ' ' + caption
|
||||
elif pretxt_action == 'append' and existing_caption:
|
||||
caption = caption + ' ' + existing_caption
|
||||
elif pretxt_action == 'copy' and existing_caption:
|
||||
caption = existing_caption
|
||||
for tag in sorted(tags, key=tags.get, reverse=True):
|
||||
if tags[tag] >= wd14_min_score:
|
||||
#print(f"WDTag {tag} score is {tags[tag]}")
|
||||
out_tags.append(tag)
|
||||
else:
|
||||
break
|
||||
|
||||
caption = caption.strip()
|
||||
if replace_class and subject is not None and subject_class is not None:
|
||||
# Find and replace "a SUBJECT CLASS" in caption with subject name
|
||||
if f"a {subject_class}" in caption:
|
||||
caption = caption.replace(f"a {subject_class}", subject)
|
||||
if caption_deepbooru:
|
||||
tags = db_interrogator.interrogate(image)
|
||||
for tag in sorted(tags, key=tags.get, reverse=True):
|
||||
if tags[tag] >= booru_min_score:
|
||||
#print(f"DBTag {tag} score is {tags[tag]}")
|
||||
out_tags.append(tag)
|
||||
|
||||
if subject_class in caption:
|
||||
caption = caption.replace(subject_class, subject)
|
||||
# Remove duplicates
|
||||
unique_tags = []
|
||||
for tag in out_tags:
|
||||
if not tag in unique_tags:
|
||||
unique_tags.append(tag.strip())
|
||||
caption_txt = ", ".join(unique_tags)
|
||||
|
||||
if 0 < caption_length < len(caption):
|
||||
split_cap = caption.split(" ")
|
||||
caption = ""
|
||||
cap_test = ""
|
||||
split_idx = 0
|
||||
while True and split_idx < len(split_cap):
|
||||
cap_test += f" {split_cap[split_idx]}"
|
||||
if len(cap_test) < caption_length:
|
||||
caption = cap_test
|
||||
split_idx += 1
|
||||
if txt_action == 'prepend' and existing_caption_txt:
|
||||
caption_txt = existing_caption_txt + ' ' + caption_txt
|
||||
elif txt_action == 'append' and existing_caption_txt:
|
||||
caption_txt = caption_txt + ' ' + existing_caption_txt
|
||||
elif txt_action == 'copy' and existing_caption_txt:
|
||||
caption_txt = existing_caption_txt
|
||||
|
||||
caption = caption.strip()
|
||||
return caption
|
||||
caption_txt = caption_txt.strip()
|
||||
if replace_class and subject is not None and subject_class is not None:
|
||||
# Find and replace "a SUBJECT CLASS" in caption_txt with subject name
|
||||
if f"a {subject_class}" in caption_txt:
|
||||
caption_txt = caption_txt.replace(f"a {subject_class}", subject)
|
||||
|
||||
def save_pic_with_caption(image, img_index, existing_caption):
|
||||
if subject_class in caption_txt:
|
||||
caption_txt = caption_txt.replace(subject_class, subject)
|
||||
|
||||
if append_filename:
|
||||
filename_part = existing_caption
|
||||
basename = f"{img_index:05}-{subindex[0]}-{filename_part}"
|
||||
else:
|
||||
basename = f"{img_index:05}-{subindex[0]}"
|
||||
if 0 < caption_length < len(caption_txt):
|
||||
split_cap = caption_txt.split(" ")
|
||||
caption_txt = ""
|
||||
cap_test = ""
|
||||
split_idx = 0
|
||||
while True and split_idx < len(split_cap):
|
||||
cap_test += f" {split_cap[split_idx]}"
|
||||
if len(cap_test) < caption_length:
|
||||
caption_txt = cap_test
|
||||
split_idx += 1
|
||||
|
||||
shared.state.current_image = img
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
caption_txt = caption_txt.strip()
|
||||
return caption_txt
|
||||
|
||||
if save_txt:
|
||||
if len(existing_caption) > 0:
|
||||
def save_pic(image, src_name, img_index, existing_caption=None, flipped=False):
|
||||
if rename:
|
||||
basename = f"{img_index:05}"
|
||||
else:
|
||||
basename = src_name
|
||||
if flipped:
|
||||
basename += "_flipped"
|
||||
|
||||
shared.state.current_image = img
|
||||
image.save(os.path.join(dst, f"{basename}.png"))
|
||||
|
||||
if existing_caption is not None and len(existing_caption) > 0:
|
||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||
file.write(existing_caption)
|
||||
|
||||
subindex[0] += 1
|
||||
image_index = 0
|
||||
|
||||
def save_pic(image, img_index, existing_caption=None):
|
||||
save_pic_with_caption(image, img_index, existing_caption=existing_caption)
|
||||
# Enumerate images
|
||||
for index, src_image in enumerate(tqdm.tqdm(files)):
|
||||
# Quit on cancel
|
||||
if shared.state.interrupted:
|
||||
msg = f"Processing interrupted, {index}/{len(files)}"
|
||||
return msg, msg
|
||||
|
||||
if flip:
|
||||
save_pic_with_caption(ImageOps.mirror(image), img_index, existing_caption=existing_caption)
|
||||
filename = os.path.join(src, src_image)
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
except Exception as e:
|
||||
msg = f"Exception processing: {e}"
|
||||
printi(msg)
|
||||
traceback.print_exc()
|
||||
return msg, msg
|
||||
|
||||
def split_pic(image, img_inverse_xy):
|
||||
if img_inverse_xy:
|
||||
from_w, from_h = image.height, image.width
|
||||
to_w, to_h = height, width
|
||||
else:
|
||||
from_w, from_h = image.width, image.height
|
||||
to_w, to_h = width, height
|
||||
h = from_h * to_w // from_w
|
||||
if img_inverse_xy:
|
||||
image = image.resize((h, to_w))
|
||||
else:
|
||||
image = image.resize((to_w, h))
|
||||
if crop:
|
||||
# Interrogate once
|
||||
short_caption = clip_interrogator.interrogate(img, short=True)
|
||||
|
||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
||||
y_step = (h - to_h) / (split_count - 1)
|
||||
for i in range(split_count):
|
||||
y = int(y_step * i)
|
||||
if img_inverse_xy:
|
||||
split_img = image.crop((y, 0, y + to_h, to_w))
|
||||
else:
|
||||
split_img = image.crop((0, y, to_w, y + to_h))
|
||||
yield split_img
|
||||
if subject_class is not None and subject_class != "":
|
||||
short_caption = subject_class
|
||||
|
||||
crop_clip = None
|
||||
shared.state.textinfo = f"Cropping: {short_caption}"
|
||||
src_ratio = img.width / img.height
|
||||
|
||||
if crop_image:
|
||||
split_threshold = max(0.0, min(1.0, split_threshold))
|
||||
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
|
||||
crop_clip = CropClip()
|
||||
# Pad image before cropping?
|
||||
if src_ratio != 1 and pad:
|
||||
if img.width > img.height:
|
||||
pad_width = img.width
|
||||
pad_height = img.width
|
||||
else:
|
||||
pad_width = img.height
|
||||
pad_height = img.height
|
||||
res = Image.new("RGB", (pad_width, pad_height))
|
||||
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
|
||||
img = res
|
||||
|
||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||
# Do the actual crop clip
|
||||
im_data = crop_clip.get_center(img, prompt=short_caption)
|
||||
crop_width = im_data[1] - im_data[0]
|
||||
center_x = im_data[0] + (crop_width / 2)
|
||||
crop_height = im_data[3] - im_data[2]
|
||||
center_y = im_data[2] + (crop_height / 2)
|
||||
crop_ratio = crop_width / crop_height
|
||||
dest_ratio = 1
|
||||
tgt_width = crop_width
|
||||
tgt_height = crop_height
|
||||
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
if crop_ratio != dest_ratio:
|
||||
if crop_width > crop_height:
|
||||
tgt_height = crop_width / dest_ratio
|
||||
tgt_width = crop_width
|
||||
else:
|
||||
tgt_width = crop_height / dest_ratio
|
||||
tgt_height = crop_height
|
||||
|
||||
subindex = [0]
|
||||
filename = os.path.join(src, imagefile)
|
||||
try:
|
||||
img = Image.open(filename).convert("RGB")
|
||||
except Exception:
|
||||
continue
|
||||
# Reverse the above if dest is too big
|
||||
if tgt_width > img.width or tgt_height > img.height:
|
||||
if tgt_width > img.width:
|
||||
tgt_width = img.width
|
||||
tgt_height = tgt_width / dest_ratio
|
||||
else:
|
||||
tgt_height = img.height
|
||||
tgt_width = tgt_height / dest_ratio
|
||||
|
||||
shared.state.textinfo = f"Processing: '({filename})"
|
||||
if crop_image:
|
||||
# Interrogate once
|
||||
short_caption = interrogate_image(img)
|
||||
tgt_height = int(tgt_height)
|
||||
tgt_width = int(tgt_width)
|
||||
left = max(center_x - (tgt_width / 2), 0)
|
||||
right = min(center_x + (tgt_width / 2), img.width)
|
||||
top = max(center_y - (tgt_height / 2), 0)
|
||||
bottom = min(center_y + (tgt_height / 2), img.height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
shared.state.current_image = img
|
||||
|
||||
if subject_class is not None and subject_class != "":
|
||||
short_caption = subject_class
|
||||
|
||||
shared.state.textinfo = f"Cropping: {short_caption}"
|
||||
if img.height > img.width:
|
||||
ratio = (img.width * height) / (img.height * width)
|
||||
inverse_xy = False
|
||||
else:
|
||||
ratio = (img.height * width) / (img.width * height)
|
||||
inverse_xy = True
|
||||
|
||||
if split and ratio < 1.0 and ratio <= split_threshold:
|
||||
for splitted in split_pic(img, inverse_xy):
|
||||
# Build our caption
|
||||
full_caption = None
|
||||
if caption_image:
|
||||
full_caption = build_caption(splitted)
|
||||
save_pic(splitted, index, existing_caption=full_caption)
|
||||
|
||||
src_ratio = img.width / img.height
|
||||
# Pad image before cropping?
|
||||
if src_ratio != 1:
|
||||
if img.width > img.height:
|
||||
pad_width = img.width
|
||||
pad_height = img.width
|
||||
if restore_faces:
|
||||
shared.state.textinfo = f"Restoring faces using {face_model}..."
|
||||
if face_model == "gfpgan":
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(img, dtype=np.uint8))
|
||||
img = Image.fromarray(restored_img)
|
||||
else:
|
||||
pad_width = img.height
|
||||
pad_height = img.height
|
||||
res = Image.new("RGB", (pad_width, pad_height))
|
||||
res.paste(img, box=(pad_width // 2 - img.width // 2, pad_height // 2 - img.height // 2))
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(img, dtype=np.uint8),
|
||||
w=1.0)
|
||||
img = Image.fromarray(restored_img)
|
||||
shared.state.current_image = img
|
||||
|
||||
if upscale:
|
||||
shared.state.textinfo = "Upscaling..."
|
||||
upscaler = shared.sd_upscalers[scaler]
|
||||
res = upscaler.scaler.upscale(img, upscale_ratio, upscaler.data_path)
|
||||
img = res
|
||||
shared.state.current_image = img
|
||||
|
||||
if pad:
|
||||
ratio = 1
|
||||
src_ratio = img.width / img.height
|
||||
|
||||
src_w = max_size if ratio < src_ratio else img.width * max_size // img.height
|
||||
src_h = max_size if ratio >= src_ratio else img.height * max_size // img.width
|
||||
|
||||
resized = images.resize_image(0, img, src_w, src_h)
|
||||
res = Image.new("RGB", (max_size, max_size))
|
||||
res.paste(resized, box=(max_size // 2 - src_w // 2, max_size // 2 - src_h // 2))
|
||||
img = res
|
||||
|
||||
im_data = crop_clip.get_center(img, prompt=short_caption)
|
||||
crop_width = im_data[1] - im_data[0]
|
||||
center_x = im_data[0] + (crop_width / 2)
|
||||
crop_height = im_data[3] - im_data[2]
|
||||
center_y = im_data[2] + (crop_height / 2)
|
||||
crop_ratio = crop_width / crop_height
|
||||
dest_ratio = width / height
|
||||
tgt_width = crop_width
|
||||
tgt_height = crop_height
|
||||
# Resize again if image is not at the right size.
|
||||
if img.width != max_size or img.height != max_size:
|
||||
img = images.resize_image(1, img, max_size, max_size)
|
||||
|
||||
if crop_ratio != dest_ratio:
|
||||
if crop_width > crop_height:
|
||||
tgt_height = crop_width / dest_ratio
|
||||
tgt_width = crop_width
|
||||
else:
|
||||
tgt_width = crop_height / dest_ratio
|
||||
tgt_height = crop_height
|
||||
|
||||
# Reverse the above if dest is too big
|
||||
if tgt_width > img.width or tgt_height > img.height:
|
||||
if tgt_width > img.width:
|
||||
tgt_width = img.width
|
||||
tgt_height = tgt_width / dest_ratio
|
||||
else:
|
||||
tgt_height = img.height
|
||||
tgt_width = tgt_height / dest_ratio
|
||||
|
||||
tgt_height = int(tgt_height)
|
||||
tgt_width = int(tgt_width)
|
||||
left = max(center_x - (tgt_width / 2), 0)
|
||||
right = min(center_x + (tgt_width / 2), img.width)
|
||||
top = max(center_y - (tgt_height / 2), 0)
|
||||
bottom = min(center_y + (tgt_height / 2), img.height)
|
||||
img = img.crop((left, top, right, bottom))
|
||||
default_resize = True
|
||||
# Build a caption, if enabled
|
||||
full_caption = build_caption(img) if caption else None
|
||||
# Show our output
|
||||
shared.state.current_image = img
|
||||
else:
|
||||
default_resize = False
|
||||
printi(f"Processed: '({src_image} - {full_caption})")
|
||||
|
||||
if restore_faces:
|
||||
shared.state.textinfo = f"Restoring faces using {face_model}..."
|
||||
if face_model == "gfpgan":
|
||||
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(img, dtype=np.uint8))
|
||||
img = Image.fromarray(restored_img)
|
||||
else:
|
||||
restored_img = modules.codeformer_model.codeformer.restore(np.array(img, dtype=np.uint8),
|
||||
w=1.0)
|
||||
img = Image.fromarray(restored_img)
|
||||
shared.state.current_image = img
|
||||
save_pic(img, src_image, image_index, existing_caption=full_caption)
|
||||
image_index += 1
|
||||
|
||||
if upscale:
|
||||
shared.state.textinfo = "Upscaling..."
|
||||
upscaler = shared.sd_upscalers[scaler]
|
||||
res = upscaler.scaler.upscale(img, upscale_ratio, upscaler.data_path)
|
||||
img = res
|
||||
default_resize = True
|
||||
shared.state.current_image = img
|
||||
if flip:
|
||||
save_pic(ImageOps.flip(img), src_image, image_index, existing_caption=full_caption, flipped=True)
|
||||
image_index += 1
|
||||
|
||||
if pad_image:
|
||||
ratio = width / height
|
||||
src_ratio = img.width / img.height
|
||||
shared.state.nextjob()
|
||||
|
||||
src_w = width if ratio < src_ratio else img.width * height // img.height
|
||||
src_h = height if ratio >= src_ratio else img.height * width // img.width
|
||||
if caption_clip or crop:
|
||||
printi("Unloading CLIP interrogator...")
|
||||
shared.interrogator.send_blip_to_ram()
|
||||
|
||||
resized = images.resize_image(0, img, src_w, src_h)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
img = res
|
||||
if caption_deepbooru:
|
||||
printi("Unloading Deepbooru interrogator...")
|
||||
db_interrogator.unload()
|
||||
|
||||
if default_resize:
|
||||
img = images.resize_image(1, img, width, height)
|
||||
shared.state.current_image = img
|
||||
full_caption = build_caption(img)
|
||||
save_pic(img, index, existing_caption=full_caption)
|
||||
if caption_wd14:
|
||||
printi("Unloading wd14 interrogator...")
|
||||
wd_interrogator.unload()
|
||||
|
||||
shared.state.nextjob()
|
||||
return f"Successfully processed {len(files)} images.", f"Successfully processed {len(files)} images."
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Exception processing: {e}"
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
return msg, msg
|
||||
|
|
|
|||
Loading…
Reference in New Issue