diff --git a/modules/api/api.py b/modules/api/api.py index 7d2c2f279..b958085ea 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -79,7 +79,6 @@ class Api: self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae]) self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension]) self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork]) - self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict]) # functional api self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo) @@ -89,10 +88,14 @@ class Api: self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"]) - self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"]) self.add_api_route("/sdapi/v1/history", endpoints.get_history, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/history", endpoints.post_history, methods=["POST"], response_model=int) + # lora api + if shared.native: + self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict]) + self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"]) + # gallery api gallery.register_api(app) diff --git a/modules/infotext.py b/modules/infotext.py index 4b9dd15ff..baa995c88 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -28,6 +28,7 @@ def unquote(text): return text +# disabled by default can be enabled if needed def check_lora(params): try: import modules.lora.networks as networks diff --git a/modules/lora/networks.py b/modules/lora/networks.py index dc6d86b2f..2db145a5a 100644 --- a/modules/lora/networks.py +++ b/modules/lora/networks.py @@ -17,7 +17,7 @@ import modules.lora.network_norm as network_norm import modules.lora.network_glora as network_glora import modules.lora.network_overrides as network_overrides import modules.lora.lora_convert as lora_convert -from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant +from modules import shared, devices, sd_models, sd_models_compile, errors, files_cache, model_quant debug = os.environ.get('SD_LORA_DEBUG', None) is not None @@ -44,6 +44,10 @@ module_types = [ ] +def total_time(): + return sum(timer.values()) + + def assign_network_names_to_compvis_modules(sd_model): if sd_model is None: return @@ -394,6 +398,8 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn def network_load(): + for k in timer.keys(): + timer[k] = 0 sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']: component = getattr(sd_model, component_name, None) diff --git a/modules/processing_callbacks.py b/modules/processing_callbacks.py index 0bb94abad..e1bf723cc 100644 --- a/modules/processing_callbacks.py +++ b/modules/processing_callbacks.py @@ -6,6 +6,7 @@ import numpy as np from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers from modules.lora.networks import network_load + p = None debug = os.environ.get('SD_CALLBACK_DEBUG', None) is not None debug_callback = shared.log.trace if debug else lambda *args, **kwargs: None @@ -15,6 +16,7 @@ def set_callbacks_p(processing): global p # pylint: disable=global-statement p = processing + def prompt_callback(step, kwargs): if prompt_parser_diffusers.embedder is None or 'prompt_embeds' not in kwargs: return kwargs @@ -29,6 +31,7 @@ def prompt_callback(step, kwargs): debug_callback(f"Callback: {e}") return kwargs + def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]): if p is None: return @@ -64,7 +67,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {} if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') time.sleep(0.1) - if hasattr(p, "stepwise_lora"): + if hasattr(p, "stepwise_lora") and shared.native: extra_networks.activate(p, p.extra_network_data, step=step) network_load() if latents is None: diff --git a/modules/processing_diffusers.py b/modules/processing_diffusers.py index 2e8fb357c..ae24f5f80 100644 --- a/modules/processing_diffusers.py +++ b/modules/processing_diffusers.py @@ -8,8 +8,7 @@ from modules import shared, devices, processing, sd_models, errors, sd_hijack_hy from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled from modules.processing_args import set_pipeline_args from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed -from modules.lora.networks import network_load -from modules.lora.networks import timer as network_timer +from modules.lora import networks debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -427,9 +426,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing): p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size] if p.negative_prompts is None or len(p.negative_prompts) == 0: p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size] - network_timer['apply'] = 0 - network_timer['restore'] = 0 - network_load() + + # load loras + networks.network_load() sd_models.move_model(shared.sd_model, devices.device) sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes @@ -459,6 +458,8 @@ def process_diffusers(p: processing.StableDiffusionProcessing): results = process_decode(p, output) timer.process.record('decode') + timer.process.add('lora', networks.total_time()) + shared.sd_model = orig_pipeline if p.state == '': global last_p # pylint: disable=global-statement diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 898522366..94664c5cb 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -460,19 +460,20 @@ def register_page(page: ExtraNetworksPage): def register_pages(): + debug('EN register-pages') from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints - from modules.ui_extra_networks_lora import ExtraNetworksPageLora from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs from modules.ui_extra_networks_styles import ExtraNetworksPageStyles from modules.ui_extra_networks_history import ExtraNetworksPageHistory from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion - debug('EN register-pages') register_page(ExtraNetworksPageCheckpoints()) - register_page(ExtraNetworksPageLora()) register_page(ExtraNetworksPageVAEs()) register_page(ExtraNetworksPageStyles()) register_page(ExtraNetworksPageHistory()) register_page(ExtraNetworksPageTextualInversion()) + if shared.native: + from modules.ui_extra_networks_lora import ExtraNetworksPageLora + register_page(ExtraNetworksPageLora()) if shared.opts.hypernetwork_enabled: from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks register_page(ExtraNetworksPageHypernetworks()) diff --git a/webui.py b/webui.py index 3aae34447..2b8d7c56f 100644 --- a/webui.py +++ b/webui.py @@ -34,7 +34,6 @@ import modules.img2img import modules.upscaler import modules.textual_inversion.textual_inversion import modules.hypernetworks.hypernetwork -import modules.lora.networks import modules.script_callbacks from modules.api.middleware import setup_middleware from modules.shared import cmd_opts, opts # pylint: disable=unused-import @@ -104,8 +103,10 @@ def initialize(): modules.sd_models.setup_model() timer.startup.record("models") - modules.lora.networks.list_available_networks() - timer.startup.record("lora") + if shared.native: + import modules.lora.networks as lora_networks + lora_networks.list_available_networks() + timer.startup.record("lora") shared.prompt_styles.reload() timer.startup.record("styles")