New hrfix float syntax (#73)

* new hrfix float syntax

* respect user settings
pull/74/head
ljleb 2023-10-13 19:36:11 -04:00 committed by GitHub
parent a3ba3075c7
commit 95ff81c8cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 37 deletions

View File

@ -8,12 +8,12 @@ class ListExpression:
def __init__(self, expressions):
self.__expressions = expressions
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
if not self.__expressions:
return
def expr_extend_tensor(expr):
expr.extend_tensor(tensor_builder, steps_range, total_steps, context)
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
expr_extend_tensor(self.__expressions[0])
for expression in self.__expressions[1:]:
@ -29,15 +29,15 @@ class InterpolationExpression:
self.__steps = steps
self.__function_name = function_name if function_name is not None else 'linear'
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
def tensor_updater(expr):
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context)
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
tensor_builder.extrude(
[tensor_updater(expr) for expr in self.__expressions],
self.get_interpolation_function(steps_range, total_steps, context))
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
def get_interpolation_function(self, steps_range, total_steps, context):
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
steps = list(self.__steps)
if steps[0] is None:
steps[0] = LiftExpression(str(steps_range[0] - 1))
@ -48,10 +48,12 @@ class InterpolationExpression:
if step is None:
continue
step = _eval_float(step, steps_range, total_steps, context)
step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
if 0 < step < 1:
if use_old_scheduling and 0 < step < 1:
step *= total_steps
elif not use_old_scheduling and isinstance(step, float):
step = (step - int(is_hires)) * total_steps
else:
step += 1
@ -91,23 +93,23 @@ class AlternationExpression:
self.__expressions = expressions
self.__speed = speed
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
if self.__speed is None:
speed = None
else:
speed = _eval_float(self.__speed, steps_range, total_steps, context)
speed = _eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling)
if speed is None:
tensor_builder.append('[')
for expr_i, expr in enumerate(self.__expressions):
if expr_i >= 1:
tensor_builder.append('|')
expr.extend_tensor(tensor_builder, steps_range, total_steps, context)
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
tensor_builder.append(']')
return
def tensor_updater(expr):
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context)
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
exprs = self.__expressions + [self.__expressions[0]]
@ -132,18 +134,20 @@ class EditingExpression:
self.__expressions = expressions
self.__step = step
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
if self.__step is None:
tensor_builder.append('[')
for expr_i, expr in enumerate(self.__expressions):
expr.extend_tensor(tensor_builder, steps_range, total_steps, context)
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
tensor_builder.append(':')
tensor_builder.append(']')
return
step = _eval_float(self.__step, steps_range, total_steps, context)
if 0 < step < 1:
step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
if use_old_scheduling and 0 < step < 1:
step *= total_steps
elif not use_old_scheduling and isinstance(step, float):
step = (step - int(is_hires)) * total_steps
else:
step += 1
@ -152,7 +156,7 @@ class EditingExpression:
tensor_builder.append('[')
for expr_i, expr in enumerate(self.__expressions):
expr_steps_range = (steps_range[0], step) if expr_i == 0 and len(self.__expressions) >= 2 else (step, steps_range[1])
expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context)
expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
tensor_builder.append(':')
tensor_builder.append(f'{step - 1}]')
@ -167,14 +171,14 @@ class WeightedExpression:
self.__weight = weight
self.__positive = positive
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
tensor_builder.append(open_bracket)
self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context)
self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
if self.__weight is not None:
tensor_builder.append(':')
self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context)
self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
tensor_builder.append(close_bracket)
@ -185,11 +189,11 @@ class WeightInterpolationExpression:
self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.))
self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.))
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
steps_range_size = steps_range[1] - steps_range[0]
weight_begin = _eval_float(self.__weight_begin, steps_range, total_steps, context)
weight_end = _eval_float(self.__weight_end, steps_range, total_steps, context)
weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
for i in range(steps_range_size):
step = i + steps_range[0]
@ -201,7 +205,7 @@ class WeightInterpolationExpression:
if step + 1 < steps_range[1]:
weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context)
weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
class DeclarationExpression:
@ -211,10 +215,10 @@ class DeclarationExpression:
self.__target = target
self.__parameters = parameters
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
updated_context = dict(context)
updated_context[self.__symbol] = (self.__value, self.__parameters)
self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context)
self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
class SubstitutionExpression:
@ -222,12 +226,12 @@ class SubstitutionExpression:
self.__symbol = symbol
self.__arguments = arguments
def extend_tensor(self, tensor_builder, steps_range, total_steps, context):
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
updated_context = dict(context)
nested, parameters = context[self.__symbol]
for argument, parameter in zip(self.__arguments, parameters):
updated_context[parameter] = argument, []
nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context)
nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
class LiftExpression:
@ -238,7 +242,10 @@ class LiftExpression:
tensor_builder.append(self.__value)
def _eval_float(expression, steps_range, total_steps, context):
def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
mock_database = ['']
expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context)
return float(mock_database[0])
expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context, is_hires, use_old_scheduling)
try:
return int(mock_database[0])
except ValueError:
return float(mock_database[0])

View File

@ -280,6 +280,11 @@ def parse_attention_weights(prompt, stoppers):
def parse_step(prompt, stoppers):
try:
prompt, step = parse_int_not_float(prompt, stoppers)
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
except ValueError:
pass
try:
prompt, step = parse_float(prompt, stoppers)
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
@ -312,6 +317,10 @@ def parse_float(prompt, stoppers):
return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
def parse_int_not_float(prompt, stoppers):
return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
def parse_dollar(prompt):
dollar_sign = re.escape('$')
return parse_token(prompt, f'({dollar_sign})')

View File

@ -26,8 +26,9 @@ def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, origi
if not shared.opts.prompt_fusion_enabled:
return original_function(model, prompts, total_steps, *args, **kwargs)
hires_steps, *_ = args if args else (None, True)
if hires_steps is not None:
hires_steps, use_old_scheduling, *_ = args if args else (None, True)
is_hires = hires_steps is not None
if is_hires:
real_total_steps = hires_steps
else:
real_total_steps = total_steps
@ -39,7 +40,7 @@ def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, origi
empty_cond.init(model)
tensor_builders = _parse_tensor_builders(prompts, real_total_steps)
tensor_builders = _parse_tensor_builders(prompts, real_total_steps, is_hires, use_old_scheduling)
if hasattr(prompt_parser, 'SdConditioning'):
empty_conditioning = prompt_parser.SdConditioning(prompts)
empty_conditioning.clear()
@ -66,7 +67,7 @@ def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, origi
for begin, end, tensor_builder
in zip(consecutive_ranges[:-1], consecutive_ranges[1:], tensor_builders))
schedules = [_sample_tensor_schedules(cond_tensor, real_total_steps, is_hires=hires_steps is not None)
schedules = [_sample_tensor_schedules(cond_tensor, real_total_steps, is_hires)
for cond_tensor in cond_tensors]
if is_negative_prompt:
@ -93,13 +94,13 @@ def _hijacked_get_multicond_learned_conditioning(*args, original_function, **kwa
return res
def _parse_tensor_builders(prompts, total_steps):
def _parse_tensor_builders(prompts, total_steps, is_hires, use_old_scheduling):
tensor_builders = []
for prompt in prompts:
expr = prompt_fusion_parser.parse_prompt(prompt)
tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict())
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires, use_old_scheduling)
tensor_builders.append(tensor_builder)
return tensor_builders