Update to use alternate before image save implementation

main
Trung Ngo 2022-10-26 07:18:19 -05:00
parent e990b3cadd
commit 3ad99b5eeb
1 changed files with 18 additions and 13 deletions

View File

@ -1,4 +1,5 @@
from modules import sd_samplers, shared, scripts, script_callbacks
from modules.script_callbacks import ImageSaveParams
import modules.images as images
from modules.processing import Processed, process_images, StableDiffusionProcessing
from modules.shared import opts, OptionInfo
@ -26,6 +27,7 @@ if not Path(state_name).exists():
with open(state_name, "wb") as f:
f.write(r.content)
class AestheticPredictor(nn.Module):
def __init__(self, input_size):
super().__init__()
@ -44,6 +46,7 @@ class AestheticPredictor(nn.Module):
def forward(self, x):
return self.layers(x)
try:
force_cpu = opts.ais_force_cpu
except:
@ -53,7 +56,7 @@ if force_cpu:
print("Aesthtic Image Scorer: Forcing prediction model to run on CPU")
device = "cuda" if not force_cpu and torch.cuda.is_available() else "cpu"
# load the model you trained previously or the model available in this repo
pt_state = torch.load(state_name, map_location=torch.device(device=device))
pt_state = torch.load(state_name, map_location=torch.device(device=device))
# CLIP embedding dim is 768 for CLIP ViT L 14
predictor = AestheticPredictor(768)
@ -63,6 +66,7 @@ predictor.eval()
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
def get_image_features(image, device=device, model=clip_model, preprocess=clip_preprocess):
image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
@ -72,6 +76,7 @@ def get_image_features(image, device=device, model=clip_model, preprocess=clip_p
image_features = image_features.cpu().detach().numpy()
return image_features
def get_score(image):
image_features = get_image_features(image)
score = predictor(torch.from_numpy(image_features).to(device).float())
@ -85,34 +90,33 @@ def on_ui_settings():
"ais_windows_tag": OptionInfo(False, "Save score as tag (Windows Only)"),
"ais_force_cpu": OptionInfo(False, "Force CPU (Requires Custom Script Reload)"),
}))
opts.add_option("ais_add_exif", options["ais_add_exif"])
opts.add_option("ais_windows_tag", options["ais_windows_tag"])
opts.add_option("ais_force_cpu", options["ais_force_cpu"])
def on_before_image_saved(image, p, **kwargs):
def on_before_image_saved(params: ImageSaveParams):
if opts.ais_add_exif:
score = round(get_score(image), 1)
if "existing_info" not in kwargs or kwargs["existing_info"] is None:
kwargs["existing_info"] = {}
kwargs["existing_info"].update({
score = round(get_score(params.image), 1)
params.pnginfo.update({
"aesthetic_score": score,
})
return image, p, kwargs
def on_image_saved(image, p, fullfn, txt_fullfn, **kwargs):
if "existing_info" in kwargs and kwargs["existing_info"] is not None and "aesthetic_score" in kwargs["existing_info"]:
score = kwargs["existing_info"]["aesthetic_score"]
def on_image_saved(params: ImageSaveParams):
if "aesthetic_score" in params.pnginfo:
score = params.pnginfo["aesthetic_score"]
else:
score = round(get_score(image), 1)
score = round(get_score(params.image), 1)
if score is not None and opts.ais_windows_tag:
if tag_files is not None:
tags = [f"aesthetic_score_{score}"]
tag_files(filename=fullfn, tags=tags)
tag_files(filename=params.filename, tags=tags)
else:
print("Aesthetic Image Scorer: Unable to load Windows tagging script")
class AestheticImageScorer(scripts.Script):
def title(self):
return "Aesthetic Image Scorer"
@ -126,6 +130,7 @@ class AestheticImageScorer(scripts.Script):
def process(self, p):
pass
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_before_image_saved(on_before_image_saved)
script_callbacks.on_image_saved(on_image_saved)