Fixed an infinite loop after getting an error in the image generation

pull/28/head
h3rmit 2023-08-26 10:21:38 +03:00
parent 0018b6be2d
commit 2f207ff37b
1 changed files with 2 additions and 2 deletions

View File

@ -118,7 +118,7 @@ def patch_unet_forward_pass(p, unet, params):
# save original forward pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward
# fix for medvram option
@ -200,7 +200,7 @@ def patch_unet_forward_pass(p, unet, params):
# restore original pass
for module in self.modules():
if isinstance(module, BasicTransformerBlock):
if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"):
module.attn1.forward = module.attn1._fabric_old_forward
del module.attn1._fabric_old_forward