689 lines
24 KiB
Python
689 lines
24 KiB
Python
import gc
|
|
import os
|
|
import re
|
|
import sys
|
|
import traceback
|
|
from io import StringIO
|
|
from math import sqrt
|
|
from pathlib import Path
|
|
from typing import Union, List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image, features
|
|
from tqdm import tqdm
|
|
|
|
import modules.codeformer_model
|
|
import modules.gfpgan_model
|
|
from clipcrop import CropClip
|
|
from extensions.sd_smartprocess.file_manager import ImageData
|
|
from extensions.sd_smartprocess.interrogators.blip_interrogator import BLIPInterrogator
|
|
from extensions.sd_smartprocess.interrogators.interrogator import InterrogatorRegistry
|
|
from extensions.sd_smartprocess.model_download import disable_safe_unpickle, enable_safe_unpickle
|
|
from extensions.sd_smartprocess.process_params import ProcessParams
|
|
from modules import shared, images
|
|
|
|
blip_interrogator = None
|
|
crop_clip = None
|
|
image_interrogators = {}
|
|
global_unpickler = None
|
|
image_features = None
|
|
|
|
|
|
def printi(message):
|
|
shared.state.textinfo = message
|
|
print(message)
|
|
|
|
|
|
def get_backup_path(file_path, params: ProcessParams):
|
|
backup_path = file_path
|
|
if params.do_backup:
|
|
file_base = os.path.splitext(file_path)[0]
|
|
file_ext = os.path.splitext(file_path)[1]
|
|
backup_index = 0
|
|
backup_path = f"{file_base}_backup{backup_index}{file_ext}"
|
|
if os.path.exists(backup_path):
|
|
while os.path.exists(backup_path):
|
|
backup_index += 1
|
|
backup_path = f"{file_base}_backup{backup_index}{file_ext}"
|
|
return file_path, backup_path
|
|
|
|
|
|
def save_pic(img, src_name, img_index, params: ProcessParams):
|
|
dest_dir = os.path.dirname(src_name)
|
|
if params.do_rename:
|
|
basename = f"{img_index:05}"
|
|
else:
|
|
src_name, backup_name = get_backup_path(src_name, params)
|
|
if src_name != backup_name and os.path.exists(src_name):
|
|
os.rename(src_name, backup_name)
|
|
basename = os.path.splitext(os.path.basename(src_name))[0]
|
|
|
|
shared.state.current_image = img
|
|
dest = os.path.join(dest_dir, f"{basename}.png")
|
|
img.save(dest)
|
|
return dest
|
|
|
|
|
|
def save_img_caption(image_path: str, img_caption: str, params: ProcessParams):
|
|
basename = os.path.splitext(image_path)[0]
|
|
dest = f"{basename}.txt"
|
|
src_name, backup_name = get_backup_path(dest, params)
|
|
if src_name != backup_name and os.path.exists(src_name):
|
|
os.rename(src_name, backup_name)
|
|
if img_caption is not None and len(img_caption) > 0:
|
|
with open(src_name, "w", encoding="utf8") as file:
|
|
file.write(src_name)
|
|
return src_name
|
|
|
|
|
|
def list_features():
|
|
global image_features
|
|
if image_features is None:
|
|
# Create buffer for pilinfo() to write into rather than stdout
|
|
buffer = StringIO()
|
|
features.pilinfo(out=buffer)
|
|
pil_features = []
|
|
# Parse and analyse lines
|
|
for line in buffer.getvalue().splitlines():
|
|
if "Extensions:" in line:
|
|
ext_list = line.split(": ")[1]
|
|
extensions = ext_list.split(", ")
|
|
for extension in extensions:
|
|
if extension not in pil_features:
|
|
pil_features.append(extension)
|
|
image_features = pil_features
|
|
else:
|
|
pil_features = image_features
|
|
return pil_features
|
|
|
|
|
|
def is_image(path: Union[Path, str], feats=None):
|
|
if feats is None:
|
|
feats = []
|
|
if not len(feats):
|
|
feats = list_features()
|
|
if isinstance(path, str):
|
|
path = Path(path)
|
|
is_img = path.is_file() and path.suffix.lower() in feats
|
|
return is_img
|
|
|
|
|
|
def cleanup():
|
|
try:
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
gc.collect()
|
|
except:
|
|
print("cleanup exception")
|
|
|
|
|
|
def vram_usage():
|
|
if torch.cuda.is_available():
|
|
used_vram = torch.cuda.memory_allocated(0) / 1024 ** 3 # Convert bytes to GB
|
|
total_vram = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 # Convert bytes to GB
|
|
return used_vram, total_vram
|
|
else:
|
|
return 0.0, 0.0
|
|
|
|
|
|
def unload_system():
|
|
disable_safe_unpickle()
|
|
if shared.interrogator is not None:
|
|
shared.interrogator.unload()
|
|
try:
|
|
shared.sd_model.to("cpu")
|
|
except:
|
|
pass
|
|
for former in modules.shared.face_restorers:
|
|
try:
|
|
former.send_model_to("cpu")
|
|
except:
|
|
pass
|
|
cleanup()
|
|
used, total = vram_usage()
|
|
print(f"System unloaded, current VRAM usage: {used}/{total} GB")
|
|
|
|
|
|
def load_system():
|
|
enable_safe_unpickle()
|
|
if shared.interrogator is not None:
|
|
shared.interrogator.send_blip_to_ram()
|
|
try:
|
|
if modules.shared.sd_model is not None:
|
|
modules.shared.sd_model.to(shared.device)
|
|
except:
|
|
pass
|
|
|
|
|
|
def get_crop_clip():
|
|
global crop_clip
|
|
if crop_clip is None:
|
|
try:
|
|
del sys.modules['models']
|
|
except:
|
|
pass
|
|
|
|
crop_clip = CropClip()
|
|
return crop_clip
|
|
|
|
|
|
def get_image_interrogators(params: ProcessParams, all_captioners):
|
|
global image_interrogators
|
|
all_interrogators = InterrogatorRegistry.get_all_interrogators()
|
|
interrogators = all_captioners
|
|
caption_agents = []
|
|
print(f"Interrogators: {interrogators}")
|
|
for interrogator_name in interrogators:
|
|
if interrogator_name not in image_interrogators:
|
|
printi(f"\rLoading {interrogator_name} interrogator...")
|
|
interrogator = all_interrogators[f"{interrogator_name}Interrogator"](params)
|
|
image_interrogators[interrogator_name] = interrogator
|
|
else:
|
|
interrogator = image_interrogators[interrogator_name]
|
|
interrogator.unload()
|
|
caption_agents.append(interrogator)
|
|
|
|
return caption_agents
|
|
|
|
|
|
def clean_string(s):
|
|
"""
|
|
Remove non-alphanumeric characters except spaces, and normalize spacing.
|
|
Args:
|
|
s: The string to clean.
|
|
|
|
Returns: A cleaned string.
|
|
"""
|
|
# Strip any HTML tags
|
|
cleaned = re.sub(r'<[^>]+>', '', s)
|
|
# Remove non-alphanumeric characters except spaces
|
|
cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', cleaned)
|
|
# Check for a sentence with just the same word repeated
|
|
if len(set(cleaned.split())) == 1:
|
|
cleaned = cleaned.split()[0]
|
|
words = cleaned.split()
|
|
words_out = []
|
|
for word in words:
|
|
if word == "y":
|
|
word = "a"
|
|
words_out.append(word)
|
|
cleaned = " ".join(words_out)
|
|
# Replace multiple spaces with a single space
|
|
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
|
return cleaned
|
|
|
|
|
|
def read_caption(image):
|
|
existing_caption_txt_filename = os.path.splitext(image)[0] + '.txt'
|
|
if os.path.exists(existing_caption_txt_filename):
|
|
with open(existing_caption_txt_filename, 'r', encoding="utf8") as file:
|
|
existing_caption_txt = file.read()
|
|
else:
|
|
image_name = os.path.splitext(os.path.basename(image))[0]
|
|
existing_caption_txt = clean_string(image_name)
|
|
return existing_caption_txt
|
|
|
|
|
|
def build_caption(image, captions_list, tags_to_ignore, caption_length, subject_class, subject, replace_class,
|
|
txt_action="ignore"):
|
|
"""
|
|
Build a caption from an array of captions, optionally ignoring tags, optionally replacing a class name with a subject name.
|
|
Args:
|
|
image: the image path, used for existing caption txt file
|
|
captions_list: A list of generated captions
|
|
tags_to_ignore: A comma-separated list of tags to ignore
|
|
caption_length: The maximum number of tags to include in the caption
|
|
subject_class: The class name to replace
|
|
subject: The subject name to replace the class name with
|
|
replace_class: Whether to replace the class name with the subject name
|
|
txt_action: What to do with the existing caption, if any
|
|
|
|
Returns: A string containing the caption
|
|
"""
|
|
|
|
all_tags = set()
|
|
for cap in captions_list:
|
|
all_tags.update({clean_string(tag) for tag in cap.split(",") if tag.strip()})
|
|
|
|
if isinstance(tags_to_ignore, str):
|
|
tags_to_ignore = tags_to_ignore.split(",")
|
|
# Filter out ignored tags
|
|
ignore_tags = set(clean_string(tag) for tag in tags_to_ignore if tag.strip())
|
|
all_tags.difference_update(ignore_tags)
|
|
|
|
# Handling existing caption based on txt_action
|
|
if txt_action == "include":
|
|
# Read existing caption from path/txt file
|
|
existing_caption_txt = read_caption(image)
|
|
|
|
existing_tags = set(clean_string(tag) for tag in existing_caption_txt.split(",") if tag.strip())
|
|
else:
|
|
existing_tags = set()
|
|
all_tags = all_tags.union(existing_tags)
|
|
|
|
# Replace class with subject
|
|
if replace_class and subject is not "" and subject_class is not "":
|
|
phrases = ["a", "an", "the", "this", "that"]
|
|
cleaned_tags = []
|
|
for tag in all_tags:
|
|
replaced = False
|
|
lower_tag = tag.lower()
|
|
for phrase in phrases:
|
|
conjunction = f"{phrase} {subject_class.lower()}"
|
|
if conjunction in lower_tag:
|
|
tag = lower_tag.replace(conjunction, f"{subject.lower()}")
|
|
replaced = True
|
|
break
|
|
if not replaced:
|
|
tag = lower_tag.replace(subject_class, subject)
|
|
cleaned_tags.append(tag)
|
|
all_tags = set(cleaned_tags)
|
|
# Limiting caption length
|
|
tags_list = list(all_tags)
|
|
# Sort tags list by length, with the longest caption first
|
|
tags_list.sort(key=len, reverse=True)
|
|
# if caption_length and len(tags_list) > caption_length:
|
|
# tags_list = tags_list[:caption_length]
|
|
|
|
caption_txt = ", ".join(tags_list)
|
|
return caption_txt
|
|
|
|
|
|
def calculate_job_length(files, crop, caption, captioners, flip, restore_faces, upscale):
|
|
num_files = len(files)
|
|
job_length = 0
|
|
if crop:
|
|
job_length += num_files
|
|
if caption:
|
|
job_length += num_files * len(captioners)
|
|
if flip:
|
|
job_length += num_files
|
|
if restore_faces:
|
|
job_length += num_files
|
|
if upscale:
|
|
job_length += num_files
|
|
return job_length
|
|
|
|
|
|
def crop_smart(img: Image, interrogator: BLIPInterrogator, cc: CropClip, params: ProcessParams):
|
|
short_caption = interrogator.interrogate(img, params)
|
|
im_data = cc.get_center(img, prompt=short_caption)
|
|
crop_width = im_data[1] - im_data[0]
|
|
center_x = im_data[0] + (crop_width / 2)
|
|
crop_height = im_data[3] - im_data[2]
|
|
center_y = im_data[2] + (crop_height / 2)
|
|
crop_ratio = crop_width / crop_height
|
|
dest_ratio = 1
|
|
tgt_width = crop_width
|
|
tgt_height = crop_height
|
|
|
|
if crop_ratio != dest_ratio:
|
|
if crop_width > crop_height:
|
|
tgt_height = crop_width / dest_ratio
|
|
tgt_width = crop_width
|
|
else:
|
|
tgt_width = crop_height / dest_ratio
|
|
tgt_height = crop_height
|
|
|
|
# Reverse the above if dest is too big
|
|
if tgt_width > img.width or tgt_height > img.height:
|
|
if tgt_width > img.width:
|
|
tgt_width = img.width
|
|
tgt_height = tgt_width / dest_ratio
|
|
else:
|
|
tgt_height = img.height
|
|
tgt_width = tgt_height / dest_ratio
|
|
|
|
tgt_height = int(tgt_height)
|
|
tgt_width = int(tgt_width)
|
|
left = max(center_x - (tgt_width / 2), 0)
|
|
right = min(center_x + (tgt_width / 2), img.width)
|
|
top = max(center_y - (tgt_height / 2), 0)
|
|
bottom = min(center_y + (tgt_height / 2), img.height)
|
|
img = img.crop((left, top, right, bottom))
|
|
return img
|
|
|
|
|
|
def crop_center(img: Image, max_size: int):
|
|
ratio = 1
|
|
src_ratio = img.width / img.height
|
|
|
|
src_w = max_size if ratio < src_ratio else img.width * max_size // img.height
|
|
src_h = max_size if ratio >= src_ratio else img.height * max_size // img.width
|
|
|
|
resized = images.resize_image(0, img, src_w, src_h)
|
|
res = Image.new("RGB", (max_size, max_size))
|
|
res.paste(resized, box=(max_size // 2 - src_w // 2, max_size // 2 - src_h // 2))
|
|
img = res
|
|
return img
|
|
|
|
|
|
def crop_empty(img: Image):
|
|
# Convert PIL Image to OpenCV format
|
|
open_cv_image = np.array(img)
|
|
open_cv_image = open_cv_image[:, :, ::-1].copy() # Convert RGB to BGR
|
|
|
|
# Function to check if the border is uniform
|
|
def is_uniform_border(border_slice):
|
|
# Check if all pixels in the slice are the same
|
|
return np.all(border_slice == border_slice[0, 0, :])
|
|
|
|
# Check top, bottom, left, right borders
|
|
h, w, _ = open_cv_image.shape
|
|
top_border = open_cv_image[0:1, :]
|
|
bottom_border = open_cv_image[h - 1:h, :]
|
|
left_border = open_cv_image[:, 0:1]
|
|
right_border = open_cv_image[:, w - 1:w]
|
|
|
|
# Compare opposite borders to ensure they match
|
|
top_matches_bottom = is_uniform_border(top_border) and is_uniform_border(bottom_border) and np.all(
|
|
top_border == bottom_border)
|
|
left_matches_right = is_uniform_border(left_border) and is_uniform_border(right_border) and np.all(
|
|
left_border == right_border)
|
|
|
|
# Find the padding sizes
|
|
top_pad = 0
|
|
bottom_pad = 0
|
|
left_pad = 0
|
|
right_pad = 0
|
|
|
|
if top_matches_bottom:
|
|
while top_pad < h and is_uniform_border(open_cv_image[top_pad:top_pad + 1, :]):
|
|
top_pad += 1
|
|
while bottom_pad < h and is_uniform_border(open_cv_image[h - bottom_pad - 1:h - bottom_pad, :]):
|
|
bottom_pad += 1
|
|
|
|
if left_matches_right:
|
|
while left_pad < w and is_uniform_border(open_cv_image[:, left_pad:left_pad + 1]):
|
|
left_pad += 1
|
|
while right_pad < w and is_uniform_border(open_cv_image[:, w - right_pad - 1:w - right_pad]):
|
|
right_pad += 1
|
|
|
|
# Crop the image
|
|
cropped_image = open_cv_image[top_pad:h - bottom_pad, left_pad:w - right_pad]
|
|
|
|
# Convert back to PIL Image
|
|
cropped_image = Image.fromarray(cropped_image[:, :, ::-1]) # Convert BGR to RGB
|
|
|
|
return cropped_image
|
|
|
|
|
|
def crop_contain(img, params: ProcessParams):
|
|
ratio = 1
|
|
src_ratio = img.width / img.height
|
|
|
|
src_w = params.max_size if ratio < src_ratio else img.width * params.max_size // img.height
|
|
src_h = params.max_size if ratio >= src_ratio else img.height * params.max_size // img.width
|
|
|
|
resized = images.resize_image(0, img, src_w, src_h)
|
|
res = Image.new("RGB", (params.max_size, params.max_size))
|
|
res.paste(resized, box=(params.max_size // 2 - src_w // 2, params.max_size // 2 - src_h // 2))
|
|
img = res
|
|
return img
|
|
|
|
|
|
def get_blip_interrogator(params: ProcessParams):
|
|
global blip_interrogator
|
|
|
|
if blip_interrogator is None:
|
|
blip_interrogator = BLIPInterrogator(params)
|
|
|
|
else:
|
|
|
|
blip_interrogator.unload()
|
|
|
|
return blip_interrogator
|
|
|
|
|
|
def process_pre(files: List[ImageData], params: ProcessParams) -> List[ImageData]:
|
|
output = []
|
|
interrogator = None
|
|
cc = None
|
|
if params.crop and params.crop_mode == "smart":
|
|
interrogator = get_blip_interrogator(params.clip_params())
|
|
cc = get_crop_clip()
|
|
total_files = len(files)
|
|
crop_length = 0
|
|
if params.crop:
|
|
crop_length += total_files
|
|
if params.pad:
|
|
crop_length += total_files
|
|
pbar = tqdm(total=crop_length, desc="Processing images")
|
|
for image_data in files:
|
|
img = image_data.get_image()
|
|
if params.crop:
|
|
if params.crop_mode == "smart":
|
|
img = crop_smart(img, interrogator, cc, params)
|
|
elif params.crop_mode == "center":
|
|
img = crop_center(img, params.max_size)
|
|
elif params.crop_mode == "empty":
|
|
img = crop_empty(img)
|
|
elif params.crop_mode == "contain":
|
|
img = crop_contain(img, params)
|
|
shared.state.current_image = img
|
|
pbar.update(1)
|
|
shared.state.job_no += 1
|
|
|
|
if params.pad:
|
|
ratio = 1
|
|
src_ratio = img.width / img.height
|
|
|
|
src_w = params.max_size if ratio < src_ratio else img.width * params.max_size // img.height
|
|
src_h = params.max_size if ratio >= src_ratio else img.height * params.max_size // img.width
|
|
|
|
resized = images.resize_image(0, img, src_w, src_h)
|
|
res = Image.new("RGB", (params.max_size, params.max_size))
|
|
res.paste(resized, box=(params.max_size // 2 - src_w // 2, params.max_size // 2 - src_h // 2))
|
|
img = res
|
|
pbar.update(1)
|
|
shared.state.job_no += 1
|
|
|
|
image_data.update_image(img)
|
|
|
|
if params.save_image:
|
|
img_path = save_pic(img, image_data.image_path, len(output), params)
|
|
image_data.image_path = img_path
|
|
output.append(image_data)
|
|
else:
|
|
output.append(image_data)
|
|
if interrogator is not None:
|
|
interrogator.unload()
|
|
if cc is not None:
|
|
cc.unload()
|
|
return output
|
|
|
|
|
|
def process_captions(files: List[ImageData], params: ProcessParams, all_captioners) -> List[ImageData]:
|
|
output = []
|
|
caption_dict = {}
|
|
caption_length = params.max_tokens
|
|
tags_to_ignore = params.tags_to_ignore
|
|
subject_class = params.subject_class
|
|
subject = params.subject
|
|
replace_class = params.replace_class
|
|
txt_action = params.txt_action
|
|
save_captions = params.save_caption or params.auto_save
|
|
agents = get_image_interrogators(params, all_captioners)
|
|
total_files = len(files)
|
|
total_captions = total_files * len(agents)
|
|
pbar = tqdm(total=total_captions, desc="Captioning images")
|
|
for caption_agent in agents:
|
|
print(f"Captioning with {caption_agent.__class__.__name__}...")
|
|
caption_agent.load()
|
|
for image_data in files:
|
|
temp_params = params
|
|
img = image_data.get_image()
|
|
temp_params.image_path = image_data.image_path
|
|
image_path = image_data.image_path
|
|
if image_path not in caption_dict:
|
|
caption_dict[image_path] = []
|
|
try:
|
|
# If the agent is LLAVA2, build the current caption
|
|
if caption_agent.__class__.__name__ == "LLAVA2Interrogator":
|
|
print("Building caption for LLAVA2")
|
|
temp_params.new_caption = build_caption(image_path, caption_dict[image_path], tags_to_ignore,
|
|
caption_length,
|
|
subject_class, subject, replace_class, txt_action)
|
|
caption_out = caption_agent.interrogate(img, temp_params)
|
|
print(f"Caption for {image_path}: {caption_out}")
|
|
caption_dict[image_path].append(caption_out)
|
|
pbar.update(1)
|
|
shared.state.job_no += 1
|
|
except Exception as e:
|
|
print(f"Exception captioning {image_data}: {e}")
|
|
traceback.print_exc()
|
|
caption_agent.unload()
|
|
|
|
output_dict = {}
|
|
for image_path, captions in caption_dict.items():
|
|
caption_string = build_caption(image_path, captions, tags_to_ignore, caption_length, subject_class, subject,
|
|
replace_class, txt_action)
|
|
output_dict[image_path] = caption_string
|
|
if save_captions:
|
|
save_img_caption(image_path, caption_string, params)
|
|
# Find the image data object in files with the path matching image_path
|
|
image_data = next((image_data for image_data in files if image_data.image_path == image_path), None)
|
|
if image_data is not None:
|
|
image_data.update_caption(caption_string, False)
|
|
output.append(image_data)
|
|
return output
|
|
|
|
|
|
def process_post(files: ImageData, params: ProcessParams) -> List[ImageData]:
|
|
output = []
|
|
total_files = len(files)
|
|
total_post = 0
|
|
if params.restore_faces:
|
|
total_post += total_files
|
|
|
|
if params.upscale:
|
|
total_post += total_files
|
|
pbar = tqdm(total=total_post, desc="Post-processing images")
|
|
upscalers = []
|
|
|
|
if params.upscale:
|
|
shared.state.textinfo = "Upscaling..."
|
|
if params.upscaler_1 is not None and params.upscaler_1 != "None":
|
|
upscalers.append(params.upscaler_1)
|
|
if params.upscaler_2 is not None and params.upscaler_2 != "None":
|
|
upscalers.append(params.upscaler_2)
|
|
img_index = 0
|
|
for file in files:
|
|
img = file.get_image()
|
|
if params.restore_faces:
|
|
shared.state.textinfo = f"Restoring faces using {params.face_model}..."
|
|
if params.face_model == "gfpgan":
|
|
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(img, dtype=np.uint8))
|
|
img = Image.fromarray(restored_img)
|
|
else:
|
|
restored_img = modules.codeformer_model.codeformer.restore(np.array(img, dtype=np.uint8),
|
|
w=1.0)
|
|
img = Image.fromarray(restored_img)
|
|
pbar.update(1)
|
|
shared.state.job_no += 1
|
|
shared.state.current_image = img
|
|
|
|
if params.upscale:
|
|
shared.state.textinfo = "Upscaling..."
|
|
used, total = vram_usage()
|
|
print(f"Upscaling, current VRAM usage: {used}/{total} GB")
|
|
scaler_dims = {}
|
|
for scaler_name in upscalers:
|
|
for scaler in shared.sd_upscalers:
|
|
print(f"Scaler: {scaler.name}")
|
|
if scaler.name == scaler_name:
|
|
if scaler.name != "none":
|
|
scaler_dims[scaler_name] = scaler.scale
|
|
break
|
|
|
|
# Calculate the upscale factor
|
|
if params.upscale_mode == "Size":
|
|
upscale_to_max = params.max_size
|
|
desired_upscale = max(upscale_to_max / img.width, upscale_to_max / img.height)
|
|
else:
|
|
desired_upscale = params.upscale_ratio
|
|
|
|
# Adjust the upscale factor if two upscalers are used
|
|
if len(upscalers) == 2:
|
|
upscale_by = sqrt(desired_upscale)
|
|
else:
|
|
upscale_by = desired_upscale
|
|
print(f"Upscalers: {upscalers}")
|
|
# Apply each upscaler sequentially
|
|
img_prompt = None
|
|
for scaler_name in upscalers:
|
|
upscaler = None
|
|
for scaler in shared.sd_upscalers:
|
|
print(f"Scaler: {scaler.name}")
|
|
if scaler.name == scaler_name:
|
|
upscaler = scaler
|
|
if scaler.name == "SD4x":
|
|
img_prompt = file.caption
|
|
break
|
|
if upscaler:
|
|
scaler = upscaler.scaler
|
|
if img_prompt:
|
|
scaler.prompt = img_prompt
|
|
img = scaler.upscale(img, upscale_by, upscaler.data_path)
|
|
try:
|
|
scaler.unload()
|
|
except:
|
|
pass
|
|
pbar.update(1)
|
|
shared.state.job_no += 1
|
|
shared.state.current_image = img
|
|
|
|
if params.save_image:
|
|
img_path = save_pic(img, file, img_index, params)
|
|
file.image_path = img_path
|
|
file.update_image(img, False)
|
|
output.append(file)
|
|
img_index += 1
|
|
return output
|
|
|
|
|
|
def do_process(params: ProcessParams) -> Tuple[List[ImageData], str]:
|
|
print(f"Processing with params: {params}")
|
|
output = params.src_files
|
|
try:
|
|
global blip_interrogator
|
|
global image_interrogators
|
|
# combine params.captioners and params.nl_captioners
|
|
all_captioners = params.captioners
|
|
for nl_captioner in params.nl_captioners:
|
|
all_captioners.append(nl_captioner)
|
|
|
|
job_length = calculate_job_length(params.src_files, params.crop, params.caption, all_captioners, params.flip,
|
|
params.restore_faces, params.upscale)
|
|
|
|
if job_length == 0:
|
|
msg = "Nothing to do."
|
|
printi(msg)
|
|
return output, msg
|
|
|
|
unload_system()
|
|
do_preprocess = params.pad or params.crop or params.flip
|
|
do_postprocess = params.restore_faces or params.upscale
|
|
|
|
shared.state.textinfo = "Initializing smart processing..."
|
|
shared.state.job_count = job_length
|
|
shared.state.job_no = 0
|
|
if do_preprocess:
|
|
output = process_pre(output, params)
|
|
|
|
if params.caption:
|
|
output = process_captions(output, params, all_captioners)
|
|
|
|
if do_postprocess:
|
|
output = process_post(output, params)
|
|
|
|
return output, f"Successfully processed {len(output)} images."
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
msg = f"Error processing images: {e}"
|
|
printi(msg)
|
|
|
|
return output, msg
|