parent
a18775603a
commit
1d86996ac4
|
|
@ -21,7 +21,7 @@ def get_origin_cond_at(step: int, is_hires: bool = False):
|
|||
|
||||
|
||||
def get_slerp_scale():
|
||||
return shared.opts.data.get('prompt_fusion_slerp_scale', 1)
|
||||
return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
|
||||
|
||||
|
||||
def get_slerp_epsilon():
|
||||
|
|
|
|||
|
|
@ -18,28 +18,33 @@ class InterpolationTensor:
|
|||
self.__interpolation_function = interpolation_function
|
||||
|
||||
def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
|
||||
cond_delta = self.interpolate_cond_delta_rec(params, origin_cond, empty_cond)
|
||||
return (cond_delta + origin_cond.extend_like(cond_delta, empty_cond)).to(dtype=origin_cond.dtype)
|
||||
cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
|
||||
if params.slerp_scale != 0:
|
||||
cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
|
||||
return cond
|
||||
|
||||
def interpolate_cond_delta_rec(self, params: InterpolationParams, origin_cond, empty_cond):
|
||||
def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
|
||||
if self.__interpolation_function is None:
|
||||
return self.to_cond_delta(params.step, origin_cond, empty_cond)
|
||||
return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
|
||||
|
||||
control_points = [
|
||||
sub_tensor.interpolate_cond_delta_rec(params, origin_cond, empty_cond)
|
||||
sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
|
||||
for sub_tensor in self.__sub_tensors
|
||||
]
|
||||
|
||||
CondWrapper, control_points_values = conds_to_cp_values(control_points)
|
||||
return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
|
||||
|
||||
def to_cond_delta(self, step, origin_cond, empty_cond):
|
||||
def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
|
||||
schedule = None
|
||||
for schedule in self.__sub_tensors:
|
||||
if schedule.end_at_step >= step:
|
||||
break
|
||||
|
||||
return schedule.cond.extend_like(origin_cond, empty_cond).to(dtype=torch.double) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.double)
|
||||
res = schedule.cond.extend_like(origin_cond, empty_cond)
|
||||
if slerp_scale != 0:
|
||||
res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
|
||||
return res
|
||||
|
||||
|
||||
def conds_to_cp_values(conds):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
[Extension]
|
||||
Name = prompt-fusion
|
||||
Loading…
Reference in New Issue