Add more tagging / png info / gen params options
parent
e327f40c18
commit
cfc8cc1581
|
|
@ -1,3 +0,0 @@
|
|||
[submodule "tools"]
|
||||
path = tools
|
||||
url = https://github.com/tsngo/stable-diffusion-webui-tools.git
|
||||
|
|
@ -1,3 +1,6 @@
|
|||
import copy
|
||||
from hashlib import md5
|
||||
import json
|
||||
import os
|
||||
from modules import sd_samplers, shared, scripts, script_callbacks
|
||||
from modules.script_callbacks import ImageSaveParams
|
||||
|
|
@ -6,12 +9,14 @@ from modules.processing import Processed, process_images, StableDiffusionProcess
|
|||
from modules.shared import opts, OptionInfo
|
||||
from modules.paths import script_path
|
||||
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import clip
|
||||
import platform
|
||||
from launch import is_installed, run_pip
|
||||
from modules.generation_parameters_copypaste import parse_generation_parameters
|
||||
|
||||
extension_name = "Aesthetic Image Scorer"
|
||||
if platform.system() == "Windows" and not is_installed("pywin32"):
|
||||
|
|
@ -85,41 +90,169 @@ def get_score(image):
|
|||
score = predictor(torch.from_numpy(image_features).to(device).float())
|
||||
return score.item()
|
||||
|
||||
class AISGroup:
|
||||
def __init__(self, name="", apply_choices=lambda choice, choice_values: {"tags": [f"{choice}_{choice_values[choice]}"]}, default=[]):
|
||||
gen_params = lambda p: {
|
||||
"steps": p.steps,
|
||||
"sampler": sd_samplers.samplers[p.sampler_index].name,
|
||||
"cfg_scale": p.cfg_scale,
|
||||
"seed": p.seed,
|
||||
"width": p.width,
|
||||
"height": p.height,
|
||||
"model": shared.sd_model.sd_model_hash,
|
||||
"prompt": p.prompt,
|
||||
"negative_prompt": p.negative_prompt,
|
||||
}
|
||||
self.choice_processors = {
|
||||
"aesthetic_score": lambda params: params.pnginfo["aesthetic_score"] if "aesthetic_score" in params.pnginfo else round(get_score(params.image), 1),
|
||||
"sampler": lambda params: sd_samplers.samplers[params.p.sampler_index].name if params.p is not None and params.p.sampler else None,
|
||||
"cfg_scale": lambda params: params.p.cfg_scale if params.p is not None and params.p.cfg_scale else None,
|
||||
"sd_model_hash": lambda params: shared.sd_model.sd_model_hash,
|
||||
"seed": lambda params: str(int(params.p.seed)) if params.p is not None and params.p.seed else None,
|
||||
"hash": lambda params: md5(json.dumps(gen_params(params.p)).encode('utf-8')).hexdigest() if params.p is not None else None,
|
||||
}
|
||||
self.name = name
|
||||
self.apply_choices = apply_choices
|
||||
self.default = default
|
||||
|
||||
def get_choices(self):
|
||||
return list(self.choice_processors.keys())
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
def selected(self, opts):
|
||||
return opts.__getattr__(self.name)
|
||||
|
||||
def get_choice_processors(self):
|
||||
return self.choice_processors
|
||||
|
||||
def apply(self, choice_values, applied, opts):
|
||||
for choice in self.get_choices():
|
||||
if choice in self.selected(opts) and choice in choice_values:
|
||||
applied_choices = self.apply_choices(choice, choice_values)
|
||||
for key, value in applied_choices.items():
|
||||
if key not in applied:
|
||||
applied[key] = self.get_default()
|
||||
if isinstance(applied[key], dict):
|
||||
applied[key].update(value)
|
||||
else:
|
||||
applied[key] = applied[key] + value
|
||||
return applied
|
||||
|
||||
def get_default(self):
|
||||
return self.default
|
||||
|
||||
class AISGroups:
|
||||
def __init__(self, groups=[]):
|
||||
self.groups = groups
|
||||
self.choices = {}
|
||||
self.choice_processors = {}
|
||||
for group in groups:
|
||||
self.choice_processors.update(group.get_choice_processors())
|
||||
self.choices = list(self.choice_processors.keys())
|
||||
|
||||
def all_selected(self, opts):
|
||||
selected = {}
|
||||
for group in self.groups:
|
||||
for select in group.selected(opts):
|
||||
selected.update({select: 0})
|
||||
return selected.keys()
|
||||
|
||||
def apply(self, opts: list, params: ImageSaveParams):
|
||||
parsed_info = parse_generation_parameters(
|
||||
params.pnginfo.get("parameters", ""))
|
||||
params_hash = md5(json.dumps(
|
||||
str(vars(params)), default=lambda o: dir(o), sort_keys=True).encode('utf-8')).hexdigest()
|
||||
applied = {}
|
||||
if params_hash not in output_cache:
|
||||
choice_values = {}
|
||||
choices_selected = self.all_selected(opts)
|
||||
for choice, processor in self.choice_processors.items():
|
||||
if choice in choices_selected:
|
||||
choice_values.update({choice: processor(params)})
|
||||
|
||||
if "seed" in choice_values and int(choice_values["seed"]) == -1:
|
||||
choice_values["seed"] = int(parsed_info["Seed"]) if "Seed" in parsed_info else int(choice_values["seed"])
|
||||
|
||||
for group in self.groups:
|
||||
applied = group.apply(choice_values, applied, opts)
|
||||
|
||||
expected_keys = ["tags", "categories", "info", "pnginfo"]
|
||||
for key in expected_keys:
|
||||
if key not in applied:
|
||||
applied[key] = []
|
||||
|
||||
output_cache[params_hash] = applied
|
||||
else:
|
||||
applied = output_cache[params_hash]
|
||||
output_cache.clear()
|
||||
|
||||
return applied
|
||||
|
||||
|
||||
ais_exif_pnginfo_choices = AISGroup(name="ais_exif_pnginfo_group", apply_choices=lambda choice, choice_values: {
|
||||
"pnginfo": {choice: choice_values[choice]}}, default={})
|
||||
ais_windows_tag_group_choices = AISGroup(name="ais_windows_tag_group")
|
||||
ais_windows_category_group_choices = AISGroup(name="ais_windows_category_group", apply_choices=lambda choice, choice_values: {
|
||||
"categories": [f"{choice}_{choice_values[choice]}"]})
|
||||
ais_generation_params_text_choices = AISGroup(name="ais_generation_params_text_group", apply_choices=lambda choice, choice_values: {
|
||||
"info": {choice: choice_values[choice]}}, default={})
|
||||
|
||||
ais_group = AISGroups([
|
||||
ais_windows_tag_group_choices,
|
||||
ais_windows_category_group_choices,
|
||||
ais_generation_params_text_choices,
|
||||
ais_exif_pnginfo_choices,
|
||||
])
|
||||
|
||||
output_cache = {}
|
||||
|
||||
def on_ui_settings():
|
||||
options = {}
|
||||
|
||||
options.update(shared.options_section(('ais', extension_name), {
|
||||
"ais_add_exif": OptionInfo(False, "Save score as EXIF or PNG Info Chunk"),
|
||||
"ais_windows_tag": OptionInfo(False, "Save score as tag (Windows Only)"),
|
||||
ais_exif_pnginfo_choices.get_name(): OptionInfo([], "Save score as EXIF or PNG Info Chunk", gr.CheckboxGroup, {"choices": ais_exif_pnginfo_choices.get_choices()}),
|
||||
ais_windows_tag_group_choices.get_name(): OptionInfo([], "Save tags (Windows only)", gr.CheckboxGroup, {"choices": ais_windows_tag_group_choices.get_choices()}),
|
||||
ais_windows_category_group_choices.get_name(): OptionInfo([], "Save category (Windows only)", gr.CheckboxGroup, {"choices": ais_windows_category_group_choices.get_choices()}),
|
||||
ais_generation_params_text_choices.get_name(): OptionInfo([], "Save generation params text", gr.CheckboxGroup, {"choices": ais_generation_params_text_choices.get_choices()}),
|
||||
"ais_force_cpu": OptionInfo(False, "Force CPU (Requires Custom Script Reload)"),
|
||||
}))
|
||||
|
||||
opts.add_option("ais_add_exif", options["ais_add_exif"])
|
||||
opts.add_option("ais_windows_tag", options["ais_windows_tag"])
|
||||
opts.add_option(ais_exif_pnginfo_choices.get_name(),
|
||||
options[ais_exif_pnginfo_choices.get_name()])
|
||||
opts.add_option(ais_windows_tag_group_choices.get_name(),
|
||||
options[ais_windows_tag_group_choices.get_name()])
|
||||
opts.add_option(ais_windows_category_group_choices.get_name(),
|
||||
options[ais_windows_category_group_choices.get_name()])
|
||||
opts.add_option(ais_generation_params_text_choices.get_name(),
|
||||
options[ais_generation_params_text_choices.get_name()])
|
||||
opts.add_option("ais_force_cpu", options["ais_force_cpu"])
|
||||
|
||||
|
||||
def on_before_image_saved(params: ImageSaveParams):
|
||||
if opts.ais_add_exif:
|
||||
score = round(get_score(params.image), 1)
|
||||
params.pnginfo.update({
|
||||
"aesthetic_score": score,
|
||||
})
|
||||
|
||||
|
||||
def on_before_image_saved(params: ImageSaveParams):
|
||||
applied = ais_group.apply(opts, params)
|
||||
if len(applied["pnginfo"]) > 0:
|
||||
params.pnginfo.update(applied["pnginfo"])
|
||||
|
||||
if len(applied["info"]) > 0:
|
||||
parts = []
|
||||
for label, value in applied["info"].items():
|
||||
parts.append(f"{label}: {value}")
|
||||
if len(parts) > 0:
|
||||
if len(params.pnginfo["parameters"]) > 0:
|
||||
params.pnginfo["parameters"] += ", "
|
||||
params.pnginfo["parameters"] += f"{', '.join(parts)}\n"
|
||||
|
||||
return params
|
||||
|
||||
def on_image_saved(params: ImageSaveParams):
|
||||
filename = os.path.realpath(os.path.join(script_path, params.filename))
|
||||
if "aesthetic_score" in params.pnginfo:
|
||||
score = params.pnginfo["aesthetic_score"]
|
||||
applied = ais_group.apply(opts, params)
|
||||
if tag_files is not None:
|
||||
tag_files(filename=filename, tags=applied["tags"], categories=applied["categories"],
|
||||
log_prefix=f"{extension_name}: ")
|
||||
else:
|
||||
score = round(get_score(params.image), 1)
|
||||
if score is not None and opts.ais_windows_tag:
|
||||
if tag_files is not None:
|
||||
tags = [f"aesthetic_score_{score}"]
|
||||
tag_files(filename=filename, tags=tags, log_prefix=f"{extension_name}: ")
|
||||
else:
|
||||
print(f"{extension_name}: Unable to load tagging script")
|
||||
|
||||
print(f"{extension_name}: Unable to load tagging script")
|
||||
|
||||
class AestheticImageScorer(scripts.Script):
|
||||
def title(self):
|
||||
|
|
|
|||
1
tools
1
tools
|
|
@ -1 +0,0 @@
|
|||
Subproject commit cb92f1a566ba9c6e0dcbb2393d69f7bc4bd0a9e2
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import sys
|
||||
# requires pywin32 (See requirements.txt. Supports Windows only obviously)
|
||||
# also recommend FileMeta https://github.com/Dijji/FileMeta to allow tags for PNG
|
||||
import pythoncom
|
||||
try:
|
||||
from win32com.propsys import propsys
|
||||
from win32com.shell import shellcon
|
||||
except:
|
||||
propsys = None
|
||||
shellcon = None
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
def set_property(file="", property="System.Keywords", values=[], remove_values=[], remove_all=False, ps=None):
|
||||
array_properties = ["System.Keywords", "System.Category"]
|
||||
if len(values) == 0 and len(remove_values) == 0 and not remove_all:
|
||||
return ps
|
||||
# get property store for a given shell item (here a file)
|
||||
try:
|
||||
pk = propsys.PSGetPropertyKeyFromName(property)
|
||||
except:
|
||||
pythoncom.CoInitialize()
|
||||
pk = propsys.PSGetPropertyKeyFromName(property)
|
||||
|
||||
if (ps is None):
|
||||
ps = propsys.SHGetPropertyStoreFromParsingName(os.path.realpath(file), None, shellcon.GPS_READWRITE, propsys.IID_IPropertyStore)
|
||||
|
||||
if property in array_properties:
|
||||
# read & print existing (or not) property value, System.Keywords type is an array of string
|
||||
existingValues = ps.GetValue(pk).GetValue()
|
||||
if existingValues == None:
|
||||
existingValues = []
|
||||
filteredValues = []
|
||||
|
||||
if not remove_all:
|
||||
for value in existingValues:
|
||||
if value in remove_values:
|
||||
continue
|
||||
filteredValues.append(value)
|
||||
|
||||
# build an array of string type PROPVARIANT
|
||||
newValue = propsys.PROPVARIANTType(filteredValues + values, pythoncom.VT_VECTOR | pythoncom.VT_BSTR)
|
||||
|
||||
# write property
|
||||
ps.SetValue(pk, newValue)
|
||||
|
||||
return ps
|
||||
|
||||
def tag_files(files_glob="", tags=[], remove_tags=[], remove_all_tags=False, filename="", comment="", categories=[], remove_categories=[], remove_all_categories=False, log_prefix=""):
|
||||
if propsys == None or shellcon == None:
|
||||
return
|
||||
|
||||
if files_glob=="":
|
||||
files = [os.path.realpath(filename)]
|
||||
else:
|
||||
files = glob.glob(files_glob)
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
ps = set_property(file=file, property="System.Keywords", values=tags,
|
||||
remove_values=remove_tags, remove_all=remove_all_tags)
|
||||
ps = set_property(file=file, property="System.Category", values=categories,
|
||||
remove_values=remove_categories, remove_all=remove_all_categories, ps=ps)
|
||||
ps.Commit()
|
||||
except:
|
||||
print(f"{log_prefix}Unable to write tag or category for {file}")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-f", "--files-glob", type=str, default="", help="glob pattern to files to tag", required=True)
|
||||
parser.add_argument("-t", "--tags", type=str, default="", help="comma separated list of tags", required=False)
|
||||
parser.add_argument("-r", "--remove-tags", type=str, default="", help="comma separated list of tags to remove", required=False)
|
||||
parser.add_argument("-c", "--categories", type=str, default="", help="comma separated list of categories add", required=False)
|
||||
parser.add_argument("-rc", "--remove-categories", type=str, default="", help="comma separated list of categories to remove", required=False)
|
||||
parser.add_argument("-rt", "--remove-all-tags", action="store_true", default=False, help="remove all tags", required=False)
|
||||
parser.add_argument("-rac", "--remove-all-categories", action="store_true", default=False, help="remove all tags", required=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.tags == "":
|
||||
args.tags = []
|
||||
else:
|
||||
args.tags = args.tags.split(',')
|
||||
|
||||
if args.categories == "":
|
||||
args.categories = []
|
||||
else:
|
||||
args.categories = args.categories.split(',')
|
||||
|
||||
if args.remove_categories == "":
|
||||
args.remove_categories = []
|
||||
else:
|
||||
args.remove_categories = args.remove_categories.split(',')
|
||||
|
||||
tag_files(**args.__dict__)
|
||||
Loading…
Reference in New Issue