mirror of https://github.com/vladmandic/automatic
parent
d0d9759840
commit
3dcb70e8a2
|
|
@ -278,14 +278,14 @@ def set_cuda_sync_mode(mode):
|
|||
def set_cuda_memory_limit():
|
||||
if not cuda_ok or opts.cuda_mem_fraction == 0:
|
||||
return
|
||||
from modules.shared import cmd_opts
|
||||
try:
|
||||
from modules.shared import cmd_opts
|
||||
torch_gc(force=True)
|
||||
mem = torch.cuda.get_device_properties(device).total_memory
|
||||
torch.cuda.set_per_process_memory_fraction(float(opts.cuda_mem_fraction), cmd_opts.device_id if cmd_opts.device_id is not None else 0)
|
||||
log.info(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}')
|
||||
log.info(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}')
|
||||
except Exception as e:
|
||||
log.warning(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}')
|
||||
log.warning(f'Torch memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}')
|
||||
|
||||
|
||||
def set_cuda_tunable():
|
||||
|
|
@ -298,11 +298,18 @@ def set_cuda_tunable():
|
|||
torch.cuda.tunable.set_max_tuning_duration(1000) # set to high value as actual is min(duration, iterations)
|
||||
torch.cuda.tunable.set_max_tuning_iterations(opts.torch_tunable_limit)
|
||||
fn = os.path.join(opts.tunable_dir, 'tunable.csv')
|
||||
lines={0}
|
||||
try:
|
||||
if os.path.exists(fn):
|
||||
with open(fn, 'r', encoding='utf8') as f:
|
||||
lines = sum(1 for _line in f)
|
||||
except Exception:
|
||||
pass
|
||||
torch.cuda.tunable.set_filename(fn)
|
||||
if torch.cuda.tunable.is_enabled():
|
||||
log.debug(f'Torce tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}"')
|
||||
except Exception:
|
||||
pass
|
||||
log.debug(f'Torche tunable: enabled={torch.cuda.tunable.is_enabled()} tuning={torch.cuda.tunable.tuning_is_enabled()} iterations={torch.cuda.tunable.get_max_tuning_iterations()} duration={torch.cuda.tunable.get_max_tuning_duration()} fn="{fn}" entries={lines}')
|
||||
except Exception as e:
|
||||
log.warning(f'Torch tunable: {e}')
|
||||
|
||||
|
||||
def test_fp16():
|
||||
|
|
@ -361,90 +368,98 @@ def test_bf16():
|
|||
|
||||
|
||||
def set_cudnn_params():
|
||||
if cuda_ok:
|
||||
if not cuda_ok:
|
||||
return
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except Exception as e:
|
||||
log.warning(f'Torch matmul: {e}')
|
||||
if torch.backends.cudnn.is_available():
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except Exception:
|
||||
pass
|
||||
if torch.backends.cudnn.is_available():
|
||||
try:
|
||||
torch.backends.cudnn.deterministic = opts.cudnn_deterministic
|
||||
torch.use_deterministic_algorithms(opts.cudnn_deterministic)
|
||||
if opts.cudnn_deterministic:
|
||||
os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8')
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if opts.cudnn_benchmark:
|
||||
log.debug('Torch cuDNN: enable benchmark')
|
||||
torch.backends.cudnn.benchmark_limit = 0
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
except Exception:
|
||||
pass
|
||||
torch.backends.cudnn.deterministic = opts.cudnn_deterministic
|
||||
torch.use_deterministic_algorithms(opts.cudnn_deterministic)
|
||||
if opts.cudnn_deterministic:
|
||||
os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':4096:8')
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if opts.cudnn_benchmark:
|
||||
log.debug('Torch cuDNN: enable benchmark')
|
||||
torch.backends.cudnn.benchmark_limit = 0
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
except Exception as e:
|
||||
log.warning(f'Torch cudnn: {e}')
|
||||
|
||||
|
||||
def override_ipex_math():
|
||||
if backend == "ipex":
|
||||
try:
|
||||
torch.xpu.set_fp32_math_mode(mode=torch.xpu.FP32MathMode.TF32)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.warning(f'Torch ipex: {e}')
|
||||
|
||||
|
||||
def set_sdpa_params():
|
||||
try:
|
||||
if opts.cross_attention_optimization == "Scaled-Dot-Product":
|
||||
if opts.cross_attention_optimization != "Scaled-Dot-Product":
|
||||
return
|
||||
try:
|
||||
torch.backends.cuda.enable_flash_sdp('Flash attention' in opts.sdp_options)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp('Memory attention' in opts.sdp_options)
|
||||
torch.backends.cuda.enable_math_sdp('Math attention' in opts.sdp_options)
|
||||
except Exception as e:
|
||||
log.warning(f'Torch attention: {e}')
|
||||
try:
|
||||
global sdpa_original # pylint: disable=global-statement
|
||||
if sdpa_original is not None:
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_original
|
||||
else:
|
||||
sdpa_original = torch.nn.functional.scaled_dot_product_attention
|
||||
if backend == "rocm":
|
||||
if 'Flash attention' in opts.sdp_options:
|
||||
try:
|
||||
# https://github.com/huggingface/diffusers/discussions/7172
|
||||
from flash_attn import flash_attn_func
|
||||
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_flash_atten)
|
||||
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
||||
return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
|
||||
else:
|
||||
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
||||
log.debug('ROCm Flash Attention Hijacked')
|
||||
except Exception as err:
|
||||
log.error(f'ROCm Flash Attention failed: {err}')
|
||||
if 'Sage attention' in opts.sdp_options:
|
||||
except Exception as e:
|
||||
log.warning(f'Torch SDPA: {e}')
|
||||
if backend == "rocm":
|
||||
if 'Flash attention' in opts.sdp_options:
|
||||
try:
|
||||
install('sageattention')
|
||||
from sageattention import sageattn
|
||||
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_sage_atten)
|
||||
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32:
|
||||
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
# https://github.com/huggingface/diffusers/discussions/7172
|
||||
from flash_attn import flash_attn_func
|
||||
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_flash_atten)
|
||||
def sdpa_flash_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
||||
return flash_attn_func(q=query.transpose(1, 2), k=key.transpose(1, 2), v=value.transpose(1, 2), dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
|
||||
else:
|
||||
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
|
||||
log.debug('SDPA Sage Attention Hijacked')
|
||||
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
||||
log.debug('Torch ROCm Flash Attention')
|
||||
except Exception as err:
|
||||
log.error(f'SDPA Sage Attention failed: {err}')
|
||||
if 'Dynamic attention' in opts.sdp_options:
|
||||
try:
|
||||
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
|
||||
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
|
||||
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
|
||||
log.debug('SDPA Dynamic Attention Hijacked')
|
||||
except Exception as err:
|
||||
log.error(f'SDPA Dynamic Attention failed: {err}')
|
||||
except Exception:
|
||||
pass
|
||||
log.error(f'Torch ROCm Flash Attention: {err}')
|
||||
if 'Sage attention' in opts.sdp_options:
|
||||
try:
|
||||
install('sageattention')
|
||||
from sageattention import sageattn
|
||||
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
@wraps(sdpa_pre_sage_atten)
|
||||
def sdpa_sage_atten(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
|
||||
if query.shape[-1] in {128, 96, 64} and attn_mask is None and query.dtype != torch.float32:
|
||||
return sageattn(q=query, k=key, v=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
else:
|
||||
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
|
||||
log.debug('Torch SDPA Sage Attention')
|
||||
except Exception as err:
|
||||
log.error(f'Torch SDPA Sage Attention: {err}')
|
||||
if 'Dynamic attention' in opts.sdp_options:
|
||||
try:
|
||||
global sdpa_pre_dyanmic_atten # pylint: disable=global-statement
|
||||
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
|
||||
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
|
||||
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
|
||||
log.debug('Torch SDPA Dynamic Attention')
|
||||
except Exception as err:
|
||||
log.error(f'Torch SDPA Dynamic Attention: {err}')
|
||||
except Exception as e:
|
||||
log.warning(f'Torch SDPA: {e}')
|
||||
|
||||
|
||||
def set_dtype():
|
||||
|
|
|
|||
|
|
@ -24,13 +24,13 @@ def initialize():
|
|||
from modules.control.units import xs # vislearn ControlNet-XS
|
||||
from modules.control.units import lite # vislearn ControlNet-XS
|
||||
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
|
||||
shared.log.debug(f'UI initialize: control models={shared.opts.control_dir}')
|
||||
shared.log.debug(f'UI initialize: control models="{shared.opts.control_dir}"')
|
||||
controlnet.cache_dir = os.path.join(shared.opts.control_dir, 'controlnet')
|
||||
xs.cache_dir = os.path.join(shared.opts.control_dir, 'xs')
|
||||
lite.cache_dir = os.path.join(shared.opts.control_dir, 'lite')
|
||||
t2iadapter.cache_dir = os.path.join(shared.opts.control_dir, 'adapter')
|
||||
processors.cache_dir = os.path.join(shared.opts.control_dir, 'processor')
|
||||
masking.cache_dir = os.path.join(shared.opts.control_dir, 'segment')
|
||||
masking.cache_dir = os.path.join(shared.opts.control_dir, 'segment')
|
||||
unit.default_device = devices.device
|
||||
unit.default_dtype = devices.dtype
|
||||
try:
|
||||
|
|
|
|||
4
webui.py
4
webui.py
|
|
@ -233,9 +233,9 @@ def get_remote_ip():
|
|||
def start_common():
|
||||
log.debug('Entering start sequence')
|
||||
if shared.cmd_opts.data_dir is not None and len(shared.cmd_opts.data_dir) > 0:
|
||||
log.info(f'Using data path: {shared.cmd_opts.data_dir}')
|
||||
log.info(f'Base path: data="{shared.cmd_opts.data_dir}"')
|
||||
if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models':
|
||||
log.info(f'Models path: {shared.cmd_opts.models_dir}')
|
||||
log.info(f'Base path: models="{shared.cmd_opts.models_dir}"')
|
||||
paths.create_paths(shared.opts)
|
||||
async_policy()
|
||||
initialize()
|
||||
|
|
|
|||
Loading…
Reference in New Issue