add Unload model button; Fixed #2
parent
bc86d8a980
commit
9ad38b32e4
|
|
@ -87,8 +87,13 @@ class StableSR:
|
||||||
|
|
||||||
def unet_forward(x, timesteps=None, context=None, y=None,**kwargs):
|
def unet_forward(x, timesteps=None, context=None, y=None,**kwargs):
|
||||||
self.latent_image = self.latent_image.to(x.device)
|
self.latent_image = self.latent_image.to(x.device)
|
||||||
|
# Ensure the device of all modules layers is the same as the unet
|
||||||
|
# This will fix the issue when user use --medvram or --lowvram
|
||||||
|
self.spade_layers.to(x.device)
|
||||||
|
self.struct_cond_model.to(x.device)
|
||||||
|
timesteps = timesteps.to(x.device)
|
||||||
self.struct_cond = None # mitigate vram peak
|
self.struct_cond = None # mitigate vram peak
|
||||||
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps.to(x.device)[:self.latent_image.shape[0]])
|
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]])
|
||||||
return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs)
|
return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs)
|
||||||
|
|
||||||
unet.forward = unet_forward
|
unet.forward = unet_forward
|
||||||
|
|
@ -107,7 +112,7 @@ class StableSR:
|
||||||
delattr(unet, FORWARD_CACHE_NAME)
|
delattr(unet, FORWARD_CACHE_NAME)
|
||||||
|
|
||||||
# unhook spade layers
|
# unhook spade layers
|
||||||
self.spade_layers.unhook(unet)
|
self.spade_layers.unhook()
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
|
|
@ -151,7 +156,15 @@ class Script(scripts.Script):
|
||||||
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
|
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None')
|
||||||
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False)
|
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False)
|
||||||
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
|
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise')
|
||||||
|
unload_model= gr.Button(value='Unload Model', variant='tool')
|
||||||
|
def unload_model_fn():
|
||||||
|
if self.stablesr_model is not None:
|
||||||
|
self.stablesr_model = None
|
||||||
|
devices.torch_gc()
|
||||||
|
print('[StableSR] Model unloaded!')
|
||||||
|
else:
|
||||||
|
print('[StableSR] No model loaded.')
|
||||||
|
unload_model.click(fn=unload_model_fn)
|
||||||
return [model, scale_factor, pure_noise, color_fix, save_original]
|
return [model, scale_factor, pure_noise, color_fix, save_original]
|
||||||
|
|
||||||
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
|
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed:
|
||||||
|
|
@ -220,6 +233,9 @@ class Script(scripts.Script):
|
||||||
return samples
|
return samples
|
||||||
finally:
|
finally:
|
||||||
self.stablesr_model.unhook(unet)
|
self.stablesr_model.unhook(unet)
|
||||||
|
# in --medvram and --lowvram mode, we send the model back to the initial device
|
||||||
|
self.stablesr_model.struct_cond_model.to(device=first_param.device)
|
||||||
|
self.stablesr_model.spade_layers.to(device=first_param.device)
|
||||||
|
|
||||||
|
|
||||||
# replace the sample function
|
# replace the sample function
|
||||||
|
|
|
||||||
|
|
@ -125,10 +125,12 @@ class SPADELayers(nn.Module):
|
||||||
self.output_ids = list(range(12))
|
self.output_ids = list(range(12))
|
||||||
self.mid_ids = [0,2]
|
self.mid_ids = [0,2]
|
||||||
self.forward_cache_name = 'org_forward_stablesr'
|
self.forward_cache_name = 'org_forward_stablesr'
|
||||||
|
self.unet = None
|
||||||
|
|
||||||
|
|
||||||
def hook(self, unet: UNetModel, get_struct_cond):
|
def hook(self, unet: UNetModel, get_struct_cond):
|
||||||
# hook all resblocks
|
# hook all resblocks
|
||||||
|
self.unet = unet
|
||||||
resblock: ResBlock = None
|
resblock: ResBlock = None
|
||||||
for i in self.input_ids:
|
for i in self.input_ids:
|
||||||
resblock = unet.input_blocks[i][0]
|
resblock = unet.input_blocks[i][0]
|
||||||
|
|
@ -154,7 +156,9 @@ class SPADELayers(nn.Module):
|
||||||
setattr(resblock, self.forward_cache_name, resblock._forward)
|
setattr(resblock, self.forward_cache_name, resblock._forward)
|
||||||
resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.middle_block[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond)
|
resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.middle_block[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond)
|
||||||
|
|
||||||
def unhook(self, unet: UNetModel):
|
def unhook(self):
|
||||||
|
unet = self.unet
|
||||||
|
if unet is None: return
|
||||||
resblock: ResBlock = None
|
resblock: ResBlock = None
|
||||||
for i in self.input_ids:
|
for i in self.input_ids:
|
||||||
resblock = unet.input_blocks[i][0]
|
resblock = unet.input_blocks[i][0]
|
||||||
|
|
@ -173,6 +177,7 @@ class SPADELayers(nn.Module):
|
||||||
if hasattr(resblock, self.forward_cache_name):
|
if hasattr(resblock, self.forward_cache_name):
|
||||||
resblock._forward = getattr(resblock, self.forward_cache_name)
|
resblock._forward = getattr(resblock, self.forward_cache_name)
|
||||||
delattr(resblock, self.forward_cache_name)
|
delattr(resblock, self.forward_cache_name)
|
||||||
|
self.unet = None
|
||||||
|
|
||||||
|
|
||||||
def load_from_dict(self, state_dict):
|
def load_from_dict(self, state_dict):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue