merge: modules/sd_models_compile.py

pull/4678/head
vladmandic 2026-03-12 14:16:51 +01:00
parent b526d98cb0
commit 939eea1cd9
1 changed files with 17 additions and 4 deletions

View File

@ -1,3 +1,4 @@
import os
import time
import logging
import torch
@ -5,6 +6,9 @@ from modules import shared, errors, devices, sd_models, sd_models_utils
from modules.logger import log
from installer import setup_logging
debug = os.environ.get('SD_COMPILE_DEBUG', None) is not None
debug_log = log.trace if debug else lambda *args, **kwargs: None
#Used by OpenVINO, can be used with TensorRT or Olive
class CompiledModelState:
@ -90,6 +94,7 @@ def compile_onediff(sd_model):
log.warning(f"Model compile: task=onediff {e}")
return sd_model
debug_log(f"Model compile: task=onediff pipeline={sd_model.__class__.__name__}")
try:
t0 = time.time()
# For some reason compiling the text_encoder, when it is used by
@ -139,6 +144,7 @@ def compile_stablefast(sd_model):
# config.trace_scheduler = False
# config.enable_cnn_optimization
# config.prefer_lowp_gemm
debug_log(f"Model compile: task=stablefast config={config.__dict__}")
try:
t0 = time.time()
sd_model = sf.compile(sd_model, config)
@ -160,8 +166,12 @@ def compile_torch(sd_model, apply_to_components=True, op="Model"):
import torch._dynamo # pylint: disable=unused-import,redefined-outer-name
torch._dynamo.reset() # pylint: disable=protected-access
log.debug(f"{op} compile: task=torch backends={torch._dynamo.list_backends()}") # pylint: disable=protected-access
debug_log(f"{op} compile: options={shared.opts.cuda_compile_options} mode={shared.opts.cuda_compile_mode} backend={shared.opts.cuda_compile_backend} targets={shared.opts.cuda_compile}")
compiled_components = []
def torch_compile_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
compiled_components.append(model.__class__.__name__)
if hasattr(model, 'compile_repeated_blocks') and 'repeated' in shared.opts.cuda_compile_options:
model.compile_repeated_blocks(
mode=shared.opts.cuda_compile_mode,
@ -195,11 +205,12 @@ def compile_torch(sd_model, apply_to_components=True, op="Model"):
return sd_model
elif shared.opts.cuda_compile_backend == "migraphx":
pass # pylint: disable=unused-import
log_level = logging.WARNING if 'verbose' in shared.opts.cuda_compile_options else logging.CRITICAL # pylint: disable=protected-access
verbose = debug or 'verbose' in shared.opts.cuda_compile_options
log_level = logging.WARNING if verbose else logging.CRITICAL # pylint: disable=protected-access
if hasattr(torch, '_logging'):
torch._logging.set_logs(dynamo=log_level, aot=log_level, inductor=log_level) # pylint: disable=protected-access
torch._dynamo.config.verbose = 'verbose' in shared.opts.cuda_compile_options # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = 'verbose' not in shared.opts.cuda_compile_options # pylint: disable=protected-access
torch._dynamo.config.verbose = verbose # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = not verbose # pylint: disable=protected-access
try:
torch._inductor.config.conv_1x1_as_mm = True # pylint: disable=protected-access
@ -214,7 +225,7 @@ def compile_torch(sd_model, apply_to_components=True, op="Model"):
if apply_to_components:
sd_model = sd_models.apply_function_to_model(sd_model, function=torch_compile_model, options=shared.opts.cuda_compile, op="compile")
else:
sd_model = torch_compile_model(sd_model)
sd_model = torch_compile_model(sd_model, op=op)
setup_logging() # compile messes with logging so reset is needed
if apply_to_components and 'precompile' in shared.opts.cuda_compile_options:
@ -225,6 +236,7 @@ def compile_torch(sd_model, apply_to_components=True, op="Model"):
pass
t1 = time.time()
log.info(f"{op} compile: task=torch time={t1-t0:.2f}")
debug_log(f"{op} compile: task=torch completed components={compiled_components} targets={shared.opts.cuda_compile} verbose={verbose} time={t1-t0:.2f}")
except Exception as e:
log.warning(f"{op} compile: task=torch {e}")
errors.display(e, 'Compile')
@ -249,6 +261,7 @@ def compile_deepcache(sd_model):
except Exception as e:
log.warning(f'Model compile: task=deepcache {e}')
return sd_model
debug_log(f"Model compile: task=deepcache pipeline={sd_model.__class__.__name__} interval={shared.opts.deep_cache_interval}")
t0 = time.time()
check_deepcache(False)
deepcache_worker = DeepCacheSDHelper(pipe=sd_model)