Merge branch 'fix-medvram'

pull/23/head
Dimitri 2023-07-23 21:56:15 +02:00
commit 10e76b321e
2 changed files with 10 additions and 2 deletions

View File

@ -15,7 +15,7 @@ from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
from scripts.helpers import WebUiComponents
__version__ = "0.3.4"
__version__ = "0.3.5"
DEBUG = False

View File

@ -1,7 +1,7 @@
import torch
import torchvision.transforms.functional as functional
from modules import devices, images
from modules import devices, images, shared
from modules.processing import StableDiffusionProcessingTxt2Img
from ldm.modules.attention import BasicTransformerBlock
@ -102,6 +102,14 @@ def patch_unet_forward_pass(p, unet, params):
if isinstance(module, BasicTransformerBlock):
module.attn1._fabric_old_forward = module.attn1.forward
# fix for medvram option
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass
## cache hidden states
cached_hiddens = []