diff --git a/scripts/image_scorer.py b/scripts/image_scorer.py index eef2a1e..feb9d04 100644 --- a/scripts/image_scorer.py +++ b/scripts/image_scorer.py @@ -1,4 +1,5 @@ from modules import sd_samplers, shared, scripts, script_callbacks +from modules.script_callbacks import ImageSaveParams import modules.images as images from modules.processing import Processed, process_images, StableDiffusionProcessing from modules.shared import opts, OptionInfo @@ -26,6 +27,7 @@ if not Path(state_name).exists(): with open(state_name, "wb") as f: f.write(r.content) + class AestheticPredictor(nn.Module): def __init__(self, input_size): super().__init__() @@ -44,6 +46,7 @@ class AestheticPredictor(nn.Module): def forward(self, x): return self.layers(x) + try: force_cpu = opts.ais_force_cpu except: @@ -53,7 +56,7 @@ if force_cpu: print("Aesthtic Image Scorer: Forcing prediction model to run on CPU") device = "cuda" if not force_cpu and torch.cuda.is_available() else "cpu" # load the model you trained previously or the model available in this repo -pt_state = torch.load(state_name, map_location=torch.device(device=device)) +pt_state = torch.load(state_name, map_location=torch.device(device=device)) # CLIP embedding dim is 768 for CLIP ViT L 14 predictor = AestheticPredictor(768) @@ -63,6 +66,7 @@ predictor.eval() clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) + def get_image_features(image, device=device, model=clip_model, preprocess=clip_preprocess): image = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): @@ -72,6 +76,7 @@ def get_image_features(image, device=device, model=clip_model, preprocess=clip_p image_features = image_features.cpu().detach().numpy() return image_features + def get_score(image): image_features = get_image_features(image) score = predictor(torch.from_numpy(image_features).to(device).float()) @@ -85,34 +90,33 @@ def on_ui_settings(): "ais_windows_tag": OptionInfo(False, "Save score as tag (Windows Only)"), "ais_force_cpu": OptionInfo(False, "Force CPU (Requires Custom Script Reload)"), })) - + opts.add_option("ais_add_exif", options["ais_add_exif"]) opts.add_option("ais_windows_tag", options["ais_windows_tag"]) opts.add_option("ais_force_cpu", options["ais_force_cpu"]) -def on_before_image_saved(image, p, **kwargs): +def on_before_image_saved(params: ImageSaveParams): if opts.ais_add_exif: - score = round(get_score(image), 1) - if "existing_info" not in kwargs or kwargs["existing_info"] is None: - kwargs["existing_info"] = {} - kwargs["existing_info"].update({ + score = round(get_score(params.image), 1) + params.pnginfo.update({ "aesthetic_score": score, }) - return image, p, kwargs -def on_image_saved(image, p, fullfn, txt_fullfn, **kwargs): - if "existing_info" in kwargs and kwargs["existing_info"] is not None and "aesthetic_score" in kwargs["existing_info"]: - score = kwargs["existing_info"]["aesthetic_score"] + +def on_image_saved(params: ImageSaveParams): + if "aesthetic_score" in params.pnginfo: + score = params.pnginfo["aesthetic_score"] else: - score = round(get_score(image), 1) + score = round(get_score(params.image), 1) if score is not None and opts.ais_windows_tag: if tag_files is not None: tags = [f"aesthetic_score_{score}"] - tag_files(filename=fullfn, tags=tags) + tag_files(filename=params.filename, tags=tags) else: print("Aesthetic Image Scorer: Unable to load Windows tagging script") + class AestheticImageScorer(scripts.Script): def title(self): return "Aesthetic Image Scorer" @@ -126,6 +130,7 @@ class AestheticImageScorer(scripts.Script): def process(self, p): pass + script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_before_image_saved(on_before_image_saved) script_callbacks.on_image_saved(on_image_saved)