mirror of https://github.com/vladmandic/automatic
252 lines
12 KiB
Python
252 lines
12 KiB
Python
import re
|
|
import os
|
|
import uuid
|
|
import string
|
|
import hashlib
|
|
import datetime
|
|
from pathlib import Path
|
|
from modules import shared, errors
|
|
|
|
|
|
debug = errors.log.trace if os.environ.get('SD_NAMEGEN_DEBUG', None) is not None else lambda *args, **kwargs: None
|
|
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
|
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
|
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
|
re_attention = re.compile(r'[\(*\[*](\w+)(:\d+(\.\d+))?[\)*\]*]|')
|
|
re_network = re.compile(r'\<\w+:(\w+)(:\d+(\.\d+))?\>|')
|
|
re_brackets = re.compile(r'[\([{})\]]')
|
|
NOTHING = object()
|
|
|
|
|
|
class FilenameGenerator:
|
|
replacements = {
|
|
'width': lambda self: self.image.width,
|
|
'height': lambda self: self.image.height,
|
|
'batch_number': lambda self: self.batch_number,
|
|
'iter_number': lambda self: self.iter_number,
|
|
'num': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
|
'generation_number': lambda self: NOTHING if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
|
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
|
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
|
'hash': lambda self: self.image_hash(),
|
|
'image_hash': lambda self: self.image_hash(),
|
|
'timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
|
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
|
|
|
'model': lambda self: shared.sd_model.sd_checkpoint_info.title if shared.sd_loaded and getattr(shared.sd_model, 'sd_checkpoint_info', None) is not None else '',
|
|
'model_shortname': lambda self: shared.sd_model.sd_checkpoint_info.model_name if shared.sd_loaded and getattr(shared.sd_model, 'sd_checkpoint_info', None) is not None else '',
|
|
'model_name': lambda self: shared.sd_model.sd_checkpoint_info.model_name if shared.sd_loaded and getattr(shared.sd_model, 'sd_checkpoint_info', None) is not None else '',
|
|
'model_type': lambda self: shared.sd_model_type if shared.sd_loaded else '',
|
|
'model_hash': lambda self: shared.sd_model.sd_checkpoint_info.shorthash if shared.sd_loaded and getattr(shared.sd_model, 'sd_checkpoint_info', None) is not None else '',
|
|
|
|
'prompt': lambda self: self.prompt_full(),
|
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
|
'prompt_words': lambda self: self.prompt_words(),
|
|
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
|
|
|
'sampler': lambda self: self.p and self.p.sampler_name,
|
|
'seed': lambda self: (self.seed and str(self.seed)) or '',
|
|
'steps': lambda self: self.p and getattr(self.p, 'steps', 0),
|
|
'cfg': lambda self: self.p and getattr(self.p, 'cfg_scale', 0),
|
|
'clip_skip': lambda self: self.p and getattr(self.p, 'clip_skip', 0),
|
|
'denoising': lambda self: self.p and getattr(self.p, 'denoising_strength', 0),
|
|
'styles': lambda self: (self.p and ", ".join([style for style in self.p.styles if not style == "None"])) or "None",
|
|
'uuid': lambda self: str(uuid.uuid4()),
|
|
}
|
|
default_time_format = '%Y%m%d%H%M%S'
|
|
|
|
def __init__(self, p, seed, prompt, image, grid=False):
|
|
if p is None:
|
|
debug('Filename generator init skip')
|
|
else:
|
|
debug(f'Filename generator init: {seed} {prompt}')
|
|
self.p = p
|
|
if seed is not None and int(seed) > 0:
|
|
self.seed = seed
|
|
elif p is not None and hasattr(p, 'all_seeds'):
|
|
self.seed = p.all_seeds[0]
|
|
elif p is not None and hasattr(p, 'seeds'):
|
|
self.seed = p.seeds[0]
|
|
else:
|
|
self.seed = p.seed if p is not None else 0
|
|
if prompt is not None:
|
|
self.prompt = prompt
|
|
else:
|
|
self.prompt = p.prompt if p is not None else ''
|
|
if isinstance(self.prompt, list):
|
|
self.prompt = ' '.join(self.prompt)
|
|
self.image = image
|
|
if not grid:
|
|
self.batch_number = NOTHING if self.p is None or getattr(self.p, 'batch_size', 1) == 1 else (self.p.batch_index + 1 if hasattr(self.p, 'batch_index') else NOTHING)
|
|
self.iter_number = NOTHING if self.p is None or getattr(self.p, 'n_iter', 1) == 1 else (self.p.iteration + 1 if hasattr(self.p, 'iteration') else NOTHING)
|
|
else:
|
|
self.batch_number = NOTHING
|
|
self.iter_number = NOTHING
|
|
|
|
def hasprompt(self, *args):
|
|
lower = self.prompt.lower()
|
|
if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None:
|
|
return None
|
|
outres = ""
|
|
for arg in args:
|
|
if arg != "":
|
|
division = arg.split("|")
|
|
expected = division[0].lower()
|
|
default = division[1] if len(division) > 1 else ""
|
|
if lower.find(expected) >= 0:
|
|
outres = f'{outres}{expected}'
|
|
else:
|
|
outres = outres if default == "" else f'{outres}{default}'
|
|
return outres
|
|
|
|
def image_hash(self):
|
|
if getattr(self, 'image', None) is None:
|
|
return None
|
|
import base64
|
|
from io import BytesIO
|
|
buffered = BytesIO()
|
|
self.image.save(buffered, format="JPEG")
|
|
img_str = base64.b64encode(buffered.getvalue())
|
|
shorthash = hashlib.sha256(img_str).hexdigest()[0:8]
|
|
return shorthash
|
|
|
|
def prompt_full(self):
|
|
return self.prompt_sanitize(self.prompt)
|
|
|
|
def prompt_words(self):
|
|
if getattr(self, 'prompt', None) is None:
|
|
return ''
|
|
no_attention = re_attention.sub(r'\1', self.prompt)
|
|
no_network = re_network.sub(r'\1', no_attention)
|
|
no_brackets = re_brackets.sub('', no_network)
|
|
words = [x for x in re_nonletters.split(no_brackets or "") if len(x) > 0]
|
|
prompt = " ".join(words[0:shared.opts.directories_max_prompt_words])
|
|
return self.prompt_sanitize(prompt)
|
|
|
|
def prompt_no_style(self):
|
|
if getattr(self, 'p', None) is None or getattr(self, 'prompt', None) is None:
|
|
return None
|
|
prompt_no_style = self.prompt
|
|
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
|
|
if len(style) > 0:
|
|
for part in style.split("{prompt}"):
|
|
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",")
|
|
prompt_no_style = prompt_no_style.replace(style, "")
|
|
return self.prompt_sanitize(prompt_no_style)
|
|
|
|
def datetime(self, *args):
|
|
import pytz
|
|
time_datetime = datetime.datetime.now()
|
|
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
|
try:
|
|
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
|
except pytz.exceptions.UnknownTimeZoneError:
|
|
time_zone = None
|
|
time_zone_time = time_datetime.astimezone(time_zone)
|
|
try:
|
|
formatted_time = time_zone_time.strftime(time_format)
|
|
except (ValueError, TypeError):
|
|
formatted_time = time_zone_time.strftime(self.default_time_format)
|
|
return formatted_time
|
|
|
|
def prompt_sanitize(self, prompt):
|
|
invalid_chars = '#<>:\'"\\|?*\n\t\r'
|
|
sanitized = prompt.translate({ ord(x): '_' for x in invalid_chars }).strip()
|
|
debug(f'Prompt sanitize: input="{prompt}" output={sanitized}')
|
|
return sanitized
|
|
|
|
def sanitize(self, filename):
|
|
invalid_chars = '\'"|?*\n\t\r' # <https://learn.microsoft.com/en-us/windows/win32/fileio/naming-a-file>
|
|
invalid_folder = ':'
|
|
invalid_files = ['CON', 'PRN', 'AUX', 'NUL', 'NULL', 'COM0', 'COM1', 'LPT0', 'LPT1']
|
|
invalid_prefix = ', '
|
|
invalid_suffix = '.,_ '
|
|
fn, ext = os.path.splitext(filename)
|
|
parts = Path(fn).parts
|
|
newparts = []
|
|
for i, part in enumerate(parts):
|
|
part = part.translate({ ord(x): '_' for x in invalid_chars })
|
|
if i > 0 or (len(part) >= 2 and part[1] != invalid_folder): # skip drive, otherwise remove
|
|
part = part.translate({ ord(x): '_' for x in invalid_folder })
|
|
part = part.lstrip(invalid_prefix).rstrip(invalid_suffix)
|
|
if part in invalid_files: # reserved names
|
|
[part := part.replace(word, '_') for word in invalid_files] # pylint: disable=expression-not-assigned
|
|
newparts.append(part)
|
|
fn = str(Path(*newparts))
|
|
max_length = max(256 - len(ext), os.statvfs(__file__).f_namemax - 32 if hasattr(os, 'statvfs') else 256 - len(ext))
|
|
while len(os.path.abspath(fn)) > max_length:
|
|
fn = fn[:-1]
|
|
fn += ext
|
|
debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
|
|
return fn
|
|
|
|
def sequence(self, fn, dirname, basename):
|
|
x = fn
|
|
if shared.opts.save_images_add_number or '[seq]' in fn:
|
|
if '[seq]' not in fn:
|
|
fn = os.path.join(os.path.dirname(fn), f"[seq]-{os.path.basename(fn)}")
|
|
basecount = get_next_sequence_number(dirname, basename)
|
|
for i in range(9999):
|
|
seq = f"{basecount + i:05}"
|
|
filename = fn.replace('[seq]', seq)
|
|
if not os.path.exists(filename):
|
|
debug(f'Prompt sequence: input="{fn}" seq={seq} output="{filename}"')
|
|
x = filename
|
|
break
|
|
return x
|
|
|
|
def apply(self, x):
|
|
res = ''
|
|
for m in re_pattern.finditer(x):
|
|
text, pattern = m.groups()
|
|
if pattern is None:
|
|
res += text
|
|
continue
|
|
pattern_args = []
|
|
while True:
|
|
m = re_pattern_arg.match(pattern)
|
|
if m is None:
|
|
break
|
|
pattern, arg = m.groups()
|
|
pattern_args.insert(0, arg)
|
|
if isinstance(pattern, list):
|
|
pattern = ' '.join(pattern)
|
|
fun = self.replacements.get(pattern.lower(), None)
|
|
if fun is not None:
|
|
try:
|
|
debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}')
|
|
replacement = fun(self, *pattern_args)
|
|
except Exception as e:
|
|
replacement = None
|
|
errors.display(e, 'Filename apply pattern')
|
|
shared.log.error(f'Filename apply pattern: {x} {e}')
|
|
if replacement == NOTHING:
|
|
continue
|
|
if replacement is not None:
|
|
res += text + str(replacement).replace('/', '-').replace('\\', '-')
|
|
continue
|
|
else:
|
|
res += text + f'[{pattern}]' # reinsert unknown pattern
|
|
return res
|
|
|
|
|
|
def get_next_sequence_number(path, basename):
|
|
"""
|
|
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
|
"""
|
|
result = -1
|
|
if basename != '':
|
|
basename = f"{basename}-"
|
|
prefix_length = len(basename)
|
|
if not os.path.isdir(path):
|
|
return 0
|
|
for p in os.listdir(path):
|
|
if p.startswith(basename):
|
|
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
|
try:
|
|
result = max(int(parts[0]), result)
|
|
except ValueError:
|
|
pass
|
|
return result + 1
|