average V1 (#76)

* basic impl

* split tests

* split tests
pull/77/head
ljleb 2023-11-27 23:00:22 -05:00 committed by GitHub
parent 676beea9c3
commit 8692c63aa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 11 deletions

View File

@ -22,6 +22,17 @@ class ListExpression:
class InterpolationExpression:
@staticmethod
def create(exprs, steps, function_name):
if function_name == "mean":
return AverageExpression(exprs, steps)
max_len = min(len(exprs), len(steps))
exprs = exprs[:max_len]
steps = steps[:max_len]
return InterpolationExpression(exprs, steps, function_name)
def __init__(self, expressions, steps, function_name=None):
assert len(expressions) >= 2
assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
@ -88,6 +99,50 @@ class InterpolationExpression:
return steps_scale_t
class AverageExpression:
def __init__(self, expressions, weights):
if len(expressions) < len(weights):
raise ValueError
self.__expressions = expressions
self.__weights = weights
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
def tensor_updater(expr):
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
tensor_builder.extrude(
[tensor_updater(expr) for expr in self.__expressions],
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
weights = [
_eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
for weight in self.__weights
]
explicit_weights = [weight for weight in weights if weight is not None]
weights = [
weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
if weight is not None
else 1 / len(self.__expressions)
for weight in weights
]
weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
def interpolation_function(conds, _params):
total = None
for cond, weight in zip(conds, weights):
cond *= weight
if total is None:
total = cond
else:
total += cond
return total
return interpolation_function
class AlternationExpression:
def __init__(self, expressions, speed):
self.__expressions = expressions

View File

@ -1,7 +1,7 @@
import dataclasses
import torch
from modules import prompt_parser
from typing import NamedTuple
from typing import NamedTuple, Union
class InterpolationParams(NamedTuple):
@ -19,7 +19,7 @@ class InterpolationTensor:
def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
cond_delta = self.interpolate_cond_delta_rec(params, origin_cond, empty_cond)
return cond_delta + origin_cond.extend_like(cond_delta, empty_cond)
return (cond_delta + origin_cond.extend_like(cond_delta, empty_cond)).to(dtype=origin_cond.dtype)
def interpolate_cond_delta_rec(self, params: InterpolationParams, origin_cond, empty_cond):
if self.__interpolation_function is None:
@ -39,7 +39,7 @@ class InterpolationTensor:
if schedule.end_at_step >= step:
break
return schedule.cond.extend_like(origin_cond, empty_cond) - origin_cond.extend_like(schedule.cond, empty_cond)
return schedule.cond.extend_like(origin_cond, empty_cond).to(dtype=torch.double) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.double)
def conds_to_cp_values(conds):
@ -168,6 +168,24 @@ class DictCondWrapper:
def to_cp_values(self):
return list(self.original_cond.values())
def to(self, dtype: Union[dict, torch.dtype]):
if not isinstance(dtype, dict):
dtype = {
k: dtype
for k in self.original_cond.items()
}
return DictCondWrapper({
k: v.to(dtype=dtype[k])
for k, v in self.original_cond.items()
})
@property
def dtype(self):
return {
k: v.dtype
for k, v in self.original_cond.items()
}
def __sub__(self, that):
return DictCondWrapper({
k: v - that.original_cond[k]
@ -209,6 +227,13 @@ class TensorCondWrapper:
def to_cp_values(self):
return [self.original_cond]
def to(self, dtype: torch.dtype):
return TensorCondWrapper(self.original_cond.to(dtype=dtype))
@property
def dtype(self):
return self.original_cond.dtype
def __sub__(self, that):
return TensorCondWrapper(self.original_cond - that.original_cond)

View File

@ -125,12 +125,7 @@ def parse_interpolation(prompt, stoppers):
prompt, steps = parse_interpolation_steps(prompt, stoppers)
prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
prompt, _ = parse_close_square(prompt, stoppers)
max_len = min(len(exprs), len(steps))
exprs = exprs[:max_len]
steps = steps[:max_len]
return ParseResult(prompt=prompt, expr=ast.InterpolationExpression(exprs, steps, function_name))
return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
def parse_interpolation_exprs(prompt, stoppers):
@ -153,7 +148,7 @@ def parse_interpolation_exprs(prompt, stoppers):
def parse_interpolation_function_name(prompt, stoppers):
try:
prompt, _ = parse_colon(prompt, stoppers)
function_names = ('linear', 'catmull', 'bezier')
function_names = ('linear', 'catmull', 'bezier', 'mean')
return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
except ValueError:
return ParseResult(prompt=prompt, expr=None)

View File

@ -6,7 +6,7 @@ def run_functional_tests(total_steps=100):
for i, (given, expected) in enumerate(functional_parse_test_cases):
expr = parse_prompt(given)
tensor_builder = InterpolationTensorBuilder()
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict())
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
actual = tensor_builder.get_prompt_database()
@ -94,6 +94,9 @@ functional_parse_test_cases = [
('[a|b|c:0.5]', {'a', 'b', 'c'}),
('[a|b|c:1.1]', {'a', 'b', 'c'}),
('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
('[a:b:c::mean]', {'a', 'b', 'c'}),
('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
]