Update to use alternate before image save implementation
parent
e990b3cadd
commit
3ad99b5eeb
|
|
@ -1,4 +1,5 @@
|
|||
from modules import sd_samplers, shared, scripts, script_callbacks
|
||||
from modules.script_callbacks import ImageSaveParams
|
||||
import modules.images as images
|
||||
from modules.processing import Processed, process_images, StableDiffusionProcessing
|
||||
from modules.shared import opts, OptionInfo
|
||||
|
|
@ -26,6 +27,7 @@ if not Path(state_name).exists():
|
|||
with open(state_name, "wb") as f:
|
||||
f.write(r.content)
|
||||
|
||||
|
||||
class AestheticPredictor(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
|
|
@ -44,6 +46,7 @@ class AestheticPredictor(nn.Module):
|
|||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
try:
|
||||
force_cpu = opts.ais_force_cpu
|
||||
except:
|
||||
|
|
@ -53,7 +56,7 @@ if force_cpu:
|
|||
print("Aesthtic Image Scorer: Forcing prediction model to run on CPU")
|
||||
device = "cuda" if not force_cpu and torch.cuda.is_available() else "cpu"
|
||||
# load the model you trained previously or the model available in this repo
|
||||
pt_state = torch.load(state_name, map_location=torch.device(device=device))
|
||||
pt_state = torch.load(state_name, map_location=torch.device(device=device))
|
||||
|
||||
# CLIP embedding dim is 768 for CLIP ViT L 14
|
||||
predictor = AestheticPredictor(768)
|
||||
|
|
@ -63,6 +66,7 @@ predictor.eval()
|
|||
|
||||
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
|
||||
|
||||
|
||||
def get_image_features(image, device=device, model=clip_model, preprocess=clip_preprocess):
|
||||
image = preprocess(image).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
|
|
@ -72,6 +76,7 @@ def get_image_features(image, device=device, model=clip_model, preprocess=clip_p
|
|||
image_features = image_features.cpu().detach().numpy()
|
||||
return image_features
|
||||
|
||||
|
||||
def get_score(image):
|
||||
image_features = get_image_features(image)
|
||||
score = predictor(torch.from_numpy(image_features).to(device).float())
|
||||
|
|
@ -85,34 +90,33 @@ def on_ui_settings():
|
|||
"ais_windows_tag": OptionInfo(False, "Save score as tag (Windows Only)"),
|
||||
"ais_force_cpu": OptionInfo(False, "Force CPU (Requires Custom Script Reload)"),
|
||||
}))
|
||||
|
||||
|
||||
opts.add_option("ais_add_exif", options["ais_add_exif"])
|
||||
opts.add_option("ais_windows_tag", options["ais_windows_tag"])
|
||||
opts.add_option("ais_force_cpu", options["ais_force_cpu"])
|
||||
|
||||
|
||||
def on_before_image_saved(image, p, **kwargs):
|
||||
def on_before_image_saved(params: ImageSaveParams):
|
||||
if opts.ais_add_exif:
|
||||
score = round(get_score(image), 1)
|
||||
if "existing_info" not in kwargs or kwargs["existing_info"] is None:
|
||||
kwargs["existing_info"] = {}
|
||||
kwargs["existing_info"].update({
|
||||
score = round(get_score(params.image), 1)
|
||||
params.pnginfo.update({
|
||||
"aesthetic_score": score,
|
||||
})
|
||||
return image, p, kwargs
|
||||
|
||||
def on_image_saved(image, p, fullfn, txt_fullfn, **kwargs):
|
||||
if "existing_info" in kwargs and kwargs["existing_info"] is not None and "aesthetic_score" in kwargs["existing_info"]:
|
||||
score = kwargs["existing_info"]["aesthetic_score"]
|
||||
|
||||
def on_image_saved(params: ImageSaveParams):
|
||||
if "aesthetic_score" in params.pnginfo:
|
||||
score = params.pnginfo["aesthetic_score"]
|
||||
else:
|
||||
score = round(get_score(image), 1)
|
||||
score = round(get_score(params.image), 1)
|
||||
if score is not None and opts.ais_windows_tag:
|
||||
if tag_files is not None:
|
||||
tags = [f"aesthetic_score_{score}"]
|
||||
tag_files(filename=fullfn, tags=tags)
|
||||
tag_files(filename=params.filename, tags=tags)
|
||||
else:
|
||||
print("Aesthetic Image Scorer: Unable to load Windows tagging script")
|
||||
|
||||
|
||||
class AestheticImageScorer(scripts.Script):
|
||||
def title(self):
|
||||
return "Aesthetic Image Scorer"
|
||||
|
|
@ -126,6 +130,7 @@ class AestheticImageScorer(scripts.Script):
|
|||
def process(self, p):
|
||||
pass
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_before_image_saved(on_before_image_saved)
|
||||
script_callbacks.on_image_saved(on_image_saved)
|
||||
|
|
|
|||
Loading…
Reference in New Issue