Add more tagging / png info / gen params options

main
Trung Ngo 2022-10-31 23:33:25 -05:00
parent e327f40c18
commit cfc8cc1581
4 changed files with 255 additions and 27 deletions

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "tools"]
path = tools
url = https://github.com/tsngo/stable-diffusion-webui-tools.git

View File

@ -1,3 +1,6 @@
import copy
from hashlib import md5
import json
import os
from modules import sd_samplers, shared, scripts, script_callbacks
from modules.script_callbacks import ImageSaveParams
@ -6,12 +9,14 @@ from modules.processing import Processed, process_images, StableDiffusionProcess
from modules.shared import opts, OptionInfo
from modules.paths import script_path
import gradio as gr
from pathlib import Path
import torch
import torch.nn as nn
import clip
import platform
from launch import is_installed, run_pip
from modules.generation_parameters_copypaste import parse_generation_parameters
extension_name = "Aesthetic Image Scorer"
if platform.system() == "Windows" and not is_installed("pywin32"):
@ -85,41 +90,169 @@ def get_score(image):
score = predictor(torch.from_numpy(image_features).to(device).float())
return score.item()
class AISGroup:
def __init__(self, name="", apply_choices=lambda choice, choice_values: {"tags": [f"{choice}_{choice_values[choice]}"]}, default=[]):
gen_params = lambda p: {
"steps": p.steps,
"sampler": sd_samplers.samplers[p.sampler_index].name,
"cfg_scale": p.cfg_scale,
"seed": p.seed,
"width": p.width,
"height": p.height,
"model": shared.sd_model.sd_model_hash,
"prompt": p.prompt,
"negative_prompt": p.negative_prompt,
}
self.choice_processors = {
"aesthetic_score": lambda params: params.pnginfo["aesthetic_score"] if "aesthetic_score" in params.pnginfo else round(get_score(params.image), 1),
"sampler": lambda params: sd_samplers.samplers[params.p.sampler_index].name if params.p is not None and params.p.sampler else None,
"cfg_scale": lambda params: params.p.cfg_scale if params.p is not None and params.p.cfg_scale else None,
"sd_model_hash": lambda params: shared.sd_model.sd_model_hash,
"seed": lambda params: str(int(params.p.seed)) if params.p is not None and params.p.seed else None,
"hash": lambda params: md5(json.dumps(gen_params(params.p)).encode('utf-8')).hexdigest() if params.p is not None else None,
}
self.name = name
self.apply_choices = apply_choices
self.default = default
def get_choices(self):
return list(self.choice_processors.keys())
def get_name(self):
return self.name
def selected(self, opts):
return opts.__getattr__(self.name)
def get_choice_processors(self):
return self.choice_processors
def apply(self, choice_values, applied, opts):
for choice in self.get_choices():
if choice in self.selected(opts) and choice in choice_values:
applied_choices = self.apply_choices(choice, choice_values)
for key, value in applied_choices.items():
if key not in applied:
applied[key] = self.get_default()
if isinstance(applied[key], dict):
applied[key].update(value)
else:
applied[key] = applied[key] + value
return applied
def get_default(self):
return self.default
class AISGroups:
def __init__(self, groups=[]):
self.groups = groups
self.choices = {}
self.choice_processors = {}
for group in groups:
self.choice_processors.update(group.get_choice_processors())
self.choices = list(self.choice_processors.keys())
def all_selected(self, opts):
selected = {}
for group in self.groups:
for select in group.selected(opts):
selected.update({select: 0})
return selected.keys()
def apply(self, opts: list, params: ImageSaveParams):
parsed_info = parse_generation_parameters(
params.pnginfo.get("parameters", ""))
params_hash = md5(json.dumps(
str(vars(params)), default=lambda o: dir(o), sort_keys=True).encode('utf-8')).hexdigest()
applied = {}
if params_hash not in output_cache:
choice_values = {}
choices_selected = self.all_selected(opts)
for choice, processor in self.choice_processors.items():
if choice in choices_selected:
choice_values.update({choice: processor(params)})
if "seed" in choice_values and int(choice_values["seed"]) == -1:
choice_values["seed"] = int(parsed_info["Seed"]) if "Seed" in parsed_info else int(choice_values["seed"])
for group in self.groups:
applied = group.apply(choice_values, applied, opts)
expected_keys = ["tags", "categories", "info", "pnginfo"]
for key in expected_keys:
if key not in applied:
applied[key] = []
output_cache[params_hash] = applied
else:
applied = output_cache[params_hash]
output_cache.clear()
return applied
ais_exif_pnginfo_choices = AISGroup(name="ais_exif_pnginfo_group", apply_choices=lambda choice, choice_values: {
"pnginfo": {choice: choice_values[choice]}}, default={})
ais_windows_tag_group_choices = AISGroup(name="ais_windows_tag_group")
ais_windows_category_group_choices = AISGroup(name="ais_windows_category_group", apply_choices=lambda choice, choice_values: {
"categories": [f"{choice}_{choice_values[choice]}"]})
ais_generation_params_text_choices = AISGroup(name="ais_generation_params_text_group", apply_choices=lambda choice, choice_values: {
"info": {choice: choice_values[choice]}}, default={})
ais_group = AISGroups([
ais_windows_tag_group_choices,
ais_windows_category_group_choices,
ais_generation_params_text_choices,
ais_exif_pnginfo_choices,
])
output_cache = {}
def on_ui_settings():
options = {}
options.update(shared.options_section(('ais', extension_name), {
"ais_add_exif": OptionInfo(False, "Save score as EXIF or PNG Info Chunk"),
"ais_windows_tag": OptionInfo(False, "Save score as tag (Windows Only)"),
ais_exif_pnginfo_choices.get_name(): OptionInfo([], "Save score as EXIF or PNG Info Chunk", gr.CheckboxGroup, {"choices": ais_exif_pnginfo_choices.get_choices()}),
ais_windows_tag_group_choices.get_name(): OptionInfo([], "Save tags (Windows only)", gr.CheckboxGroup, {"choices": ais_windows_tag_group_choices.get_choices()}),
ais_windows_category_group_choices.get_name(): OptionInfo([], "Save category (Windows only)", gr.CheckboxGroup, {"choices": ais_windows_category_group_choices.get_choices()}),
ais_generation_params_text_choices.get_name(): OptionInfo([], "Save generation params text", gr.CheckboxGroup, {"choices": ais_generation_params_text_choices.get_choices()}),
"ais_force_cpu": OptionInfo(False, "Force CPU (Requires Custom Script Reload)"),
}))
opts.add_option("ais_add_exif", options["ais_add_exif"])
opts.add_option("ais_windows_tag", options["ais_windows_tag"])
opts.add_option(ais_exif_pnginfo_choices.get_name(),
options[ais_exif_pnginfo_choices.get_name()])
opts.add_option(ais_windows_tag_group_choices.get_name(),
options[ais_windows_tag_group_choices.get_name()])
opts.add_option(ais_windows_category_group_choices.get_name(),
options[ais_windows_category_group_choices.get_name()])
opts.add_option(ais_generation_params_text_choices.get_name(),
options[ais_generation_params_text_choices.get_name()])
opts.add_option("ais_force_cpu", options["ais_force_cpu"])
def on_before_image_saved(params: ImageSaveParams):
if opts.ais_add_exif:
score = round(get_score(params.image), 1)
params.pnginfo.update({
"aesthetic_score": score,
})
def on_before_image_saved(params: ImageSaveParams):
applied = ais_group.apply(opts, params)
if len(applied["pnginfo"]) > 0:
params.pnginfo.update(applied["pnginfo"])
if len(applied["info"]) > 0:
parts = []
for label, value in applied["info"].items():
parts.append(f"{label}: {value}")
if len(parts) > 0:
if len(params.pnginfo["parameters"]) > 0:
params.pnginfo["parameters"] += ", "
params.pnginfo["parameters"] += f"{', '.join(parts)}\n"
return params
def on_image_saved(params: ImageSaveParams):
filename = os.path.realpath(os.path.join(script_path, params.filename))
if "aesthetic_score" in params.pnginfo:
score = params.pnginfo["aesthetic_score"]
applied = ais_group.apply(opts, params)
if tag_files is not None:
tag_files(filename=filename, tags=applied["tags"], categories=applied["categories"],
log_prefix=f"{extension_name}: ")
else:
score = round(get_score(params.image), 1)
if score is not None and opts.ais_windows_tag:
if tag_files is not None:
tags = [f"aesthetic_score_{score}"]
tag_files(filename=filename, tags=tags, log_prefix=f"{extension_name}: ")
else:
print(f"{extension_name}: Unable to load tagging script")
print(f"{extension_name}: Unable to load tagging script")
class AestheticImageScorer(scripts.Script):
def title(self):

1
tools

@ -1 +0,0 @@
Subproject commit cb92f1a566ba9c6e0dcbb2393d69f7bc4bd0a9e2

99
tools/add_tags.py Normal file
View File

@ -0,0 +1,99 @@
import sys
# requires pywin32 (See requirements.txt. Supports Windows only obviously)
# also recommend FileMeta https://github.com/Dijji/FileMeta to allow tags for PNG
import pythoncom
try:
from win32com.propsys import propsys
from win32com.shell import shellcon
except:
propsys = None
shellcon = None
import os
import argparse
import glob
script_path = os.path.dirname(os.path.realpath(__file__))
def set_property(file="", property="System.Keywords", values=[], remove_values=[], remove_all=False, ps=None):
array_properties = ["System.Keywords", "System.Category"]
if len(values) == 0 and len(remove_values) == 0 and not remove_all:
return ps
# get property store for a given shell item (here a file)
try:
pk = propsys.PSGetPropertyKeyFromName(property)
except:
pythoncom.CoInitialize()
pk = propsys.PSGetPropertyKeyFromName(property)
if (ps is None):
ps = propsys.SHGetPropertyStoreFromParsingName(os.path.realpath(file), None, shellcon.GPS_READWRITE, propsys.IID_IPropertyStore)
if property in array_properties:
# read & print existing (or not) property value, System.Keywords type is an array of string
existingValues = ps.GetValue(pk).GetValue()
if existingValues == None:
existingValues = []
filteredValues = []
if not remove_all:
for value in existingValues:
if value in remove_values:
continue
filteredValues.append(value)
# build an array of string type PROPVARIANT
newValue = propsys.PROPVARIANTType(filteredValues + values, pythoncom.VT_VECTOR | pythoncom.VT_BSTR)
# write property
ps.SetValue(pk, newValue)
return ps
def tag_files(files_glob="", tags=[], remove_tags=[], remove_all_tags=False, filename="", comment="", categories=[], remove_categories=[], remove_all_categories=False, log_prefix=""):
if propsys == None or shellcon == None:
return
if files_glob=="":
files = [os.path.realpath(filename)]
else:
files = glob.glob(files_glob)
for file in files:
try:
ps = set_property(file=file, property="System.Keywords", values=tags,
remove_values=remove_tags, remove_all=remove_all_tags)
ps = set_property(file=file, property="System.Category", values=categories,
remove_values=remove_categories, remove_all=remove_all_categories, ps=ps)
ps.Commit()
except:
print(f"{log_prefix}Unable to write tag or category for {file}")
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--files-glob", type=str, default="", help="glob pattern to files to tag", required=True)
parser.add_argument("-t", "--tags", type=str, default="", help="comma separated list of tags", required=False)
parser.add_argument("-r", "--remove-tags", type=str, default="", help="comma separated list of tags to remove", required=False)
parser.add_argument("-c", "--categories", type=str, default="", help="comma separated list of categories add", required=False)
parser.add_argument("-rc", "--remove-categories", type=str, default="", help="comma separated list of categories to remove", required=False)
parser.add_argument("-rt", "--remove-all-tags", action="store_true", default=False, help="remove all tags", required=False)
parser.add_argument("-rac", "--remove-all-categories", action="store_true", default=False, help="remove all tags", required=False)
if __name__ == "__main__":
args = parser.parse_args()
if args.tags == "":
args.tags = []
else:
args.tags = args.tags.split(',')
if args.categories == "":
args.categories = []
else:
args.categories = args.categories.split(',')
if args.remove_categories == "":
args.remove_categories = []
else:
args.remove_categories = args.remove_categories.split(',')
tag_files(**args.__dict__)