refactor ddpmedit

pull/2256/head
Vladimir Mandic 2023-09-20 13:17:33 -04:00
parent 200ced8b1c
commit b2e7bcd546
3 changed files with 4 additions and 4 deletions

View File

@ -3,7 +3,7 @@
model:
base_learning_rate: 1.0e-04
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
target: modules.hijack.ddpm_edit.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120

View File

@ -59,9 +59,9 @@ ddpm_edit_hijack = None
def hijack_ddpm_edit():
global ddpm_edit_hijack # pylint: disable=global-statement
if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.hijack.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast # pylint: disable=unnecessary-lambda-assignment