add Unload model button; Fixed #2

pull/12/head
pkuliyi2015 2023-05-23 09:55:11 +00:00
parent bc86d8a980
commit 9ad38b32e4
2 changed files with 25 additions and 4 deletions

View File

@ -87,8 +87,13 @@ class StableSR:
def unet_forward(x, timesteps=None, context=None, y=None,**kwargs):
self.latent_image = self.latent_image.to(x.device)
# Ensure the device of all modules layers is the same as the unet
# This will fix the issue when user use --medvram or --lowvram
self.spade_layers.to(x.device)
self.struct_cond_model.to(x.device)
timesteps = timesteps.to(x.device)
self.struct_cond = None # mitigate vram peak
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps.to(x.device)[:self.latent_image.shape[0]])
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]])
return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs)
unet.forward = unet_forward
@ -107,7 +112,7 @@ class StableSR:
delattr(unet, FORWARD_CACHE_NAME)
# unhook spade layers
self.spade_layers.unhook(unet)
self.spade_layers.unhook()
class Script(scripts.Script):
@ -151,7 +156,15 @@ class Script(scripts.Script):
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False)
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
unload_model= gr.Button(value='Unload Model', variant='tool')
def unload_model_fn():
if self.stablesr_model is not None:
self.stablesr_model = None
devices.torch_gc()
print('[StableSR] Model unloaded!')
else:
print('[StableSR] No model loaded.')
unload_model.click(fn=unload_model_fn)
return [model, scale_factor, pure_noise, color_fix, save_original]
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
@ -220,6 +233,9 @@ class Script(scripts.Script):
return samples
finally:
self.stablesr_model.unhook(unet)
# in --medvram and --lowvram mode, we send the model back to the initial device
self.stablesr_model.struct_cond_model.to(device=first_param.device)
self.stablesr_model.spade_layers.to(device=first_param.device)
# replace the sample function

View File

@ -125,10 +125,12 @@ class SPADELayers(nn.Module):
self.output_ids = list(range(12))
self.mid_ids = [0,2]
self.forward_cache_name = 'org_forward_stablesr'
self.unet = None
def hook(self, unet: UNetModel, get_struct_cond):
# hook all resblocks
self.unet = unet
resblock: ResBlock = None
for i in self.input_ids:
resblock = unet.input_blocks[i][0]
@ -154,7 +156,9 @@ class SPADELayers(nn.Module):
setattr(resblock, self.forward_cache_name, resblock._forward)
resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.middle_block[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond)
def unhook(self, unet: UNetModel):
def unhook(self):
unet = self.unet
if unet is None: return
resblock: ResBlock = None
for i in self.input_ids:
resblock = unet.input_blocks[i][0]
@ -173,6 +177,7 @@ class SPADELayers(nn.Module):
if hasattr(resblock, self.forward_cache_name):
resblock._forward = getattr(resblock, self.forward_cache_name)
delattr(resblock, self.forward_cache_name)
self.unet = None
def load_from_dict(self, state_dict):