Merge branch 'fix-medvram'
commit
10e76b321e
|
|
@ -15,7 +15,7 @@ from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
|
|||
from scripts.helpers import WebUiComponents
|
||||
|
||||
|
||||
__version__ = "0.3.4"
|
||||
__version__ = "0.3.5"
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torchvision.transforms.functional as functional
|
||||
|
||||
from modules import devices, images
|
||||
from modules import devices, images, shared
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img
|
||||
|
||||
from ldm.modules.attention import BasicTransformerBlock
|
||||
|
|
@ -102,6 +102,14 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
if isinstance(module, BasicTransformerBlock):
|
||||
module.attn1._fabric_old_forward = module.attn1.forward
|
||||
|
||||
# fix for medvram option
|
||||
if shared.cmd_opts.medvram:
|
||||
try:
|
||||
# Trigger register_forward_pre_hook to move the model to correct device
|
||||
p.sd_model.model()
|
||||
except:
|
||||
pass
|
||||
|
||||
## cache hidden states
|
||||
|
||||
cached_hiddens = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue