132 lines
4.6 KiB
Python
132 lines
4.6 KiB
Python
from modules import sd_samplers, shared, scripts, script_callbacks
|
|
import modules.images as images
|
|
from modules.processing import Processed, process_images, StableDiffusionProcessing
|
|
from modules.shared import opts, OptionInfo
|
|
|
|
from pathlib import Path
|
|
import torch
|
|
import torch.nn as nn
|
|
import clip
|
|
import platform
|
|
from launch import is_installed, run_pip
|
|
|
|
if platform.system() == "Windows" and not is_installed("pywin32"):
|
|
run_pip(f"install pywin32", "pywin32")
|
|
try:
|
|
from tools.add_tags import tag_files
|
|
except:
|
|
print("Aesthetic Image Scorer: Unable to load Windows tagging script")
|
|
tag_files = None
|
|
|
|
state_name = "sac+logos+ava1-l14-linearMSE.pth"
|
|
if not Path(state_name).exists():
|
|
url = f"https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/{state_name}?raw=true"
|
|
import requests
|
|
r = requests.get(url)
|
|
with open(state_name, "wb") as f:
|
|
f.write(r.content)
|
|
|
|
class AestheticPredictor(nn.Module):
|
|
def __init__(self, input_size):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.layers = nn.Sequential(
|
|
nn.Linear(self.input_size, 1024),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(1024, 128),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(128, 64),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(64, 16),
|
|
nn.Linear(16, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
try:
|
|
force_cpu = opts.ais_force_cpu
|
|
except:
|
|
force_cpu = False
|
|
|
|
if force_cpu:
|
|
print("Aesthtic Image Scorer: Forcing prediction model to run on CPU")
|
|
device = "cuda" if not force_cpu and torch.cuda.is_available() else "cpu"
|
|
# load the model you trained previously or the model available in this repo
|
|
pt_state = torch.load(state_name, map_location=torch.device(device=device))
|
|
|
|
# CLIP embedding dim is 768 for CLIP ViT L 14
|
|
predictor = AestheticPredictor(768)
|
|
predictor.load_state_dict(pt_state)
|
|
predictor.to(device)
|
|
predictor.eval()
|
|
|
|
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
|
|
|
|
def get_image_features(image, device=device, model=clip_model, preprocess=clip_preprocess):
|
|
image = preprocess(image).unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
image_features = model.encode_image(image)
|
|
# l2 normalize
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
image_features = image_features.cpu().detach().numpy()
|
|
return image_features
|
|
|
|
def get_score(image):
|
|
image_features = get_image_features(image)
|
|
score = predictor(torch.from_numpy(image_features).to(device).float())
|
|
return score.item()
|
|
|
|
|
|
def on_ui_settings():
|
|
options = {}
|
|
options.update(shared.options_section(('ais', "Aesthetic Image Scorer"), {
|
|
"ais_add_exif": OptionInfo(False, "Save score as EXIF or PNG Info Chunk"),
|
|
"ais_windows_tag": OptionInfo(False, "Save score as tag (Windows Only)"),
|
|
"ais_force_cpu": OptionInfo(False, "Force CPU (Requires Custom Script Reload)"),
|
|
}))
|
|
|
|
opts.add_option("ais_add_exif", options["ais_add_exif"])
|
|
opts.add_option("ais_windows_tag", options["ais_windows_tag"])
|
|
opts.add_option("ais_force_cpu", options["ais_force_cpu"])
|
|
|
|
|
|
def on_before_image_saved(image, p, **kwargs):
|
|
if opts.ais_add_exif:
|
|
score = round(get_score(image), 1)
|
|
if "existing_info" not in kwargs or kwargs["existing_info"] is None:
|
|
kwargs["existing_info"] = {}
|
|
kwargs["existing_info"].update({
|
|
"aesthetic_score": score,
|
|
})
|
|
return image, p, kwargs
|
|
|
|
def on_image_saved(image, p, fullfn, txt_fullfn, **kwargs):
|
|
if "existing_info" in kwargs and kwargs["existing_info"] is not None and "aesthetic_score" in kwargs["existing_info"]:
|
|
score = kwargs["existing_info"]["aesthetic_score"]
|
|
else:
|
|
score = round(get_score(image), 1)
|
|
if score is not None and opts.ais_windows_tag:
|
|
if tag_files is not None:
|
|
tags = [f"aesthetic_score_{score}"]
|
|
tag_files(filename=fullfn, tags=tags)
|
|
else:
|
|
print("Aesthetic Image Scorer: Unable to load Windows tagging script")
|
|
|
|
class AestheticImageScorer(scripts.Script):
|
|
def title(self):
|
|
return "Aesthetic Image Scorer"
|
|
|
|
def show(self, is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def ui(self, is_img2img):
|
|
return []
|
|
|
|
def process(self, p):
|
|
pass
|
|
|
|
script_callbacks.on_ui_settings(on_ui_settings)
|
|
script_callbacks.on_before_image_saved(on_before_image_saved)
|
|
script_callbacks.on_image_saved(on_image_saved)
|