diff --git a/scripts/fabric.py b/scripts/fabric.py index 0aea716..26e68a9 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -15,7 +15,7 @@ from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass from scripts.helpers import WebUiComponents -__version__ = "0.3.4" +__version__ = "0.3.5" DEBUG = False diff --git a/scripts/patching.py b/scripts/patching.py index 65105ab..e7e468f 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -1,7 +1,7 @@ import torch import torchvision.transforms.functional as functional -from modules import devices, images +from modules import devices, images, shared from modules.processing import StableDiffusionProcessingTxt2Img from ldm.modules.attention import BasicTransformerBlock @@ -102,6 +102,14 @@ def patch_unet_forward_pass(p, unet, params): if isinstance(module, BasicTransformerBlock): module.attn1._fabric_old_forward = module.attn1.forward + # fix for medvram option + if shared.cmd_opts.medvram: + try: + # Trigger register_forward_pre_hook to move the model to correct device + p.sd_model.model() + except: + pass + ## cache hidden states cached_hiddens = []