diff --git a/scripts/stable_lora/scripts/lora_webui.py b/scripts/stable_lora/scripts/lora_webui.py index 3964425..cacba5e 100755 --- a/scripts/stable_lora/scripts/lora_webui.py +++ b/scripts/stable_lora/scripts/lora_webui.py @@ -172,7 +172,8 @@ class StableLoraScript(Text2VideoExtension, StableLoraProcessor): first_lora_init = not self.is_lora_loaded(p.sd_model) # If the LoRA is still loaded, unload it. - self.handle_lora_start(lora_files, p.sd_model) + unload_args = [p.sd_model, None, use_bias, use_time, use_conv, use_emb, use_linear, None] + self.handle_lora_start(lora_files, p.sd_model, unload_args) can_use_lora = self.can_use_lora(p.sd_model) diff --git a/scripts/stable_lora/stable_utils/lora_processor.py b/scripts/stable_lora/stable_utils/lora_processor.py index 988770c..286c4c3 100644 --- a/scripts/stable_lora/stable_utils/lora_processor.py +++ b/scripts/stable_lora/stable_utils/lora_processor.py @@ -114,12 +114,19 @@ class StableLoraProcessor: return lora_files_to_load - def handle_lora_load(self, sd_model, lora_files_list, set_lora_loaded=False): + def handle_lora_load( + self, + sd_model, + lora_files_list, + set_lora_loaded=False, + unload_args=[] + ): if not hasattr(sd_model, self.lora_loaded) and set_lora_loaded: setattr(sd_model, self.lora_loaded, True) if self.is_lora_loaded(sd_model) and not set_lora_loaded: - self.process_lora(p, lora_files_list, undo_merge=True) + unload_args[1], unload_args[-1] = self.undo_merge_preprocess() + self.process_lora(*unload_args, undo_merge=True) delattr(sd_model, self.lora_loaded) def handle_alpha_change(self, lora_alpha, model): @@ -130,9 +137,14 @@ class StableLoraProcessor: return (options != self.previous_advanced_options) \ and self.is_lora_loaded(model) - def handle_lora_start(self, lora_files, model): + def handle_lora_start(self, lora_files, model, unload_args): if len(lora_files) == 0 and self.is_lora_loaded(model): - self.handle_lora_load(model, lora_files, set_lora_loaded=False) + self.handle_lora_load( + model, + lora_files, + set_lora_loaded=False, + unload_args=unload_args + ) self.log(f"Unloaded previously loaded LoRA files") return