mirror of https://github.com/vladmandic/automatic
Merge pull request #4726 from resonantsky/dev
Added further rocblas support enhancementspull/4733/head^2
commit
bfd9a0c0f5
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
|
@ -6,39 +7,28 @@ from typing import Dict, Optional
|
|||
import installer
|
||||
from modules.logger import log
|
||||
from modules.json_helpers import readfile, writefile
|
||||
from modules.shared import opts
|
||||
|
||||
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
|
||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||
|
||||
|
||||
def _check_rocm() -> bool:
|
||||
from modules import shared
|
||||
if getattr(shared.cmd_opts, 'use_rocm', False):
|
||||
return True
|
||||
if installer.torch_info.get('type') == 'rocm':
|
||||
return True
|
||||
import torch # pylint: disable=import-outside-toplevel
|
||||
return hasattr(torch.version, 'hip') and torch.version.hip is not None
|
||||
|
||||
|
||||
is_rocm = _check_rocm()
|
||||
|
||||
|
||||
CONFIG = Path(os.path.abspath(os.path.join('data', 'rocm.json')))
|
||||
|
||||
_cache: Optional[Dict[str, str]] = None # loaded once, invalidated on save
|
||||
|
||||
# Metadata key written into rocm.json to record which architecture profile is active.
|
||||
# Not an environment variable — always skipped during env application but preserved in the
|
||||
# Not an environment variable - always skipped during env application but preserved in the
|
||||
# saved config so that arch-safety enforcement is consistent across restarts.
|
||||
_ARCH_KEY = "_rocm_arch"
|
||||
|
||||
# Vars that must never appear in the process environment.
|
||||
#
|
||||
# _DTYPE_UNSAFE: alter FP16 inference dtype — must be cleared regardless of config
|
||||
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — DEBUG alias: routes all FP16 convs through BF16 exponent math
|
||||
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — API-level alias: same BF16-exponent effect
|
||||
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM — unstable experimental FP16 path
|
||||
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 — changes FP16 WrW atomic accumulation
|
||||
# _DTYPE_UNSAFE: alter FP16 inference dtype - must be cleared regardless of config
|
||||
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL - DEBUG alias: routes all FP16 convs through BF16 exponent math
|
||||
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL - API-level alias: same BF16-exponent effect
|
||||
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM - unstable experimental FP16 path
|
||||
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 - changes FP16 WrW atomic accumulation
|
||||
#
|
||||
# SOLVER_DISABLED_BY_DEFAULT: every solver known to be incompatible with this runtime
|
||||
# (FP32-only, training-only WrW/BWD, fixed-geometry mismatches, XDLOPS/CDNA-only, arch-specific).
|
||||
|
|
@ -53,18 +43,18 @@ _DTYPE_UNSAFE = {
|
|||
# regardless of saved config. Limited to dtype-corrupting vars only.
|
||||
# IMPORTANT: SOLVER_DISABLED_BY_DEFAULT is intentionally NOT included here.
|
||||
# When a solver var is absent (unset) MIOpen still calls IsApplicable() on every
|
||||
# conv-find — wasted probing overhead. When a var is explicitly "0" MIOpen skips
|
||||
# conv-find - wasted probing overhead. When a var is explicitly "0" MIOpen skips
|
||||
# IsApplicable() immediately. Solver defaults flow through the config loop as "0"
|
||||
# (their ROCM_ENV_VARS default is "0") so they are explicitly set to "0" in the env.
|
||||
_UNSET_VARS = _DTYPE_UNSAFE
|
||||
|
||||
# Additional environment vars that must be removed from the process before MIOpen loads.
|
||||
# These are not MIOpen solver toggles but can corrupt MIOpen's runtime behaviour:
|
||||
# HIP_PATH / HIP_PATH_71 — point to the system AMD ROCm install; override the venv-bundled
|
||||
# HIP_PATH / HIP_PATH_71 - point to the system AMD ROCm install; override the venv-bundled
|
||||
# _rocm_sdk_devel DLLs with a potentially mismatched system version
|
||||
# QML_*/QT_* — QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
|
||||
# QML_*/QT_* - QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
|
||||
# PyTorch but can conflict with Gradio's embedded Qt helpers
|
||||
# PYENV_VIRTUALENV_DISABLE_PROMPT — pyenv noise that confuses venv detection
|
||||
# PYENV_VIRTUALENV_DISABLE_PROMPT - pyenv noise that confuses venv detection
|
||||
_EXTRA_CLEAR_VARS = {
|
||||
"HIP_PATH",
|
||||
"HIP_PATH_71",
|
||||
|
|
@ -72,7 +62,7 @@ _EXTRA_CLEAR_VARS = {
|
|||
"QML_DISABLE_DISK_CACHE",
|
||||
"QML_FORCE_DISK_CACHE",
|
||||
"QT_DISABLE_SHADER_DISK_CACHE",
|
||||
# PERF_VALS vars are NOT boolean toggles — MIOpen reads them as perf-config strings.
|
||||
# PERF_VALS vars are NOT boolean toggles - MIOpen reads them as perf-config strings.
|
||||
# If inherited from a parent shell with value "1", MIOpen's GetPerfConfFromEnv parses
|
||||
# "1" as a degenerate config and can return dtype=float32 output from FP16 tensors.
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS",
|
||||
|
|
@ -81,12 +71,12 @@ _EXTRA_CLEAR_VARS = {
|
|||
|
||||
# Solvers whose MIOpen IsApplicable() explicitly rejects non-FP32 tensors.
|
||||
# They are safe to leave enabled in FP32 mode. When the active dtype is FP16 or BF16
|
||||
# we force them OFF so MIOpen skips the IsApplicable probe entirely — avoids overhead on
|
||||
# we force them OFF so MIOpen skips the IsApplicable probe entirely - avoids overhead on
|
||||
# every conv shape find. These are NOT in _UNSET_VARS because they are valid in FP32.
|
||||
_FP32_ONLY_SOLVERS = {
|
||||
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution — FP32 only (MIOpen source: IsFp32 check)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 — FP32 only
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd — FP32 only
|
||||
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution - FP32 only (MIOpen source: IsFp32 check)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 - FP32 only
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd - FP32 only
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -106,8 +96,7 @@ def _resolve_dtype() -> str:
|
|||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from modules import shared as _sh # pylint: disable=import-outside-toplevel
|
||||
v = getattr(getattr(_sh, 'opts', None), 'cuda_dtype', None)
|
||||
v = getattr(opts, 'cuda_dtype', None)
|
||||
if v in ('FP16', 'BF16', 'FP32'):
|
||||
return v
|
||||
except Exception:
|
||||
|
|
@ -118,17 +107,25 @@ def _resolve_dtype() -> str:
|
|||
# --- venv helpers ---
|
||||
|
||||
def _get_venv() -> str:
|
||||
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
|
||||
return sys.prefix
|
||||
|
||||
|
||||
def _get_root() -> str:
|
||||
from modules.paths import script_path # pylint: disable=import-outside-toplevel
|
||||
return str(script_path)
|
||||
|
||||
|
||||
def _expand_venv(value: str) -> str:
|
||||
return value.replace("{VIRTUAL_ENV}", _get_venv())
|
||||
return value.replace("{VIRTUAL_ENV}", _get_venv()).replace("{ROOT}", _get_root())
|
||||
|
||||
|
||||
def _collapse_venv(value: str) -> str:
|
||||
venv = _get_venv()
|
||||
root = _get_root()
|
||||
if venv and value.startswith(venv):
|
||||
return "{VIRTUAL_ENV}" + value[len(venv):]
|
||||
if root and value.startswith(root):
|
||||
return "{ROOT}" + value[len(root):]
|
||||
return value
|
||||
|
||||
|
||||
|
|
@ -163,7 +160,7 @@ def load_config() -> Dict[str, str]:
|
|||
_cache = data if data else {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
|
||||
# Purge unsafe vars from a stale saved config and re-persist only if the file existed.
|
||||
# When running without a saved config (first run / after Delete), load_config() must
|
||||
# never create the file — that only happens via save_config() on Apply or Apply Profile.
|
||||
# never create the file - that only happens via save_config() on Apply or Apply Profile.
|
||||
dirty = {k for k in _cache if k in _UNSET_VARS or (k != _ARCH_KEY and k not in ROCM_ENV_VARS)}
|
||||
if dirty:
|
||||
_cache = {k: v for k, v in _cache.items() if k not in dirty}
|
||||
|
|
@ -212,7 +209,7 @@ def apply_env(config: Optional[Dict[str, str]] = None) -> None:
|
|||
os.environ[var] = expanded
|
||||
# Arch safety net: hard-force all hardware-incompatible vars to "0" in the env.
|
||||
# This runs *after* the config loop so it overrides any stale "1" that survived in the JSON.
|
||||
# Source of truth: rocm_profiles.UNAVAILABLE[arch] — vars with no supporting hardware.
|
||||
# Source of truth: rocm_profiles.UNAVAILABLE[arch] - vars with no supporting hardware.
|
||||
arch = config.get(_ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
if unavailable:
|
||||
|
|
@ -240,7 +237,7 @@ def apply_all(names: list, values: list) -> None:
|
|||
meta = ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
if value is None:
|
||||
pass # Gradio passed None (component not interacted with) — leave config unchanged
|
||||
pass # Gradio passed None (component not interacted with) - leave config unchanged
|
||||
else:
|
||||
config[name] = "1" if value else "0"
|
||||
elif meta["widget"] == "radio":
|
||||
|
|
@ -248,7 +245,7 @@ def apply_all(names: list, values: list) -> None:
|
|||
valid = {v for _, v in meta["options"]} if meta["options"] and isinstance(meta["options"][0], tuple) else set(meta["options"] or [])
|
||||
if stored in valid:
|
||||
config[name] = stored
|
||||
# else: value was None/invalid — leave the existing saved value untouched
|
||||
# else: value was None/invalid - leave the existing saved value untouched
|
||||
else:
|
||||
if meta.get("options"):
|
||||
value = _dropdown_stored(str(value), meta["options"])
|
||||
|
|
@ -291,7 +288,7 @@ def delete_config() -> None:
|
|||
CONFIG.unlink()
|
||||
log.info(f'ROCm delete_config: deleted {CONFIG}')
|
||||
_cache = None
|
||||
# Delete the MIOpen user DB (~/.miopen/db) — stale entries can cause solver mismatches
|
||||
# Delete the MIOpen user DB (~/.miopen/db) - stale entries can cause solver mismatches
|
||||
miopen_db = Path(os.path.expanduser('~')) / '.miopen' / 'db'
|
||||
if miopen_db.exists():
|
||||
shutil.rmtree(miopen_db, ignore_errors=True)
|
||||
|
|
@ -365,6 +362,28 @@ def _user_db_summary(path: Path) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _extract_db_hash(db_path: Path) -> str:
|
||||
"""Derive the cache subfolder name from udb.txt filenames.
|
||||
e.g. gfx1030_30.HIP.3_5_1_5454e9e2da.udb.txt → '3.5.1.5454e9e2da'"""
|
||||
for f in db_path.glob("*.HIP.*.udb.txt"):
|
||||
m = re.search(r'\.HIP\.([^.]+)\.udb\.txt$', f.name)
|
||||
if m:
|
||||
return m.group(1).replace("_", ".")
|
||||
return ""
|
||||
|
||||
|
||||
def _user_cache_summary(path: Path) -> dict:
|
||||
"""Return {filename: 'N KB'} for binary cache blobs in the resolved cache path."""
|
||||
out = {}
|
||||
if not path.exists():
|
||||
return out
|
||||
for f in sorted(path.iterdir()):
|
||||
if f.is_file():
|
||||
kb = f.stat().st_size // 1024
|
||||
out[f.name] = f"{kb} KB"
|
||||
return out
|
||||
|
||||
|
||||
def info() -> dict:
|
||||
config = load_config()
|
||||
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
|
||||
|
|
@ -427,20 +446,29 @@ def info() -> dict:
|
|||
if ufiles:
|
||||
udb["files"] = ufiles
|
||||
|
||||
# User cache (~/.miopen/cache/<version-hash>)
|
||||
cache_base = Path.home() / ".miopen" / "cache"
|
||||
db_hash = _extract_db_hash(user_db_path) if user_db_path.exists() else ""
|
||||
cache_path = cache_base / db_hash if db_hash else cache_base
|
||||
ucache = {"path": str(cache_path), "exists": cache_path.exists()}
|
||||
if cache_path.exists():
|
||||
cfiles = _user_cache_summary(cache_path)
|
||||
if cfiles:
|
||||
ucache["files"] = cfiles
|
||||
|
||||
return {
|
||||
"rocm": rocm_section,
|
||||
"torch": torch_section,
|
||||
"gpu": gpu_section,
|
||||
"system_db": sdb,
|
||||
"user_db": udb,
|
||||
"user_cache": ucache,
|
||||
}
|
||||
|
||||
|
||||
# Apply saved config to os.environ at import time (only when ROCm is present)
|
||||
if is_rocm:
|
||||
if installer.torch_info.get('type', None) == 'rocm':
|
||||
try:
|
||||
apply_env()
|
||||
except Exception as _e:
|
||||
print(f"[rocm_mgr] Warning: failed to apply env at import: {_e}", file=sys.stderr)
|
||||
else:
|
||||
log.debug('ROCm is not installed — skipping rocm_mgr env apply')
|
||||
log.debug(f"[rocm_mgr] Warning: failed to apply env at import: {_e}")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""
|
||||
"""
|
||||
Architecture-specific MIOpen solver profiles for AMD GCN/RDNA GPUs.
|
||||
|
||||
Sources:
|
||||
|
|
@ -6,8 +6,8 @@ Sources:
|
|||
|
||||
Key axis: consumer RDNA GPUs have NO XDLOPS hardware (that's CDNA/Instinct only).
|
||||
RDNA2 (gfx1030): RX 6000 series
|
||||
RDNA3 (gfx1100): RX 7000 series — adds Fury Winograd, wider MPASS
|
||||
RDNA4 (gfx1200): RX 9000 series — adds Rage Winograd, wider MPASS
|
||||
RDNA3 (gfx1100): RX 7000 series - adds Fury Winograd, wider MPASS
|
||||
RDNA4 (gfx1200): RX 9000 series - adds Rage Winograd, wider MPASS
|
||||
|
||||
Each profile is a dict of {var: value} that will be MERGED on top of the
|
||||
current config (general vars like DB path / log level are preserved).
|
||||
|
|
@ -15,9 +15,9 @@ current config (general vars like DB path / log level are preserved).
|
|||
|
||||
from typing import Dict
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Shared: everything that must be OFF on ALL consumer RDNA (no XDLOPS hw)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_XDLOPS_OFF: Dict[str, str] = {
|
||||
# GTC XDLOPS (CDNA-only)
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS": "0",
|
||||
|
|
@ -55,7 +55,7 @@ _XDLOPS_OFF: Dict[str, str] = {
|
|||
# MLIR (CDNA-only in practice)
|
||||
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS": "0",
|
||||
# MP BD Winograd (Multi-pass Block-Decomposed — CDNA / high-end only)
|
||||
# MP BD Winograd (Multi-pass Block-Decomposed - CDNA / high-end only)
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F3X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F4X3": "0",
|
||||
|
|
@ -68,17 +68,17 @@ _XDLOPS_OFF: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F6X3": "0",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA2 — gfx1030 (RX 6000 series)
|
||||
|
||||
# RDNA2 - gfx1030 (RX 6000 series)
|
||||
# No XDLOPS, no Fury/Rage Winograd, MPASS limited to F3x2/F3x3
|
||||
# ASM IGEMM: V4R1 variants only; HIP IGEMM: non-XDLOPS V4R1/R4 only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RDNA2: Dict[str, str] = {
|
||||
**_XDLOPS_OFF,
|
||||
# General settings (architecture-independent; set here so all profiles cover them)
|
||||
"MIOPEN_SEARCH_CUTOFF": "0",
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": "0",
|
||||
# Core algo enables — FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
|
||||
# Core algo enables - FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
|
||||
"MIOPEN_DEBUG_CONV_FFT": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT": "1",
|
||||
"MIOPEN_DEBUG_CONV_GEMM": "1",
|
||||
|
|
@ -93,16 +93,16 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "1",
|
||||
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "1",
|
||||
"MIOPEN_DEBUG_ATTN_SOFTMAX": "1",
|
||||
# Direct ASM — dtype notes
|
||||
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward — enabled
|
||||
# Direct ASM - dtype notes
|
||||
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward - enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
|
||||
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches — disabled
|
||||
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
|
||||
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
|
||||
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) - never matches SD - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
|
||||
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) — disabled for inference
|
||||
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) - disabled for inference
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW3X3": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW1X1": "0",
|
||||
# PERF_VALS intentionally blank: MIOpen reads this as a config string not a boolean;
|
||||
|
|
@ -110,30 +110,30 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS": "",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "1",
|
||||
# NAIVE_CONV_FWD: scalar FP32 reference solver — IsApplicable does NOT reliably filter for FP16;
|
||||
# NAIVE_CONV_FWD: scalar FP32 reference solver - IsApplicable does NOT reliably filter for FP16;
|
||||
# can be selected for unusual shapes (e.g. VAE decoder 3-ch output) and returns dtype=float32
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD": "0",
|
||||
# Direct OCL — dtype notes
|
||||
# FWD / FWD1X1: FP32/FP16 forward — enabled
|
||||
# Direct OCL - dtype notes
|
||||
# FWD / FWD1X1: FP32/FP16 forward - enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
|
||||
# FWD11X11: requires 11*11 kernel — no SD match — disabled
|
||||
# FWD11X11: requires 11*11 kernel - no SD match - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
|
||||
# FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16;
|
||||
# can produce dtype=float32 output for FP16 inputs — disabled
|
||||
# FWDGEN: FP32 generic OCL fallback - IsApplicable does NOT reliably reject for FP16;
|
||||
# can produce dtype=float32 output for FP16 inputs - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWDGEN": "0",
|
||||
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient — disabled
|
||||
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW2": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW53": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW1X1": "0",
|
||||
# Winograd RxS — dtype per MIOpen docs
|
||||
# WINOGRAD_3X3: FP32-only — harmless (IsApplicable rejects for fp16); enabled
|
||||
# Winograd RxS - dtype per MIOpen docs
|
||||
# WINOGRAD_3X3: FP32-only - harmless (IsApplicable rejects for fp16); enabled
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "1",
|
||||
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW — keep enabled (fp16 fwd/bwd path exists)
|
||||
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW - keep enabled (fp16 fwd/bwd path exists)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "1",
|
||||
# RXS_FWD_BWD: FP32/FP16 — explicitly the fp16-capable subset
|
||||
# RXS_FWD_BWD: FP32/FP16 - explicitly the fp16-capable subset
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "1",
|
||||
# RXS_WRW: FP32 WrW only — training-only, disabled for inference fp16 profile
|
||||
# RXS_WRW: FP32 WrW only - training-only, disabled for inference fp16 profile
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_WRW": "0",
|
||||
# RXS_F3X2: FP32/FP16 Fwd/Bwd
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "1",
|
||||
|
|
@ -141,15 +141,15 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "1",
|
||||
# RXS_F2X3_G1: FP32/FP16 Fwd/Bwd (non-group convolutions)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "1",
|
||||
# FUSED_WINOGRAD: FP32-only — harmless (IsApplicable rejects for fp16); enabled
|
||||
# FUSED_WINOGRAD: FP32-only - harmless (IsApplicable rejects for fp16); enabled
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "1",
|
||||
# PERF_VALS intentionally blank: same reason as ASM_1X1U — not a boolean, config string
|
||||
# PERF_VALS intentionally blank: same reason as ASM_1X1U - not a boolean, config string
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS": "",
|
||||
# Fury/Rage Winograd — NOT available on RDNA2
|
||||
# Fury/Rage Winograd - NOT available on RDNA2
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "0",
|
||||
# MPASS — only F3x2 and F3x3 are safe on RDNA2
|
||||
# MPASS - only F3x2 and F3x3 are safe on RDNA2
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "0",
|
||||
|
|
@ -159,50 +159,50 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3": "0",
|
||||
# ASM Implicit GEMM — forward V4R1 only; no GTC/XDLOPS on RDNA2
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
|
||||
# ASM Implicit GEMM - forward V4R1 only; no GTC/XDLOPS on RDNA2
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only - disabled
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_V4R1": "0",
|
||||
# HIP Implicit GEMM — non-XDLOPS V4R1/R4 forward only
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
|
||||
# HIP Implicit GEMM - non-XDLOPS V4R1/R4 forward only
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only - disabled
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4": "0",
|
||||
# Group Conv XDLOPS / CK default kernels — RDNA3/4 only, not available on RDNA2
|
||||
# Group Conv XDLOPS / CK default kernels - RDNA3/4 only, not available on RDNA2
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "0",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "0",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA3 — gfx1100 (RX 7000 series)
|
||||
# RDNA3 - gfx1100 (RX 7000 series)
|
||||
# Fury Winograd added; MPASS F3x4 enabled; Group Conv XDLOPS + CK default kernels enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA3: Dict[str, str] = {
|
||||
**RDNA2,
|
||||
# Fury Winograd — introduced for gfx1100 (RDNA3)
|
||||
# Fury Winograd - introduced for gfx1100 (RDNA3)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "1",
|
||||
# Wider MPASS on RDNA3
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "1",
|
||||
# Group Conv XDLOPS / CK — available from gfx1100 (RDNA3) onwards
|
||||
# Group Conv XDLOPS / CK - available from gfx1100 (RDNA3) onwards
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "1",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "1",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "1",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA4 — gfx1200 (RX 9000 series)
|
||||
# RDNA4 - gfx1200 (RX 9000 series)
|
||||
# Rage Winograd added; MPASS F3x5 enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA4: Dict[str, str] = {
|
||||
**RDNA3,
|
||||
# Rage Winograd — introduced for gfx1200 (RDNA4)
|
||||
# Rage Winograd - introduced for gfx1200 (RDNA4)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "1",
|
||||
# Wider MPASS on RDNA4
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "1",
|
||||
|
|
|
|||
|
|
@ -1,15 +1,48 @@
|
|||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
|
||||
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||
|
||||
"MIOPEN_SYSTEM_DB_PATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
||||
"desc": "MIOpen system DB path",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
},
|
||||
"ROCBLAS_TENSILE_LIBPATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\rocblas\\library",
|
||||
"desc": "rocBLAS Tensile library path",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_GEMM_ENFORCE_BACKEND": {
|
||||
"default": "1",
|
||||
"desc": "Enforce GEMM backend",
|
||||
"desc": "GEMM backend",
|
||||
"widget": "dropdown",
|
||||
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_ROCM_USE_ROCBLAS": {
|
||||
"default": "0",
|
||||
"desc": "PyTorch: Use rocBLAS",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"PYTORCH_HIPBLASLT_DISABLE": {
|
||||
"default": "1",
|
||||
"desc": "PyTorch: Use hipBLASLt",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Allow hipBLASLt", "0"), ("1 - Disable hipBLASLt", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"ROCBLAS_USE_HIPBLASLT": {
|
||||
"default": "0",
|
||||
"desc": "rocBLAS: use hipBLASLt backend",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Tensile (rocBLAS)", "0"), ("1 - hipBLASLt", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_FIND_MODE": {
|
||||
"default": "2",
|
||||
"desc": "MIOpen Find Mode",
|
||||
|
|
@ -31,12 +64,69 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
|||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_SYSTEM_DB_PATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
||||
"desc": "MIOpen system DB path",
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
||||
"default": "0",
|
||||
"desc": "Deterministic convolutions",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_CONVOLUTION_MAX_WORKSPACE": {
|
||||
"default": "1073741824",
|
||||
"desc": "MIOpen convolutions: max workspace (bytes; 1 GB)",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
"restart_required": False,
|
||||
},
|
||||
"ROCBLAS_DEVICE_MEMORY_SIZE": {
|
||||
"default": "",
|
||||
"desc": "rocBLAS workspace size in bytes (empty = dynamic)",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_CACHE_DIR": {
|
||||
"default": "{ROOT}\\models\\tunable",
|
||||
"desc": "TunableOp cache directory",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": False,
|
||||
},
|
||||
|
||||
"ROCBLAS_STREAM_ORDER_ALLOC": {
|
||||
"default": "1",
|
||||
"desc": "rocBLAS stream-ordered memory allocation",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Standard", "0"), ("1 - Stream-ordered", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"ROCBLAS_DEFAULT_ATOMICS_MODE": {
|
||||
"default": "1",
|
||||
"desc": "rocBLAS allow atomics",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off (deterministic)", "0"), ("1 - On (performance)", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_ROCBLAS_ENABLED": {
|
||||
"default": "0",
|
||||
"desc": "TunableOp: Enable tuning",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_TUNING": {
|
||||
"default": "0",
|
||||
"desc": "TunableOp: Tuning mode",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Use Cache", "0"), ("1 - Benchmark new shapes", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {
|
||||
"default": "0",
|
||||
"desc": "TunableOp: benchmark hipBLASLt kernels",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_LOG_LEVEL": {
|
||||
"default": "0",
|
||||
|
|
@ -66,23 +156,8 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
|||
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
||||
"default": "0",
|
||||
"desc": "Deterministic convolution (reproducible results, may be slower)",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
|
||||
# Removed entirely — not representable in the UI, cannot be set by users:
|
||||
# WRW (weight-gradient) and BWD (data-gradient) — training passes only, never run during inference
|
||||
# XDLOPS/CK CDNA-exclusive (MI100/MI200/MI300 matrix engine variants) — not on any RDNA
|
||||
# Fixed-geometry (5x10, 7x7-ImageNet, 11x11) — shapes never appear in SD/video inference
|
||||
# FP32-reference (NAIVE_CONV_FWD, FWDGEN) — IsApplicable() unreliable for FP16/BF16
|
||||
# Wide MPASS (F3x4..F7x3) — kernel sizes that cannot match any SD convolution shape
|
||||
# Disabled by default (added but off): RDNA3/4-only — Group Conv XDLOPS, CK default kernels
|
||||
_SOLVER_DESCS: Dict[str, str] = {}
|
||||
|
||||
_SOLVER_DESCS.update({
|
||||
|
|
@ -251,3 +326,13 @@ SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
|
|||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
||||
]),
|
||||
]
|
||||
|
||||
# Variables that are relevant only when hipBLASLt is the active GEMM backend.
|
||||
# These are visually greyed-out in the UI when rocBLAS (MIOPEN_GEMM_ENFORCE_BACKEND="1") is selected.
|
||||
HIPBLASLT_VARS: set = {
|
||||
"PYTORCH_HIPBLASLT_DISABLE",
|
||||
"ROCBLAS_USE_HIPBLASLT",
|
||||
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED",
|
||||
"HIPBLASLT_LOG_LEVEL",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ from modules import scripts_manager, shared
|
|||
# rocm_mgr exposes package-internal helpers (prefixed _) that are intentionally called here
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
class ROCmScript(scripts_manager.Script):
|
||||
def title(self):
|
||||
return "ROCm: Advanced Config"
|
||||
return "Windows ROCm: Advanced Config"
|
||||
|
||||
def show(self, _is_img2img):
|
||||
if shared.cmd_opts.use_rocm or installer.torch_info.get('type') == 'rocm':
|
||||
|
|
@ -19,7 +18,7 @@ class ROCmScript(scripts_manager.Script):
|
|||
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
|
||||
return []
|
||||
|
||||
from scripts.rocm import rocm_mgr, rocm_vars # pylint: disable=no-name-in-module
|
||||
from scripts.rocm import rocm_mgr, rocm_vars, rocm_profiles # pylint: disable=no-name-in-module
|
||||
|
||||
config = rocm_mgr.load_config()
|
||||
var_names = []
|
||||
|
|
@ -59,11 +58,25 @@ class ROCmScript(scripts_manager.Script):
|
|||
row("path", udb.get("path", ""))
|
||||
for fname, finfo in udb.get("files", {}).items():
|
||||
row(fname, finfo)
|
||||
section("User cache (~/.miopen/cache)")
|
||||
ucache = d.get("user_cache", {})
|
||||
row("path", ucache.get("path", ""))
|
||||
for fname, sz in ucache.get("files", {}).items():
|
||||
row(fname, sz)
|
||||
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
|
||||
|
||||
def _build_style(unavailable, hipblaslt_disabled=False):
|
||||
rules = []
|
||||
for v in (unavailable or []):
|
||||
rules.append(f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}")
|
||||
if hipblaslt_disabled:
|
||||
for v in rocm_vars.HIPBLASLT_VARS:
|
||||
rules.append(f"#rocm_var_{v.lower()} {{ opacity: 0.45; pointer-events: none; }}")
|
||||
return f"<style>{' '.join(rules)}</style>" if rules else ""
|
||||
|
||||
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
|
||||
with gr.Row():
|
||||
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>Set your database and solver selections based on GPU profile or individually.</p><br><p>Enable cuDNN in Backend Settings to activate MIOpen.</p>")
|
||||
gr.HTML("<p><u>Advanced configuration for ROCm users.</u></p><br><p>This script aims to take the guesswork out of configuring MIOpen and rocBLAS on Windows ROCm, but also to expose the functioning switches of MIOpen for advanced configurations.</p><br><p>For best performance ensure that cuDNN and PyTorch tunable ops are set to <b><i>default</i></b> in Backend Settings.</p><br><p>This script was written with the intent to support ROCm Windows users, it should however, function identically for Linux users.</p><br>")
|
||||
with gr.Row():
|
||||
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
|
||||
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
|
||||
|
|
@ -74,12 +87,15 @@ class ROCmScript(scripts_manager.Script):
|
|||
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
|
||||
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
|
||||
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
|
||||
style_out = gr.HTML("")
|
||||
_init_gemm = config.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
_init_arch = config.get(rocm_mgr._ARCH_KEY, "")
|
||||
_init_unavailable = rocm_profiles.UNAVAILABLE.get(_init_arch, set()) if _init_arch else set()
|
||||
style_out = gr.HTML(_build_style(_init_unavailable, _init_gemm == "1"))
|
||||
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
|
||||
|
||||
# General vars (dropdowns, textboxes, checkboxes)
|
||||
with gr.Group():
|
||||
gr.HTML("<h3>MIOpen Settings</h3><hr>")
|
||||
gr.HTML("<br><h3>MIOpen Settings</h3><hr>")
|
||||
for name, meta in rocm_vars.GENERAL_VARS.items():
|
||||
comp = _make_component(name, meta, config)
|
||||
var_names.append(name)
|
||||
|
|
@ -106,13 +122,46 @@ class ROCmScript(scripts_manager.Script):
|
|||
|
||||
for name, comp in zip(var_names, components):
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "dropdown":
|
||||
if meta["widget"] == "dropdown" and name != "MIOPEN_GEMM_ENFORCE_BACKEND":
|
||||
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
|
||||
|
||||
_GEMM_COMPANIONS = {
|
||||
"PYTORCH_ROCM_USE_ROCBLAS": {"1": "1", "5": "0"},
|
||||
"PYTORCH_HIPBLASLT_DISABLE": {"1": "1", "5": "0"},
|
||||
"ROCBLAS_USE_HIPBLASLT": {"1": "0", "5": "1"},
|
||||
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {"1": "0", "5": "1"},
|
||||
}
|
||||
|
||||
def gemm_changed(gemm_display_val):
|
||||
stored = rocm_mgr._dropdown_stored(str(gemm_display_val), rocm_vars.ROCM_ENV_VARS["MIOPEN_GEMM_ENFORCE_BACKEND"]["options"])
|
||||
cfg = rocm_mgr.load_config().copy()
|
||||
cfg["MIOPEN_GEMM_ENFORCE_BACKEND"] = stored
|
||||
for var, vals in _GEMM_COMPANIONS.items():
|
||||
cfg[var] = vals.get(stored, cfg.get(var, ""))
|
||||
rocm_mgr.save_config(cfg)
|
||||
rocm_mgr.apply_env(cfg)
|
||||
arch = cfg.get(rocm_mgr._ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
result = [gr.update(value=_build_style(unavailable, stored == "1"))]
|
||||
for pname in var_names:
|
||||
if pname in _GEMM_COMPANIONS:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||
val = _GEMM_COMPANIONS[pname].get(stored, cfg.get(pname, ""))
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
|
||||
else:
|
||||
result.append(gr.update())
|
||||
return result
|
||||
|
||||
gemm_comp = components[var_names.index("MIOPEN_GEMM_ENFORCE_BACKEND")]
|
||||
gemm_comp.change(fn=gemm_changed, inputs=[gemm_comp], outputs=[style_out] + components, show_progress='hidden')
|
||||
|
||||
def apply_fn(*values):
|
||||
rocm_mgr.apply_all(var_names, list(values))
|
||||
saved = rocm_mgr.load_config()
|
||||
result = [gr.update(value="")]
|
||||
arch = saved.get(rocm_mgr._ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
gemm_val = saved.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
val = saved.get(name, meta["default"])
|
||||
|
|
@ -124,19 +173,13 @@ class ROCmScript(scripts_manager.Script):
|
|||
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
||||
return result
|
||||
|
||||
def _build_style(unavailable):
|
||||
if not unavailable:
|
||||
return ""
|
||||
rules = " ".join(
|
||||
f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}"
|
||||
for v in unavailable
|
||||
)
|
||||
return f"<style>{rules}</style>"
|
||||
|
||||
def reset_fn():
|
||||
rocm_mgr.reset_defaults()
|
||||
updated = rocm_mgr.load_config()
|
||||
result = [gr.update(value="")]
|
||||
arch = updated.get(rocm_mgr._ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
val = updated.get(name, meta["default"])
|
||||
|
|
@ -150,7 +193,9 @@ class ROCmScript(scripts_manager.Script):
|
|||
|
||||
def clear_fn():
|
||||
rocm_mgr.clear_env()
|
||||
result = [gr.update(value="")]
|
||||
cfg = rocm_mgr.load_config()
|
||||
gemm_val = cfg.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(None, gemm_val == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
|
|
@ -163,7 +208,8 @@ class ROCmScript(scripts_manager.Script):
|
|||
|
||||
def delete_fn():
|
||||
rocm_mgr.delete_config()
|
||||
result = [gr.update(value="")]
|
||||
gemm_default = rocm_vars.ROCM_ENV_VARS.get("MIOPEN_GEMM_ENFORCE_BACKEND", {}).get("default", "1")
|
||||
result = [gr.update(value=_build_style(None, gemm_default == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
|
|
@ -175,11 +221,11 @@ class ROCmScript(scripts_manager.Script):
|
|||
return result
|
||||
|
||||
def profile_fn(arch):
|
||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||
rocm_mgr.apply_profile(arch)
|
||||
updated = rocm_mgr.load_config()
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
result = [gr.update(value=_build_style(unavailable))]
|
||||
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||
for pname in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||
val = updated.get(pname, meta["default"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue