parent
676beea9c3
commit
8692c63aa4
|
|
@ -22,6 +22,17 @@ class ListExpression:
|
|||
|
||||
|
||||
class InterpolationExpression:
|
||||
@staticmethod
|
||||
def create(exprs, steps, function_name):
|
||||
if function_name == "mean":
|
||||
return AverageExpression(exprs, steps)
|
||||
|
||||
max_len = min(len(exprs), len(steps))
|
||||
exprs = exprs[:max_len]
|
||||
steps = steps[:max_len]
|
||||
|
||||
return InterpolationExpression(exprs, steps, function_name)
|
||||
|
||||
def __init__(self, expressions, steps, function_name=None):
|
||||
assert len(expressions) >= 2
|
||||
assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
|
||||
|
|
@ -88,6 +99,50 @@ class InterpolationExpression:
|
|||
return steps_scale_t
|
||||
|
||||
|
||||
class AverageExpression:
|
||||
def __init__(self, expressions, weights):
|
||||
if len(expressions) < len(weights):
|
||||
raise ValueError
|
||||
|
||||
self.__expressions = expressions
|
||||
self.__weights = weights
|
||||
|
||||
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
||||
def tensor_updater(expr):
|
||||
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
||||
|
||||
tensor_builder.extrude(
|
||||
[tensor_updater(expr) for expr in self.__expressions],
|
||||
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
|
||||
|
||||
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
||||
weights = [
|
||||
_eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
|
||||
for weight in self.__weights
|
||||
]
|
||||
explicit_weights = [weight for weight in weights if weight is not None]
|
||||
weights = [
|
||||
weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
|
||||
if weight is not None
|
||||
else 1 / len(self.__expressions)
|
||||
for weight in weights
|
||||
]
|
||||
weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
|
||||
|
||||
def interpolation_function(conds, _params):
|
||||
total = None
|
||||
for cond, weight in zip(conds, weights):
|
||||
cond *= weight
|
||||
if total is None:
|
||||
total = cond
|
||||
else:
|
||||
total += cond
|
||||
|
||||
return total
|
||||
|
||||
return interpolation_function
|
||||
|
||||
|
||||
class AlternationExpression:
|
||||
def __init__(self, expressions, speed):
|
||||
self.__expressions = expressions
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import dataclasses
|
||||
import torch
|
||||
from modules import prompt_parser
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
|
||||
class InterpolationParams(NamedTuple):
|
||||
|
|
@ -19,7 +19,7 @@ class InterpolationTensor:
|
|||
|
||||
def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
|
||||
cond_delta = self.interpolate_cond_delta_rec(params, origin_cond, empty_cond)
|
||||
return cond_delta + origin_cond.extend_like(cond_delta, empty_cond)
|
||||
return (cond_delta + origin_cond.extend_like(cond_delta, empty_cond)).to(dtype=origin_cond.dtype)
|
||||
|
||||
def interpolate_cond_delta_rec(self, params: InterpolationParams, origin_cond, empty_cond):
|
||||
if self.__interpolation_function is None:
|
||||
|
|
@ -39,7 +39,7 @@ class InterpolationTensor:
|
|||
if schedule.end_at_step >= step:
|
||||
break
|
||||
|
||||
return schedule.cond.extend_like(origin_cond, empty_cond) - origin_cond.extend_like(schedule.cond, empty_cond)
|
||||
return schedule.cond.extend_like(origin_cond, empty_cond).to(dtype=torch.double) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.double)
|
||||
|
||||
|
||||
def conds_to_cp_values(conds):
|
||||
|
|
@ -168,6 +168,24 @@ class DictCondWrapper:
|
|||
def to_cp_values(self):
|
||||
return list(self.original_cond.values())
|
||||
|
||||
def to(self, dtype: Union[dict, torch.dtype]):
|
||||
if not isinstance(dtype, dict):
|
||||
dtype = {
|
||||
k: dtype
|
||||
for k in self.original_cond.items()
|
||||
}
|
||||
return DictCondWrapper({
|
||||
k: v.to(dtype=dtype[k])
|
||||
for k, v in self.original_cond.items()
|
||||
})
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return {
|
||||
k: v.dtype
|
||||
for k, v in self.original_cond.items()
|
||||
}
|
||||
|
||||
def __sub__(self, that):
|
||||
return DictCondWrapper({
|
||||
k: v - that.original_cond[k]
|
||||
|
|
@ -209,6 +227,13 @@ class TensorCondWrapper:
|
|||
def to_cp_values(self):
|
||||
return [self.original_cond]
|
||||
|
||||
def to(self, dtype: torch.dtype):
|
||||
return TensorCondWrapper(self.original_cond.to(dtype=dtype))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.original_cond.dtype
|
||||
|
||||
def __sub__(self, that):
|
||||
return TensorCondWrapper(self.original_cond - that.original_cond)
|
||||
|
||||
|
|
|
|||
|
|
@ -125,12 +125,7 @@ def parse_interpolation(prompt, stoppers):
|
|||
prompt, steps = parse_interpolation_steps(prompt, stoppers)
|
||||
prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
|
||||
prompt, _ = parse_close_square(prompt, stoppers)
|
||||
|
||||
max_len = min(len(exprs), len(steps))
|
||||
exprs = exprs[:max_len]
|
||||
steps = steps[:max_len]
|
||||
|
||||
return ParseResult(prompt=prompt, expr=ast.InterpolationExpression(exprs, steps, function_name))
|
||||
return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
|
||||
|
||||
|
||||
def parse_interpolation_exprs(prompt, stoppers):
|
||||
|
|
@ -153,7 +148,7 @@ def parse_interpolation_exprs(prompt, stoppers):
|
|||
def parse_interpolation_function_name(prompt, stoppers):
|
||||
try:
|
||||
prompt, _ = parse_colon(prompt, stoppers)
|
||||
function_names = ('linear', 'catmull', 'bezier')
|
||||
function_names = ('linear', 'catmull', 'bezier', 'mean')
|
||||
return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
|
||||
except ValueError:
|
||||
return ParseResult(prompt=prompt, expr=None)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ def run_functional_tests(total_steps=100):
|
|||
for i, (given, expected) in enumerate(functional_parse_test_cases):
|
||||
expr = parse_prompt(given)
|
||||
tensor_builder = InterpolationTensorBuilder()
|
||||
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict())
|
||||
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
|
||||
|
||||
actual = tensor_builder.get_prompt_database()
|
||||
|
||||
|
|
@ -94,6 +94,9 @@ functional_parse_test_cases = [
|
|||
('[a|b|c:0.5]', {'a', 'b', 'c'}),
|
||||
('[a|b|c:1.1]', {'a', 'b', 'c'}),
|
||||
('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
|
||||
('[a:b:c::mean]', {'a', 'b', 'c'}),
|
||||
('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
|
||||
('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue