From 3dcb70e8a241296d7dde9d555c2d5027fe0bbab1 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 31 Jan 2025 09:31:12 -0500 Subject: [PATCH] device init logging Signed-off-by: Vladimir Mandic --- modules/devices.py | 151 +++++++++++++++++++--------------- modules/ui_control_helpers.py | 4 +- webui.py | 4 +- 3 files changed, 87 insertions(+), 72 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 03790f9e1..030593236 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -278,14 +278,14 @@ def set_cuda_sync_mode(mode): def set_cuda_memory_limit(): if not cuda_ok or opts.cuda_mem_fraction == 0: return - from modules.shared import cmd_opts try: + from modules.shared import cmd_opts torch_gc(force=True) mem = torch.cuda.get_device_properties(device).total_memory torch.cuda.set_per_process_memory_fraction(float(opts.cuda_mem_fraction), cmd_opts.device_id if cmd_opts.device_id is not None else 0) - log.info(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}') + log.info(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}') except Exception as e: - log.warning(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}') + log.warning(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}') def set_cuda_tunable(): @@ -298,11 +298,18 @@ def set_cuda_tunable(): torch.cuda.tunable.set_max_tuning_duration(1000) # set to high value as actual is min(duration, iterations) torch.cuda.tunable.set_max_tuning_iterations(opts.torch_tunable_limit) fn = os.path.join(opts.tunable_dir, 'tunable.csv') + lines={0} + try: + if os.path.exists(fn): + with open(fn, 'r', encoding='utf8') as f: + lines = sum(1 for _line in f) + except Exception: + pass torch.cuda.tunable.set_filename(fn) if torch.cuda.tunable.is_enabled(): - log.debug(f'Torce tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}"') - except Exception: - pass + log.debug(f'Torche tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}" entries={lines}') + except Exception as e: + log.warning(f'Torch tunable: {e}') def test_fp16(): @@ -361,90 +368,98 @@ def test_bf16(): def set_cudnn_params(): - if cuda_ok: + if not cuda_ok: + return + try: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) + except Exception as e: + log.warning(f'Torch matmul: {e}') + if torch.backends.cudnn.is_available(): try: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True - torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) - except Exception: - pass - if torch.backends.cudnn.is_available(): - try: - torch.backends.cudnn.deterministic = opts.cudnn_deterministic - torch.use_deterministic_algorithms(opts.cudnn_deterministic) - if opts.cudnn_deterministic: - os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8') - torch.backends.cudnn.benchmark = True - if opts.cudnn_benchmark: - log.debug('Torch cuDNN: enable benchmark') - torch.backends.cudnn.benchmark_limit = 0 - torch.backends.cudnn.allow_tf32 = True - except Exception: - pass + torch.backends.cudnn.deterministic = opts.cudnn_deterministic + torch.use_deterministic_algorithms(opts.cudnn_deterministic) + if opts.cudnn_deterministic: + os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8') + torch.backends.cudnn.benchmark = True + if opts.cudnn_benchmark: + log.debug('Torch cuDNN: enable benchmark') + torch.backends.cudnn.benchmark_limit = 0 + torch.backends.cudnn.allow_tf32 = True + except Exception as e: + log.warning(f'Torch cudnn: {e}') def override_ipex_math(): if backend == "ipex": try: torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32) - except Exception: - pass + except Exception as e: + log.warning(f'Torch ipex: {e}') def set_sdpa_params(): try: - if opts.cross_attention_optimization == "Scaled-Dot-Product": + if opts.cross_attention_optimization != "Scaled-Dot-Product": + return + try: torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options) torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options) torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options) + except Exception as e: + log.warning(f'Torch attention: {e}') + try: global sdpa_original # pylint: disable=global-statement if sdpa_original is not None: torch.nn.functional.scaled_dot_product_attention = sdpa_original else: sdpa_original = torch.nn.functional.scaled_dot_product_attention - if backend == "rocm": - if 'Flash attention' in opts.sdp_options: - try: - # https://github.com/huggingface/diffusers/discussions/7172 - from flash_attn import flash_attn_func - sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention - @wraps(sdpa_pre_flash_atten) - def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): - if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: - return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2) - else: - return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten - log.debug('ROCm Flash Attention Hijacked') - except Exception as err: - log.error(f'ROCm Flash Attention failed: {err}') - if 'Sage attention' in opts.sdp_options: + except Exception as e: + log.warning(f'Torch SDPA: {e}') + if backend == "rocm": + if 'Flash attention' in opts.sdp_options: try: - install('sageattention') - from sageattention import sageattn - sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention - @wraps(sdpa_pre_sage_atten) - def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): - if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32: - return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + # https://github.com/huggingface/diffusers/discussions/7172 + from flash_attn import flash_attn_func + sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention + @wraps(sdpa_pre_flash_atten) + def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32: + return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2) else: - return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) - torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten - log.debug('SDPA Sage Attention Hijacked') + return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten + log.debug('Torch ROCm Flash Attention') except Exception as err: - log.error(f'SDPA Sage Attention failed: {err}') - if 'Dynamic attention' in opts.sdp_options: - try: - global sdpa_pre_dyanmic_atten # pylint: disable=global-statement - sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention - from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention - torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention - log.debug('SDPA Dynamic Attention Hijacked') - except Exception as err: - log.error(f'SDPA Dynamic Attention failed: {err}') - except Exception: - pass + log.error(f'Torch ROCm Flash Attention: {err}') + if 'Sage attention' in opts.sdp_options: + try: + install('sageattention') + from sageattention import sageattn + sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention + @wraps(sdpa_pre_sage_atten) + def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32: + return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + else: + return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten + log.debug('Torch SDPA Sage Attention') + except Exception as err: + log.error(f'Torch SDPA Sage Attention: {err}') + if 'Dynamic attention' in opts.sdp_options: + try: + global sdpa_pre_dyanmic_atten # pylint: disable=global-statement + sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention + from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention + torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention + log.debug('Torch SDPA Dynamic Attention') + except Exception as err: + log.error(f'Torch SDPA Dynamic Attention: {err}') + except Exception as e: + log.warning(f'Torch SDPA: {e}') def set_dtype(): diff --git a/modules/ui_control_helpers.py b/modules/ui_control_helpers.py index 553b5fbac..7573d6f69 100644 --- a/modules/ui_control_helpers.py +++ b/modules/ui_control_helpers.py @@ -24,13 +24,13 @@ def initialize(): from modules.control.units import xs # vislearn ControlNet-XS from modules.control.units import lite # vislearn ControlNet-XS from modules.control.units import t2iadapter # TencentARC T2I-Adapter - shared.log.debug(f'UI initialize: control models={shared.opts.control_dir}') + shared.log.debug(f'UI initialize: control models="{shared.opts.control_dir}"') controlnet.cache_dir = os.path.join(shared.opts.control_dir, 'controlnet') xs.cache_dir = os.path.join(shared.opts.control_dir, 'xs') lite.cache_dir = os.path.join(shared.opts.control_dir, 'lite') t2iadapter.cache_dir = os.path.join(shared.opts.control_dir, 'adapter') processors.cache_dir = os.path.join(shared.opts.control_dir, 'processor') - masking.cache_dir = os.path.join(shared.opts.control_dir, 'segment') + masking.cache_dir = os.path.join(shared.opts.control_dir, 'segment') unit.default_device = devices.device unit.default_dtype = devices.dtype try: diff --git a/webui.py b/webui.py index 4f3de9617..fed72944d 100644 --- a/webui.py +++ b/webui.py @@ -233,9 +233,9 @@ def get_remote_ip(): def start_common(): log.debug('Entering start sequence') if shared.cmd_opts.data_dir is not None and len(shared.cmd_opts.data_dir) > 0: - log.info(f'Using data path: {shared.cmd_opts.data_dir}') + log.info(f'Base path: data="{shared.cmd_opts.data_dir}"') if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models': - log.info(f'Models path: {shared.cmd_opts.models_dir}') + log.info(f'Base path: models="{shared.cmd_opts.models_dir}"') paths.create_paths(shared.opts) async_policy() initialize()