add Unload model button; Fixed #2
parent
bc86d8a980
commit
9ad38b32e4
|
|
@ -87,8 +87,13 @@ class StableSR:
|
|||
|
||||
def unet_forward(x, timesteps=None, context=None, y=None,**kwargs):
|
||||
self.latent_image = self.latent_image.to(x.device)
|
||||
# Ensure the device of all modules layers is the same as the unet
|
||||
# This will fix the issue when user use --medvram or --lowvram
|
||||
self.spade_layers.to(x.device)
|
||||
self.struct_cond_model.to(x.device)
|
||||
timesteps = timesteps.to(x.device)
|
||||
self.struct_cond = None # mitigate vram peak
|
||||
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps.to(x.device)[:self.latent_image.shape[0]])
|
||||
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]])
|
||||
return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs)
|
||||
|
||||
unet.forward = unet_forward
|
||||
|
|
@ -107,7 +112,7 @@ class StableSR:
|
|||
delattr(unet, FORWARD_CACHE_NAME)
|
||||
|
||||
# unhook spade layers
|
||||
self.spade_layers.unhook(unet)
|
||||
self.spade_layers.unhook()
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
|
|
@ -151,7 +156,15 @@ class Script(scripts.Script):
|
|||
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
|
||||
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False)
|
||||
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
|
||||
|
||||
unload_model= gr.Button(value='Unload Model', variant='tool')
|
||||
def unload_model_fn():
|
||||
if self.stablesr_model is not None:
|
||||
self.stablesr_model = None
|
||||
devices.torch_gc()
|
||||
print('[StableSR] Model unloaded!')
|
||||
else:
|
||||
print('[StableSR] No model loaded.')
|
||||
unload_model.click(fn=unload_model_fn)
|
||||
return [model, scale_factor, pure_noise, color_fix, save_original]
|
||||
|
||||
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
|
||||
|
|
@ -220,6 +233,9 @@ class Script(scripts.Script):
|
|||
return samples
|
||||
finally:
|
||||
self.stablesr_model.unhook(unet)
|
||||
# in --medvram and --lowvram mode, we send the model back to the initial device
|
||||
self.stablesr_model.struct_cond_model.to(device=first_param.device)
|
||||
self.stablesr_model.spade_layers.to(device=first_param.device)
|
||||
|
||||
|
||||
# replace the sample function
|
||||
|
|
|
|||
|
|
@ -125,10 +125,12 @@ class SPADELayers(nn.Module):
|
|||
self.output_ids = list(range(12))
|
||||
self.mid_ids = [0,2]
|
||||
self.forward_cache_name = 'org_forward_stablesr'
|
||||
self.unet = None
|
||||
|
||||
|
||||
def hook(self, unet: UNetModel, get_struct_cond):
|
||||
# hook all resblocks
|
||||
self.unet = unet
|
||||
resblock: ResBlock = None
|
||||
for i in self.input_ids:
|
||||
resblock = unet.input_blocks[i][0]
|
||||
|
|
@ -154,7 +156,9 @@ class SPADELayers(nn.Module):
|
|||
setattr(resblock, self.forward_cache_name, resblock._forward)
|
||||
resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.middle_block[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond)
|
||||
|
||||
def unhook(self, unet: UNetModel):
|
||||
def unhook(self):
|
||||
unet = self.unet
|
||||
if unet is None: return
|
||||
resblock: ResBlock = None
|
||||
for i in self.input_ids:
|
||||
resblock = unet.input_blocks[i][0]
|
||||
|
|
@ -173,6 +177,7 @@ class SPADELayers(nn.Module):
|
|||
if hasattr(resblock, self.forward_cache_name):
|
||||
resblock._forward = getattr(resblock, self.forward_cache_name)
|
||||
delattr(resblock, self.forward_cache_name)
|
||||
self.unet = None
|
||||
|
||||
|
||||
def load_from_dict(self, state_dict):
|
||||
|
|
|
|||
Loading…
Reference in New Issue