mirror of https://github.com/vladmandic/automatic
conditional imports and summary timer
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/3593/head
parent
6aa7a4707e
commit
b6963470a9
|
|
@ -79,7 +79,6 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
|
||||
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
|
||||
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
|
||||
self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict])
|
||||
|
||||
# functional api
|
||||
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
|
||||
|
|
@ -89,10 +88,14 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/history", endpoints.get_history, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/history", endpoints.post_history, methods=["POST"], response_model=int)
|
||||
|
||||
# lora api
|
||||
if shared.native:
|
||||
self.add_api_route("/sdapi/v1/loras", endpoints.get_loras, methods=["GET"], response_model=List[dict])
|
||||
self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"])
|
||||
|
||||
# gallery api
|
||||
gallery.register_api(app)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ def unquote(text):
|
|||
return text
|
||||
|
||||
|
||||
# disabled by default can be enabled if needed
|
||||
def check_lora(params):
|
||||
try:
|
||||
import modules.lora.networks as networks
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import modules.lora.network_norm as network_norm
|
|||
import modules.lora.network_glora as network_glora
|
||||
import modules.lora.network_overrides as network_overrides
|
||||
import modules.lora.lora_convert as lora_convert
|
||||
from modules import shared, devices, sd_models, sd_models_compile, errors, scripts, files_cache, model_quant
|
||||
from modules import shared, devices, sd_models, sd_models_compile, errors, files_cache, model_quant
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LORA_DEBUG', None) is not None
|
||||
|
|
@ -44,6 +44,10 @@ module_types = [
|
|||
]
|
||||
|
||||
|
||||
def total_time():
|
||||
return sum(timer.values())
|
||||
|
||||
|
||||
def assign_network_names_to_compvis_modules(sd_model):
|
||||
if sd_model is None:
|
||||
return
|
||||
|
|
@ -394,6 +398,8 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||
|
||||
|
||||
def network_load():
|
||||
for k in timer.keys():
|
||||
timer[k] = 0
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model) # wrapped model compatiblility
|
||||
for component_name in ['text_encoder','text_encoder_2', 'unet', 'transformer']:
|
||||
component = getattr(sd_model, component_name, None)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import numpy as np
|
|||
from modules import shared, processing_correction, extra_networks, timer, prompt_parser_diffusers
|
||||
from modules.lora.networks import network_load
|
||||
|
||||
|
||||
p = None
|
||||
debug = os.environ.get('SD_CALLBACK_DEBUG', None) is not None
|
||||
debug_callback = shared.log.trace if debug else lambda *args, **kwargs: None
|
||||
|
|
@ -15,6 +16,7 @@ def set_callbacks_p(processing):
|
|||
global p # pylint: disable=global-statement
|
||||
p = processing
|
||||
|
||||
|
||||
def prompt_callback(step, kwargs):
|
||||
if prompt_parser_diffusers.embedder is None or 'prompt_embeds' not in kwargs:
|
||||
return kwargs
|
||||
|
|
@ -29,6 +31,7 @@ def prompt_callback(step, kwargs):
|
|||
debug_callback(f"Callback: {e}")
|
||||
return kwargs
|
||||
|
||||
|
||||
def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]):
|
||||
if p is None:
|
||||
return
|
||||
|
|
@ -64,7 +67,7 @@ def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}
|
|||
if shared.state.interrupted or shared.state.skipped:
|
||||
raise AssertionError('Interrupted...')
|
||||
time.sleep(0.1)
|
||||
if hasattr(p, "stepwise_lora"):
|
||||
if hasattr(p, "stepwise_lora") and shared.native:
|
||||
extra_networks.activate(p, p.extra_network_data, step=step)
|
||||
network_load()
|
||||
if latents is None:
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ from modules import shared, devices, processing, sd_models, errors, sd_hijack_hy
|
|||
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled
|
||||
from modules.processing_args import set_pipeline_args
|
||||
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
|
||||
from modules.lora.networks import network_load
|
||||
from modules.lora.networks import timer as network_timer
|
||||
from modules.lora import networks
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_DIFFUSERS_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
|
@ -427,9 +426,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
|
|||
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
if p.negative_prompts is None or len(p.negative_prompts) == 0:
|
||||
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
network_timer['apply'] = 0
|
||||
network_timer['restore'] = 0
|
||||
network_load()
|
||||
|
||||
# load loras
|
||||
networks.network_load()
|
||||
|
||||
sd_models.move_model(shared.sd_model, devices.device)
|
||||
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes
|
||||
|
|
@ -459,6 +458,8 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
|
|||
results = process_decode(p, output)
|
||||
|
||||
timer.process.record('decode')
|
||||
timer.process.add('lora', networks.total_time())
|
||||
|
||||
shared.sd_model = orig_pipeline
|
||||
if p.state == '':
|
||||
global last_p # pylint: disable=global-statement
|
||||
|
|
|
|||
|
|
@ -460,19 +460,20 @@ def register_page(page: ExtraNetworksPage):
|
|||
|
||||
|
||||
def register_pages():
|
||||
debug('EN register-pages')
|
||||
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
|
||||
from modules.ui_extra_networks_lora import ExtraNetworksPageLora
|
||||
from modules.ui_extra_networks_vae import ExtraNetworksPageVAEs
|
||||
from modules.ui_extra_networks_styles import ExtraNetworksPageStyles
|
||||
from modules.ui_extra_networks_history import ExtraNetworksPageHistory
|
||||
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
|
||||
debug('EN register-pages')
|
||||
register_page(ExtraNetworksPageCheckpoints())
|
||||
register_page(ExtraNetworksPageLora())
|
||||
register_page(ExtraNetworksPageVAEs())
|
||||
register_page(ExtraNetworksPageStyles())
|
||||
register_page(ExtraNetworksPageHistory())
|
||||
register_page(ExtraNetworksPageTextualInversion())
|
||||
if shared.native:
|
||||
from modules.ui_extra_networks_lora import ExtraNetworksPageLora
|
||||
register_page(ExtraNetworksPageLora())
|
||||
if shared.opts.hypernetwork_enabled:
|
||||
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
|
||||
register_page(ExtraNetworksPageHypernetworks())
|
||||
|
|
|
|||
7
webui.py
7
webui.py
|
|
@ -34,7 +34,6 @@ import modules.img2img
|
|||
import modules.upscaler
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.hypernetworks.hypernetwork
|
||||
import modules.lora.networks
|
||||
import modules.script_callbacks
|
||||
from modules.api.middleware import setup_middleware
|
||||
from modules.shared import cmd_opts, opts # pylint: disable=unused-import
|
||||
|
|
@ -104,8 +103,10 @@ def initialize():
|
|||
modules.sd_models.setup_model()
|
||||
timer.startup.record("models")
|
||||
|
||||
modules.lora.networks.list_available_networks()
|
||||
timer.startup.record("lora")
|
||||
if shared.native:
|
||||
import modules.lora.networks as lora_networks
|
||||
lora_networks.list_available_networks()
|
||||
timer.startup.record("lora")
|
||||
|
||||
shared.prompt_styles.reload()
|
||||
timer.startup.record("styles")
|
||||
|
|
|
|||
Loading…
Reference in New Issue