fix fp64 not supported by intel hardware (#79)

* fix

* dtype in diff space only
main
ljleb 2023-12-12 20:00:15 -05:00 committed by GitHub
parent a18775603a
commit 1d86996ac4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 8 deletions

View File

@ -21,7 +21,7 @@ def get_origin_cond_at(step: int, is_hires: bool = False):
def get_slerp_scale():
return shared.opts.data.get('prompt_fusion_slerp_scale', 1)
return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
def get_slerp_epsilon():

View File

@ -18,28 +18,33 @@ class InterpolationTensor:
self.__interpolation_function = interpolation_function
def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
cond_delta = self.interpolate_cond_delta_rec(params, origin_cond, empty_cond)
return (cond_delta + origin_cond.extend_like(cond_delta, empty_cond)).to(dtype=origin_cond.dtype)
cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
if params.slerp_scale != 0:
cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
return cond
def interpolate_cond_delta_rec(self, params: InterpolationParams, origin_cond, empty_cond):
def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
if self.__interpolation_function is None:
return self.to_cond_delta(params.step, origin_cond, empty_cond)
return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
control_points = [
sub_tensor.interpolate_cond_delta_rec(params, origin_cond, empty_cond)
sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
for sub_tensor in self.__sub_tensors
]
CondWrapper, control_points_values = conds_to_cp_values(control_points)
return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
def to_cond_delta(self, step, origin_cond, empty_cond):
def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
schedule = None
for schedule in self.__sub_tensors:
if schedule.end_at_step >= step:
break
return schedule.cond.extend_like(origin_cond, empty_cond).to(dtype=torch.double) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.double)
res = schedule.cond.extend_like(origin_cond, empty_cond)
if slerp_scale != 0:
res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
return res
def conds_to_cp_values(conds):

2
metadata.ini Normal file
View File

@ -0,0 +1,2 @@
[Extension]
Name = prompt-fusion