diff --git a/scripts/stablesr.py b/scripts/stablesr.py index b08a561..b21b12f 100644 --- a/scripts/stablesr.py +++ b/scripts/stablesr.py @@ -87,8 +87,13 @@ class StableSR: def unet_forward(x, timesteps=None, context=None, y=None,**kwargs): self.latent_image = self.latent_image.to(x.device) + # Ensure the device of all modules layers is the same as the unet + # This will fix the issue when user use --medvram or --lowvram + self.spade_layers.to(x.device) + self.struct_cond_model.to(x.device) + timesteps = timesteps.to(x.device) self.struct_cond = None # mitigate vram peak - self.struct_cond = self.struct_cond_model(self.latent_image, timesteps.to(x.device)[:self.latent_image.shape[0]]) + self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]]) return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs) unet.forward = unet_forward @@ -107,7 +112,7 @@ class StableSR: delattr(unet, FORWARD_CACHE_NAME) # unhook spade layers - self.spade_layers.unhook(unet) + self.spade_layers.unhook() class Script(scripts.Script): @@ -151,7 +156,15 @@ class Script(scripts.Script): save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None') color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False) pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise') - + unload_model= gr.Button(value='Unload Model', variant='tool') + def unload_model_fn(): + if self.stablesr_model is not None: + self.stablesr_model = None + devices.torch_gc() + print('[StableSR] Model unloaded!') + else: + print('[StableSR] No model loaded.') + unload_model.click(fn=unload_model_fn) return [model, scale_factor, pure_noise, color_fix, save_original] def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed: @@ -220,6 +233,9 @@ class Script(scripts.Script): return samples finally: self.stablesr_model.unhook(unet) + # in --medvram and --lowvram mode, we send the model back to the initial device + self.stablesr_model.struct_cond_model.to(device=first_param.device) + self.stablesr_model.spade_layers.to(device=first_param.device) # replace the sample function diff --git a/srmodule/spade.py b/srmodule/spade.py index cc4fe2c..2d21e9b 100644 --- a/srmodule/spade.py +++ b/srmodule/spade.py @@ -125,10 +125,12 @@ class SPADELayers(nn.Module): self.output_ids = list(range(12)) self.mid_ids = [0,2] self.forward_cache_name = 'org_forward_stablesr' + self.unet = None def hook(self, unet: UNetModel, get_struct_cond): # hook all resblocks + self.unet = unet resblock: ResBlock = None for i in self.input_ids: resblock = unet.input_blocks[i][0] @@ -154,7 +156,9 @@ class SPADELayers(nn.Module): setattr(resblock, self.forward_cache_name, resblock._forward) resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.middle_block[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond) - def unhook(self, unet: UNetModel): + def unhook(self): + unet = self.unet + if unet is None: return resblock: ResBlock = None for i in self.input_ids: resblock = unet.input_blocks[i][0] @@ -173,6 +177,7 @@ class SPADELayers(nn.Module): if hasattr(resblock, self.forward_cache_name): resblock._forward = getattr(resblock, self.forward_cache_name) delattr(resblock, self.forward_cache_name) + self.unet = None def load_from_dict(self, state_dict):