new server info panel

Signed-off-by: vladmandic <mandic00@live.com>
pull/4690/head
vladmandic 2026-03-18 15:37:57 +01:00
parent c187aa706a
commit fb2f9ea650
11 changed files with 139 additions and 43 deletions

View File

@ -31,6 +31,7 @@ const jsConfig = defineConfig([
...globals.jquery, ...globals.jquery,
panzoom: 'readonly', panzoom: 'readonly',
authFetch: 'readonly', authFetch: 'readonly',
initServerInfo: 'readonly',
log: 'readonly', log: 'readonly',
debug: 'readonly', debug: 'readonly',
error: 'readonly', error: 'readonly',

@ -1 +1 @@
Subproject commit f90df02fe036d737096e4c3083d86d9a09fae932 Subproject commit ba8955c02f633da243b8dc66d3631f68d986f6ca

View File

@ -20,6 +20,11 @@ class Dot(dict): # dot notation access to dictionary attributes
__setattr__ = dict.__setitem__ __setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__ __delattr__ = dict.__delitem__
class Torch(dict):
def set(self, **kwargs):
for k, v in kwargs.items():
self[k] = v
version = { version = {
'app': 'sd.next', 'app': 'sd.next',
'updated': 'unknown', 'updated': 'unknown',
@ -75,6 +80,8 @@ control_extensions = [ # 3rd party extensions marked as safe for control ui
'IP Adapters', 'IP Adapters',
'Remove background', 'Remove background',
] ]
gpu_info = []
torch_info = Torch()
try: try:
@ -588,7 +595,7 @@ def install_rocm_zluda():
if device_id < len(amd_gpus): if device_id < len(amd_gpus):
device = amd_gpus[device_id] device = amd_gpus[device_id]
if sys.platform == "win32" and not args.use_zluda and device is not None and device.therock is not None and not installed("rocm"): if sys.platform == "win32" and (not args.use_zluda) and (device is not None) and (device.therock is not None) and not installed("rocm"):
check_python(supported_minors=[11, 12, 13], reason='ROCm backend requires a Python version between 3.11 and 3.13') check_python(supported_minors=[11, 12, 13], reason='ROCm backend requires a Python version between 3.11 and 3.13')
install(f"rocm[devel,libraries] --index-url https://rocm.nightlies.amd.com/{device.therock}") install(f"rocm[devel,libraries] --index-url https://rocm.nightlies.amd.com/{device.therock}")
rocm.refresh() rocm.refresh()
@ -750,6 +757,18 @@ def check_cudnn():
os.environ['CUDA_PATH'] = cuda_path os.environ['CUDA_PATH'] = cuda_path
def get_cuda_arch(capability):
major, minor = capability
mapping = {9: "Hopper",
8: "Ada Lovelace" if minor == 9 else "Ampere",
7: "Turing" if minor == 5 else "Volta",
6: "Pascal",
5: "Maxwell",
3: "Kepler"}
name = mapping.get(major, "Unknown")
return f"{major}.{minor} {name}"
# check torch version # check torch version
def check_torch(): def check_torch():
log.info('Torch: verifying installation') log.info('Torch: verifying installation')
@ -832,6 +851,7 @@ def check_torch():
log.info(f'Torch backend: type=IPEX version={ipex.__version__}') log.info(f'Torch backend: type=IPEX version={ipex.__version__}')
except Exception: except Exception:
pass pass
torch_info.set(version=torch.__version__)
if 'cpu' in torch.__version__: if 'cpu' in torch.__version__:
if is_cuda_available: if is_cuda_available:
if args.use_cuda: if args.use_cuda:
@ -845,20 +865,44 @@ def check_torch():
install(torch_command, 'torch torchvision', quiet=True, reinstall=True, force=True) # foce reinstall install(torch_command, 'torch torchvision', quiet=True, reinstall=True, force=True) # foce reinstall
else: else:
log.warning(f'Torch: version="{torch.__version__}" CPU version installed and ROCm is available - consider reinstalling') log.warning(f'Torch: version="{torch.__version__}" CPU version installed and ROCm is available - consider reinstalling')
if args.use_openvino:
torch_info.set(type='openvino')
else:
torch_info.set(type='cpu')
if hasattr(torch, "xpu") and torch.xpu.is_available() and allow_ipex: if hasattr(torch, "xpu") and torch.xpu.is_available() and allow_ipex:
if shutil.which('icpx') is not None: if shutil.which('icpx') is not None:
log.info(f'{os.popen("icpx --version").read().rstrip()}') log.info(f'{os.popen("icpx --version").read().rstrip()}')
torch_info.set(type='xpu', oneapi=torch.xpu.runtime_version(), dpc=torch.xpu.dpcpp_version(), driver=torch.xpu.driver_version())
for device in range(torch.xpu.device_count()): for device in range(torch.xpu.device_count()):
log.info(f'Torch detected: gpu="{torch.xpu.get_device_name(device)}" vram={round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024)} units={torch.xpu.get_device_properties(device).max_compute_units}') gpu = {
'gpu': torch.xpu.get_device_name(device),
'vram': round(torch.xpu.get_device_properties(device).total_memory / 1024 / 1024),
'units': torch.xpu.get_device_properties(device).max_compute_units,
}
log.info(f'Torch detected: {gpu}')
gpu_info.append(gpu)
elif torch.cuda.is_available() and (allow_cuda or allow_rocm): elif torch.cuda.is_available() and (allow_cuda or allow_rocm):
if torch.version.cuda and allow_cuda: if args.use_zluda:
log.info(f'Torch backend: version="{torch.__version__}" type=CUDA CUDA={torch.version.cuda} cuDNN={torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else "N/A"}') torch_info.set(type="zluda", cuda=torch.version.cuda)
elif torch.version.cuda and allow_cuda:
torch_info.set(type='cuda', cuda=torch.version.cuda, cudnn=torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'N/A')
elif torch.version.hip and allow_rocm: elif torch.version.hip and allow_rocm:
log.info(f'Torch backend: version="{torch.__version__}" type=ROCm HIP={torch.version.hip}') torch_info.set(type='rocm', hip=torch.version.hip)
else: else:
log.warning('Unknown Torch backend') log.warning('Unknown Torch backend')
log.info(f"Torch backend: {torch_info}")
for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]: for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]:
log.info(f'Torch detected: gpu="{torch.cuda.get_device_name(device)}" vram={round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024)} arch={torch.cuda.get_device_capability(device)} cores={torch.cuda.get_device_properties(device).multi_processor_count}') gpu = {
'gpu': torch.cuda.get_device_name(device),
'vram': round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024),
'arch': get_cuda_arch(torch.cuda.get_device_capability(device)),
'cores': torch.cuda.get_device_properties(device).multi_processor_count,
}
gpu_info.append(gpu)
log.info(f'Torch detected: {gpu}')
else: else:
try: try:
if args.use_directml and allow_directml: if args.use_directml and allow_directml:
@ -867,7 +911,11 @@ def check_torch():
log.warning(f'Torch backend: DirectML ({dml_ver})') log.warning(f'Torch backend: DirectML ({dml_ver})')
log.warning('DirectML: end-of-life') log.warning('DirectML: end-of-life')
for i in range(0, torch_directml.device_count()): for i in range(0, torch_directml.device_count()):
log.info(f'Torch detected GPU: {torch_directml.device_name(i)}') gpu = {
'gpu': torch_directml.device_name(i),
}
gpu_info.append(gpu)
log.info(f'Torch detected GPU: {gpu}')
except Exception: except Exception:
log.warning("Torch reports CUDA not available") log.warning("Torch reports CUDA not available")
except Exception as e: except Exception as e:

View File

@ -24,9 +24,9 @@ async function authFetch(url, options = {}) {
let res; let res;
try { try {
res = await fetch(url, options); res = await fetch(url, options);
if (!res.ok) error('fetch', { status: res.status, url, user, token }); if (!res.ok) error('fetch', { status: res?.status || 503, url, user, token });
} catch (err) { } catch (err) {
error('fetch', { status: res.status, url, user, token, error: err }); error('fetch', { status: res?.status || 503, url, user, token, error: err });
} }
return res; return res;
} }

View File

@ -30,7 +30,7 @@ async function updateGPU() {
const gpuEl = document.getElementById('gpu'); const gpuEl = document.getElementById('gpu');
const gpuTable = document.getElementById('gpu-table'); const gpuTable = document.getElementById('gpu-table');
try { try {
const res = await authFetch(`${window.api}/gpu`); const res = await authFetch(`${window.api}/gpu-smi`);
if (!res.ok) { if (!res.ok) {
clearInterval(gpuInterval); clearInterval(gpuInterval);
gpuEl.style.display = 'none'; gpuEl.style.display = 'none';

View File

@ -39,12 +39,14 @@ class Api:
def register(self): def register(self):
# fetch js/css # fetch js/css
self.add_api_route("/js", server.get_js, methods=["GET"], auth=False) self.add_api_route("/js", server.get_js, methods=["GET"], auth=False)
# server api # server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str) self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str]) self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=list[str])
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"]) self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"]) self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"]) self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
self.add_api_route("/sdapi/v1/torch", server.get_torch, methods=["GET"])
self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus) self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus)
self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"]) self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"])
self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress) self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress)
@ -54,7 +56,8 @@ class Api:
self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"]) self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"])
self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory) self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory)
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=list[models.ResGPU]) self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu, methods=["GET"])
self.add_api_route("/sdapi/v1/gpu-smi", gpu.get_gpu_smi, methods=["GET"], response_model=list[models.ResGPU])
# core api using locking # core api using locking
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img, tags=["Generation"]) self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img, tags=["Generation"])

View File

@ -5,7 +5,17 @@ from modules.logger import log
device = None device = None
def get_gpu_status(): def get_gpu():
import installer
res = {}
if len(installer.gpu_info) == 1:
return installer.gpu_info[0]
for i, item in enumerate(installer.gpu_info):
res[i] = item
return res
def get_gpu_smi():
"""Return real-time GPU metrics (utilization, temperature, memory, clock speeds) via vendor-specific APIs (NVML, ROCm SMI, XPU SMI).""" """Return real-time GPU metrics (utilization, temperature, memory, clock speeds) via vendor-specific APIs (NVML, ROCm SMI, XPU SMI)."""
global device # pylint: disable=global-statement global device # pylint: disable=global-statement
if device is None: if device is None:
@ -37,5 +47,5 @@ class ResGPU(BaseModel):
if __name__ == '__main__': if __name__ == '__main__':
from rich import print as rprint from rich import print as rprint
for gpu in get_gpu_status(): for gpu in get_gpu_smi():
rprint(gpu) rprint(gpu)

View File

@ -19,7 +19,7 @@ ignore_endpoints = [
'/sdapi/v1/version', '/sdapi/v1/version',
'/sdapi/v1/log', '/sdapi/v1/log',
'/sdapi/v1/browser', '/sdapi/v1/browser',
'/sdapi/v1/gpu', '/sdapi/v1/gpu-smi',
'/sdapi/v1/network/thumb', '/sdapi/v1/network/thumb',
'/sdapi/v1/progress', '/sdapi/v1/progress',
] ]

View File

@ -8,14 +8,6 @@ from modules import shared
from modules.logger import log from modules.logger import log
from modules.api import models, helpers from modules.api import models, helpers
def _get_version():
return installer.get_version()
def post_shutdown():
log.info('Shutdown request received')
import sys
sys.exit(0)
def get_js(request: Request): def get_js(request: Request):
file = request.query_params.get("file", None) file = request.query_params.get("file", None)
@ -43,32 +35,34 @@ def get_js(request: Request):
media_type = 'application/octet-stream' media_type = 'application/octet-stream'
return FileResponse(file, media_type=media_type) return FileResponse(file, media_type=media_type)
def get_version():
return installer.get_version()
def get_motd(): def get_motd():
import requests import requests
motd = '' motd = ""
ver = _get_version() ver = get_version()
if ver.get('updated', None) is not None: if ver.get("updated", None) is not None:
motd = f"version <b>{ver['commit']} {ver['updated']}</b> <span style='color: var(--primary-500)'>{ver['url'].split('/')[-1]}</span><br>" # pylint: disable=use-maxsplit-arg motd = f"version <b>{ver['commit']} {ver['updated']}</b> <span style='color: var(--primary-500)'>{ver['url'].split('/')[-1]}</span><br>" # pylint: disable=use-maxsplit-arg
if shared.opts.motd: if shared.opts.motd:
try: try:
res = requests.get('https://vladmandic.github.io/sdnext/motd', timeout=3) res = requests.get("https://vladmandic.github.io/sdnext/motd", timeout=3)
if res.status_code == 200: if res.status_code == 200:
msg = (res.text or '').strip() msg = (res.text or "").strip()
log.info(f'MOTD: {msg if len(msg) > 0 else "N/A"}') log.info(f"MOTD: {msg if len(msg) > 0 else 'N/A'}")
motd += res.text motd += res.text
else: else:
log.error(f'MOTD: {res.status_code}') log.error(f"MOTD: {res.status_code}")
except Exception as err: except Exception as err:
log.error(f'MOTD: {err}') log.error(f"MOTD: {err}")
return motd return motd
def get_version():
return _get_version()
def get_platform(): def get_platform():
from installer import get_platform as installer_get_platform
from modules.loader import get_packages as loader_get_packages from modules.loader import get_packages as loader_get_packages
return { **installer_get_platform(), **loader_get_packages() } return { **installer.get_platform(), **loader_get_packages() }
def get_torch():
return dict(installer.torch_info)
def get_log(req: models.ReqGetLog = Depends()): def get_log(req: models.ReqGetLog = Depends()):
lines = log.buffer[:req.lines] if req.lines > 0 else log.buffer.copy() lines = log.buffer[:req.lines] if req.lines > 0 else log.buffer.copy()
@ -85,6 +79,11 @@ def post_log(req: models.ReqPostLog):
log.error(f'UI: {req.error}') log.error(f'UI: {req.error}')
return {} return {}
def post_shutdown():
log.info("Shutdown request received")
import sys
sys.exit(0)
def get_cmd_flags(): def get_cmd_flags():
return vars(shared.cmd_opts) return vars(shared.cmd_opts)

View File

@ -2,7 +2,7 @@ from functools import wraps
import torch import torch
from modules import rocm from modules import rocm
from modules.errors import log from modules.errors import log
from installer import install, installed from installer import install, installed, torch_info
def set_dynamic_attention(): def set_dynamic_attention():
@ -10,6 +10,7 @@ def set_dynamic_attention():
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
torch_info.set(attention='dynamic')
return sdpa_pre_dyanmic_atten return sdpa_pre_dyanmic_atten
except Exception as err: except Exception as err:
log.error(f'Torch attention: type="dynamic attention" {err}') log.error(f'Torch attention: type="dynamic attention" {err}')
@ -20,6 +21,7 @@ def set_triton_flash_attention(backend: str):
try: try:
if backend in {"rocm", "zluda"}: # flash_attn_triton_amd only works with AMD if backend in {"rocm", "zluda"}: # flash_attn_triton_amd only works with AMD
from modules.flash_attn_triton_amd import interface_fa from modules.flash_attn_triton_amd import interface_fa
sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_triton_flash_atten) @wraps(sdpa_pre_triton_flash_atten)
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
@ -42,6 +44,7 @@ def set_triton_flash_attention(backend: str):
kwargs["enable_gqa"] = enable_gqa kwargs["enable_gqa"] = enable_gqa
return sdpa_pre_triton_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) return sdpa_pre_triton_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
torch.nn.functional.scaled_dot_product_attention = sdpa_triton_flash_atten torch.nn.functional.scaled_dot_product_attention = sdpa_triton_flash_atten
torch_info.set(attention='triton')
log.debug('Torch attention: type="Triton Flash attention"') log.debug('Torch attention: type="Triton Flash attention"')
except Exception as err: except Exception as err:
log.error(f'Torch attention: type="Triton Flash attention" {err}') log.error(f'Torch attention: type="Triton Flash attention" {err}')
@ -78,6 +81,7 @@ def set_flex_attention():
return flex_attention(query, key, value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa) return flex_attention(query, key, value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa)
torch.nn.functional.scaled_dot_product_attention = sdpa_flex_atten torch.nn.functional.scaled_dot_product_attention = sdpa_flex_atten
torch_info.set(attention="flex")
log.debug('Torch attention: type="Flex attention"') log.debug('Torch attention: type="Flex attention"')
except Exception as err: except Exception as err:
log.error(f'Torch attention: type="Flex attention" {err}') log.error(f'Torch attention: type="Flex attention" {err}')
@ -93,6 +97,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
else: else:
install('flash-attn') install('flash-attn')
from flash_attn import flash_attn_func from flash_attn import flash_attn_func
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
@wraps(sdpa_pre_flash_atten) @wraps(sdpa_pre_flash_atten)
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, scale: float | None = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
@ -120,6 +125,7 @@ def set_ck_flash_attention(backend: str, device: torch.device):
kwargs["enable_gqa"] = enable_gqa kwargs["enable_gqa"] = enable_gqa
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
torch_info.set(attention="flash")
log.debug('Torch attention: type="Flash attention"') log.debug('Torch attention: type="Flash attention"')
except Exception as err: except Exception as err:
log.error(f'Torch attention: type="Flash attention" {err}') log.error(f'Torch attention: type="Flash attention" {err}')
@ -174,6 +180,7 @@ def set_sage_attention(backend: str, device: torch.device):
kwargs["enable_gqa"] = enable_gqa kwargs["enable_gqa"] = enable_gqa
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs) return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
torch_info.set(attention="sage")
log.debug(f'Torch attention: type="Sage attention" backend={"cuda" if use_cuda_backend else "auto"}') log.debug(f'Torch attention: type="Sage attention" backend={"cuda" if use_cuda_backend else "auto"}')
except Exception as err: except Exception as err:
log.error(f'Torch attention: type="Sage attention" {err}') log.error(f'Torch attention: type="Sage attention" {err}')
@ -208,19 +215,22 @@ def set_diffusers_attention(pipe, quiet:bool=False):
log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"') log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"')
if shared.opts.cross_attention_optimization == "Disabled": if shared.opts.cross_attention_optimization == "Disabled":
pass # do nothing torch_info.set(attention="disabled")
elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers
torch_info.set(attention="sdpa")
# set_attn(pipe, p.AttnProcessor2_0(), name="Scaled-Dot-Product") # set_attn(pipe, p.AttnProcessor2_0(), name="Scaled-Dot-Product")
pass
elif shared.opts.cross_attention_optimization == "xFormers": elif shared.opts.cross_attention_optimization == "xFormers":
if hasattr(pipe, 'enable_xformers_memory_efficient_attention'): if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
torch_info.set(attention="xformers")
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
else: else:
log.warning(f"Attention: xFormers is not compatible with {pipe.__class__.__name__}") log.warning(f"Attention: xFormers is not compatible with {pipe.__class__.__name__}")
elif shared.opts.cross_attention_optimization == "Batch matrix-matrix": elif shared.opts.cross_attention_optimization == "Batch matrix-matrix":
torch_info.set(attention="bmm")
set_attn(pipe, p.AttnProcessor(), name="Batch matrix-matrix") set_attn(pipe, p.AttnProcessor(), name="Batch matrix-matrix")
elif shared.opts.cross_attention_optimization == "Dynamic Attention BMM": elif shared.opts.cross_attention_optimization == "Dynamic Attention BMM":
from modules.sd_hijack_dynamic_atten import DynamicAttnProcessorBMM from modules.sd_hijack_dynamic_atten import DynamicAttnProcessorBMM
torch_info.set(attention="dynamic_bmm")
set_attn(pipe, DynamicAttnProcessorBMM(), name="Dynamic Attention BMM") set_attn(pipe, DynamicAttnProcessorBMM(), name="Dynamic Attention BMM")
if shared.opts.attention_slicing != "Default" and hasattr(pipe, "enable_attention_slicing") and hasattr(pipe, "disable_attention_slicing"): if shared.opts.attention_slicing != "Default" and hasattr(pipe, "enable_attention_slicing") and hasattr(pipe, "disable_attention_slicing"):

View File

@ -4,8 +4,10 @@ import time
import contextlib import contextlib
import importlib.metadata import importlib.metadata
import torch import torch
from installer import torch_info
from modules.logger import log
from modules import rocm, attention from modules import rocm, attention
from modules.errors import log, display, install as install_traceback from modules.errors import display, install as install_traceback
debug = os.environ.get('SD_DEVICE_DEBUG', None) is not None debug = os.environ.get('SD_DEVICE_DEBUG', None) is not None
@ -398,17 +400,35 @@ def test_triton(early: bool = False):
test_triton_func(torch.randn(16, device=device), torch.randn(16, device=device), torch.randn(16, device=device)) test_triton_func(torch.randn(16, device=device), torch.randn(16, device=device), torch.randn(16, device=device))
triton_ok = True triton_ok = True
else: else:
torch_info.set(triton=False)
triton_ok = False triton_ok = False
except Exception as e: except Exception as e:
torch_info.set(triton=False)
triton_ok = False triton_ok = False
line = str(e).splitlines()[0] line = str(e).splitlines()[0]
log.warning(f"Triton test fail: {line}") log.warning(f"Triton test fail: {line}")
if debug: if debug:
from modules import errors from modules import errors
errors.display(e, 'Triton') errors.display(e, 'Triton')
triton_version = False
if triton_ok:
if triton_version is None:
try:
import torch._inductor.triton as torch_triton
triton_version = torch_triton.__version__
except Exception:
pass
if triton_version is None:
try:
import triton
triton_version = triton.__version__
except Exception:
pass
torch_info.set(triton=triton_version)
t1 = time.time() t1 = time.time()
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Triton: pass={triton_ok} fn={fn} time={t1-t0:.2f}') log.debug(f'Triton: pass={triton_ok} version={triton_version} fn={fn} time={t1-t0:.2f}')
if not triton_ok and opts is not None: if not triton_ok and opts is not None:
opts.sdnq_dequantize_compile = False opts.sdnq_dequantize_compile = False
return triton_ok return triton_ok
@ -469,6 +489,7 @@ def set_sdpa_params():
torch.backends.cuda.enable_math_sdp('Math' in opts.sdp_options or 'Math attention' in opts.sdp_options) torch.backends.cuda.enable_math_sdp('Math' in opts.sdp_options or 'Math attention' in opts.sdp_options)
if hasattr(torch.backends.cuda, "allow_fp16_bf16_reduction_math_sdp"): # only valid for torch >= 2.5 if hasattr(torch.backends.cuda, "allow_fp16_bf16_reduction_math_sdp"): # only valid for torch >= 2.5
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
torch_info.set(attention="sdpa")
log.debug(f'Torch attention: type="sdpa" kernels={opts.sdp_options} overrides={opts.sdp_overrides}') log.debug(f'Torch attention: type="sdpa" kernels={opts.sdp_options} overrides={opts.sdp_overrides}')
except Exception as err: except Exception as err:
log.warning(f'Torch attention: type="sdpa" {err}') log.warning(f'Torch attention: type="sdpa" {err}')
@ -559,6 +580,10 @@ def set_dtype():
inference_context = contextlib.nullcontext inference_context = contextlib.nullcontext
else: else:
inference_context = torch.no_grad inference_context = torch.no_grad
if dtype == dtype_vae:
torch_info.set(dtype=str(dtype))
else:
torch_info.set(dtype=str(dtype), vae=str(dtype_vae))
def set_cuda_params(): def set_cuda_params():