conditional imports and summary timer

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3593/head
Vladimir Mandic 2024-11-29 10:05:06 -05:00
parent 6aa7a4707e
commit b6963470a9
7 changed files with 31 additions and 15 deletions

View File

@ -79,7 +79,6 @@ class Api:
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict])
# functional api
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
@ -89,10 +88,14 @@ class Api:
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"])
self.add_api_route("/sdapi/v1/history", endpoints.get_history, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/history", endpoints.post_history, methods=["POST"], response_model=int)
# lora api
if shared.native:
self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict])
self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"])
# gallery api
gallery.register_api(app)

View File

@ -28,6 +28,7 @@ def unquote(text):
return text
# disabled by default can be enabled if needed
def check_lora(params):
try:
import modules.lora.networks as networks

View File

@ -17,7 +17,7 @@ import modules.lora.network_norm as network_norm
import modules.lora.network_glora as network_glora
import modules.lora.network_overrides as network_overrides
import modules.lora.lora_convert as lora_convert
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant
from modules import shared, devices, sd_models, sd_models_compile, errors, files_cache, model_quant
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
@ -44,6 +44,10 @@ module_types = [
]
def total_time():
return sum(timer.values())
def assign_network_names_to_compvis_modules(sd_model):
if sd_model is None:
return
@ -394,6 +398,8 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
def network_load():
for k in timer.keys():
timer[k] = 0
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']:
component = getattr(sd_model, component_name, None)

View File

@ -6,6 +6,7 @@ import numpy as np
from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers
from modules.lora.networks import network_load
p = None
debug = os.environ.get('SD_CALLBACK_DEBUG', None) is not None
debug_callback = shared.log.trace if debug else lambda *args, **kwargs: None
@ -15,6 +16,7 @@ def set_callbacks_p(processing):
global p # pylint: disable=global-statement
p = processing
def prompt_callback(step, kwargs):
if prompt_parser_diffusers.embedder is None or 'prompt_embeds' not in kwargs:
return kwargs
@ -29,6 +31,7 @@ def prompt_callback(step, kwargs):
debug_callback(f"Callback: {e}")
return kwargs
def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]):
if p is None:
return
@ -64,7 +67,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
if shared.state.interrupted or shared.state.skipped:
raise AssertionError('Interrupted...')
time.sleep(0.1)
if hasattr(p, "stepwise_lora"):
if hasattr(p, "stepwise_lora") and shared.native:
extra_networks.activate(p, p.extra_network_data, step=step)
network_load()
if latents is None:

View File

@ -8,8 +8,7 @@ from modules import shared, devices, processing, sd_models, errors, sd_hijack_hy
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled
from modules.processing_args import set_pipeline_args
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
from modules.lora.networks import network_load
from modules.lora.networks import timer as network_timer
from modules.lora import networks
debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
@ -427,9 +426,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
if p.negative_prompts is None or len(p.negative_prompts) == 0:
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
network_timer['apply'] = 0
network_timer['restore'] = 0
network_load()
# load loras
networks.network_load()
sd_models.move_model(shared.sd_model, devices.device)
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes
@ -459,6 +458,8 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
results = process_decode(p, output)
timer.process.record('decode')
timer.process.add('lora', networks.total_time())
shared.sd_model = orig_pipeline
if p.state == '':
global last_p # pylint: disable=global-statement

View File

@ -460,19 +460,20 @@ def register_page(page: ExtraNetworksPage):
def register_pages():
debug('EN register-pages')
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
from modules.ui_extra_networks_lora import ExtraNetworksPageLora
from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs
from modules.ui_extra_networks_styles import ExtraNetworksPageStyles
from modules.ui_extra_networks_history import ExtraNetworksPageHistory
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
debug('EN register-pages')
register_page(ExtraNetworksPageCheckpoints())
register_page(ExtraNetworksPageLora())
register_page(ExtraNetworksPageVAEs())
register_page(ExtraNetworksPageStyles())
register_page(ExtraNetworksPageHistory())
register_page(ExtraNetworksPageTextualInversion())
if shared.native:
from modules.ui_extra_networks_lora import ExtraNetworksPageLora
register_page(ExtraNetworksPageLora())
if shared.opts.hypernetwork_enabled:
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
register_page(ExtraNetworksPageHypernetworks())

View File

@ -34,7 +34,6 @@ import modules.img2img
import modules.upscaler
import modules.textual_inversion.textual_inversion
import modules.hypernetworks.hypernetwork
import modules.lora.networks
import modules.script_callbacks
from modules.api.middleware import setup_middleware
from modules.shared import cmd_opts, opts # pylint: disable=unused-import
@ -104,8 +103,10 @@ def initialize():
modules.sd_models.setup_model()
timer.startup.record("models")
modules.lora.networks.list_available_networks()
timer.startup.record("lora")
if shared.native:
import modules.lora.networks as lora_networks
lora_networks.list_available_networks()
timer.startup.record("lora")
shared.prompt_styles.reload()
timer.startup.record("styles")