device init logging

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3740/head
Vladimir Mandic 2025-01-31 09:31:12 -05:00
parent d0d9759840
commit 3dcb70e8a2
3 changed files with 87 additions and 72 deletions

View File

@ -278,14 +278,14 @@ def set_cuda_sync_mode(mode):
def set_cuda_memory_limit():
if not cuda_ok or opts.cuda_mem_fraction == 0:
return
from modules.shared import cmd_opts
try:
from modules.shared import cmd_opts
torch_gc(force=True)
mem = torch.cuda.get_device_properties(device).total_memory
torch.cuda.set_per_process_memory_fraction(float(opts.cuda_mem_fraction), cmd_opts.device_id if cmd_opts.device_id is not None else 0)
log.info(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}')
log.info(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}')
except Exception as e:
log.warning(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}')
log.warning(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}')
def set_cuda_tunable():
@ -298,11 +298,18 @@ def set_cuda_tunable():
torch.cuda.tunable.set_max_tuning_duration(1000) # set to high value as actual is min(duration, iterations)
torch.cuda.tunable.set_max_tuning_iterations(opts.torch_tunable_limit)
fn = os.path.join(opts.tunable_dir, 'tunable.csv')
lines={0}
try:
if os.path.exists(fn):
with open(fn, 'r', encoding='utf8') as f:
lines = sum(1 for _line in f)
except Exception:
pass
torch.cuda.tunable.set_filename(fn)
if torch.cuda.tunable.is_enabled():
log.debug(f'Torce tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}"')
except Exception:
pass
log.debug(f'Torche tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}" entries={lines}')
except Exception as e:
log.warning(f'Torch tunable: {e}')
def test_fp16():
@ -361,90 +368,98 @@ def test_bf16():
def set_cudnn_params():
if cuda_ok:
if not cuda_ok:
return
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except Exception as e:
log.warning(f'Torch matmul: {e}')
if torch.backends.cudnn.is_available():
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except Exception:
pass
if torch.backends.cudnn.is_available():
try:
torch.backends.cudnn.deterministic = opts.cudnn_deterministic
torch.use_deterministic_algorithms(opts.cudnn_deterministic)
if opts.cudnn_deterministic:
os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8')
torch.backends.cudnn.benchmark = True
if opts.cudnn_benchmark:
log.debug('Torch cuDNN: enable benchmark')
torch.backends.cudnn.benchmark_limit = 0
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
torch.backends.cudnn.deterministic = opts.cudnn_deterministic
torch.use_deterministic_algorithms(opts.cudnn_deterministic)
if opts.cudnn_deterministic:
os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8')
torch.backends.cudnn.benchmark = True
if opts.cudnn_benchmark:
log.debug('Torch cuDNN: enable benchmark')
torch.backends.cudnn.benchmark_limit = 0
torch.backends.cudnn.allow_tf32 = True
except Exception as e:
log.warning(f'Torch cudnn: {e}')
def override_ipex_math():
if backend == "ipex":
try:
torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32)
except Exception:
pass
except Exception as e:
log.warning(f'Torch ipex: {e}')
def set_sdpa_params():
try:
if opts.cross_attention_optimization == "Scaled-Dot-Product":
if opts.cross_attention_optimization != "Scaled-Dot-Product":
return
try:
torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options)
torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options)
torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options)
except Exception as e:
log.warning(f'Torch attention: {e}')
try:
global sdpa_original # pylint: disable=global-statement
if sdpa_original is not None:
torch.nn.functional.scaled_dot_product_attention = sdpa_original
else:
sdpa_original = torch.nn.functional.scaled_dot_product_attention
if backend == "rocm":
if 'Flash attention' in opts.sdp_options:
try:
# https://github.com/huggingface/diffusers/discussions/7172
from flash_attn import flash_attn_func
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
else:
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
log.debug('ROCm Flash Attention Hijacked')
except Exception as err:
log.error(f'ROCm Flash Attention failed: {err}')
if 'Sage attention' in opts.sdp_options:
except Exception as e:
log.warning(f'Torch SDPA: {e}')
if backend == "rocm":
if 'Flash attention' in opts.sdp_options:
try:
install('sageattention')
from sageattention import sageattn
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32:
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
# https://github.com/huggingface/diffusers/discussions/7172
from flash_attn import flash_attn_func
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
else:
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
log.debug('SDPA Sage Attention Hijacked')
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
log.debug('Torch ROCm Flash Attention')
except Exception as err:
log.error(f'SDPA Sage Attention failed: {err}')
if 'Dynamic attention' in opts.sdp_options:
try:
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
log.debug('SDPA Dynamic Attention Hijacked')
except Exception as err:
log.error(f'SDPA Dynamic Attention failed: {err}')
except Exception:
pass
log.error(f'Torch ROCm Flash Attention: {err}')
if 'Sage attention' in opts.sdp_options:
try:
install('sageattention')
from sageattention import sageattn
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_sage_atten)
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32:
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
else:
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
log.debug('Torch SDPA Sage Attention')
except Exception as err:
log.error(f'Torch SDPA Sage Attention: {err}')
if 'Dynamic attention' in opts.sdp_options:
try:
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
log.debug('Torch SDPA Dynamic Attention')
except Exception as err:
log.error(f'Torch SDPA Dynamic Attention: {err}')
except Exception as e:
log.warning(f'Torch SDPA: {e}')
def set_dtype():

View File

@ -24,13 +24,13 @@ def initialize():
from modules.control.units import xs # vislearn ControlNet-XS
from modules.control.units import lite # vislearn ControlNet-XS
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
shared.log.debug(f'UI initialize: control models={shared.opts.control_dir}')
shared.log.debug(f'UI initialize: control models="{shared.opts.control_dir}"')
controlnet.cache_dir = os.path.join(shared.opts.control_dir, 'controlnet')
xs.cache_dir = os.path.join(shared.opts.control_dir, 'xs')
lite.cache_dir = os.path.join(shared.opts.control_dir, 'lite')
t2iadapter.cache_dir = os.path.join(shared.opts.control_dir, 'adapter')
processors.cache_dir = os.path.join(shared.opts.control_dir, 'processor')
masking.cache_dir = os.path.join(shared.opts.control_dir, 'segment')
masking.cache_dir = os.path.join(shared.opts.control_dir, 'segment')
unit.default_device = devices.device
unit.default_dtype = devices.dtype
try:

View File

@ -233,9 +233,9 @@ def get_remote_ip():
def start_common():
log.debug('Entering start sequence')
if shared.cmd_opts.data_dir is not None and len(shared.cmd_opts.data_dir) > 0:
log.info(f'Using data path: {shared.cmd_opts.data_dir}')
log.info(f'Base path: data="{shared.cmd_opts.data_dir}"')
if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models':
log.info(f'Models path: {shared.cmd_opts.models_dir}')
log.info(f'Base path: models="{shared.cmd_opts.models_dir}"')
paths.create_paths(shared.opts)
async_policy()
initialize()