mirror of https://github.com/vladmandic/automatic
Added further rocblas support enhancements and performance-related best practice settings.
parent
668a94141d
commit
f5c037a735
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
@ -121,14 +122,22 @@ def _get_venv() -> str:
|
||||||
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
|
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
|
||||||
|
|
||||||
|
|
||||||
|
def _get_root() -> str:
|
||||||
|
"""App root — one level above the venv folder (e.g. E:\\Sd.Next)."""
|
||||||
|
return str(Path(_get_venv()).parent)
|
||||||
|
|
||||||
|
|
||||||
def _expand_venv(value: str) -> str:
|
def _expand_venv(value: str) -> str:
|
||||||
return value.replace("{VIRTUAL_ENV}", _get_venv())
|
return value.replace("{VIRTUAL_ENV}", _get_venv()).replace("{ROOT}", _get_root())
|
||||||
|
|
||||||
|
|
||||||
def _collapse_venv(value: str) -> str:
|
def _collapse_venv(value: str) -> str:
|
||||||
venv = _get_venv()
|
venv = _get_venv()
|
||||||
|
root = _get_root()
|
||||||
if venv and value.startswith(venv):
|
if venv and value.startswith(venv):
|
||||||
return "{VIRTUAL_ENV}" + value[len(venv):]
|
return "{VIRTUAL_ENV}" + value[len(venv):]
|
||||||
|
if root and value.startswith(root):
|
||||||
|
return "{ROOT}" + value[len(root):]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -365,6 +374,28 @@ def _user_db_summary(path: Path) -> dict:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_db_hash(db_path: Path) -> str:
|
||||||
|
"""Derive the cache subfolder name from udb.txt filenames.
|
||||||
|
e.g. gfx1030_30.HIP.3_5_1_5454e9e2da.udb.txt → '3.5.1.5454e9e2da'"""
|
||||||
|
for f in db_path.glob("*.HIP.*.udb.txt"):
|
||||||
|
m = re.search(r'\.HIP\.([^.]+)\.udb\.txt$', f.name)
|
||||||
|
if m:
|
||||||
|
return m.group(1).replace("_", ".")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _user_cache_summary(path: Path) -> dict:
|
||||||
|
"""Return {filename: 'N KB'} for binary cache blobs in the resolved cache path."""
|
||||||
|
out = {}
|
||||||
|
if not path.exists():
|
||||||
|
return out
|
||||||
|
for f in sorted(path.iterdir()):
|
||||||
|
if f.is_file():
|
||||||
|
kb = f.stat().st_size // 1024
|
||||||
|
out[f.name] = f"{kb} KB"
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def info() -> dict:
|
def info() -> dict:
|
||||||
config = load_config()
|
config = load_config()
|
||||||
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
|
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
|
||||||
|
|
@ -427,12 +458,23 @@ def info() -> dict:
|
||||||
if ufiles:
|
if ufiles:
|
||||||
udb["files"] = ufiles
|
udb["files"] = ufiles
|
||||||
|
|
||||||
|
# --- User cache (~/.miopen/cache/<version-hash>) ---
|
||||||
|
cache_base = Path.home() / ".miopen" / "cache"
|
||||||
|
db_hash = _extract_db_hash(user_db_path) if user_db_path.exists() else ""
|
||||||
|
cache_path = cache_base / db_hash if db_hash else cache_base
|
||||||
|
ucache = {"path": str(cache_path), "exists": cache_path.exists()}
|
||||||
|
if cache_path.exists():
|
||||||
|
cfiles = _user_cache_summary(cache_path)
|
||||||
|
if cfiles:
|
||||||
|
ucache["files"] = cfiles
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"rocm": rocm_section,
|
"rocm": rocm_section,
|
||||||
"torch": torch_section,
|
"torch": torch_section,
|
||||||
"gpu": gpu_section,
|
"gpu": gpu_section,
|
||||||
"system_db": sdb,
|
"system_db": sdb,
|
||||||
"user_db": udb,
|
"user_db": udb,
|
||||||
|
"user_cache": ucache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from typing import Dict, Any, List, Tuple
|
||||||
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
|
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
|
||||||
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||||
|
|
||||||
|
# ── GEMM backend selector + companion toggles ──────────────────────────
|
||||||
"MIOPEN_GEMM_ENFORCE_BACKEND": {
|
"MIOPEN_GEMM_ENFORCE_BACKEND": {
|
||||||
"default": "1",
|
"default": "1",
|
||||||
"desc": "Enforce GEMM backend",
|
"desc": "Enforce GEMM backend",
|
||||||
|
|
@ -10,6 +11,29 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||||
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
|
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
|
||||||
"restart_required": False,
|
"restart_required": False,
|
||||||
},
|
},
|
||||||
|
"PYTORCH_ROCM_USE_ROCBLAS": {
|
||||||
|
"default": "1",
|
||||||
|
"desc": "PyTorch ROCm: prioritise rocBLAS for linear algebra",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||||
|
"restart_required": True,
|
||||||
|
},
|
||||||
|
"PYTORCH_HIPBLASLT_DISABLE": {
|
||||||
|
"default": "1",
|
||||||
|
"desc": "Disable PyTorch hipBLASLt dispatcher",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Allow hipBLASLt", "0"), ("1 - Disable hipBLASLt", "1")],
|
||||||
|
"restart_required": True,
|
||||||
|
},
|
||||||
|
"ROCBLAS_USE_HIPBLASLT": {
|
||||||
|
"default": "0",
|
||||||
|
"desc": "rocBLAS: use hipBLASLt backend (0 = Tensile)",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Tensile (rocBLAS)", "0"), ("1 - hipBLASLt", "1")],
|
||||||
|
"restart_required": True,
|
||||||
|
},
|
||||||
|
|
||||||
|
# ── MIOpen behavioural settings ────────────────────────────────────────
|
||||||
"MIOPEN_FIND_MODE": {
|
"MIOPEN_FIND_MODE": {
|
||||||
"default": "2",
|
"default": "2",
|
||||||
"desc": "MIOpen Find Mode",
|
"desc": "MIOpen Find Mode",
|
||||||
|
|
@ -31,6 +55,15 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||||
"restart_required": True,
|
"restart_required": True,
|
||||||
},
|
},
|
||||||
|
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
||||||
|
"default": "0",
|
||||||
|
"desc": "Deterministic convolution (reproducible results, may be slower)",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
|
||||||
|
# ── Paths / sizes ──────────────────────────────────────────────────────
|
||||||
"MIOPEN_SYSTEM_DB_PATH": {
|
"MIOPEN_SYSTEM_DB_PATH": {
|
||||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
||||||
"desc": "MIOpen system DB path",
|
"desc": "MIOpen system DB path",
|
||||||
|
|
@ -38,6 +71,75 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||||
"options": None,
|
"options": None,
|
||||||
"restart_required": True,
|
"restart_required": True,
|
||||||
},
|
},
|
||||||
|
"MIOPEN_CONVOLUTION_MAX_WORKSPACE": {
|
||||||
|
"default": "1073741824",
|
||||||
|
"desc": "MIOpen convolution max workspace (bytes; 1 GB default)",
|
||||||
|
"widget": "textbox",
|
||||||
|
"options": None,
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
"ROCBLAS_TENSILE_LIBPATH": {
|
||||||
|
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\rocblas\\library",
|
||||||
|
"desc": "rocBLAS Tensile library path",
|
||||||
|
"widget": "textbox",
|
||||||
|
"options": None,
|
||||||
|
"restart_required": True,
|
||||||
|
},
|
||||||
|
"ROCBLAS_DEVICE_MEMORY_SIZE": {
|
||||||
|
"default": "",
|
||||||
|
"desc": "rocBLAS workspace size in bytes (empty = dynamic)",
|
||||||
|
"widget": "textbox",
|
||||||
|
"options": None,
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
"PYTORCH_TUNABLEOP_CACHE_DIR": {
|
||||||
|
"default": "{ROOT}\\models\\tunable",
|
||||||
|
"desc": "TunableOp: kernel profile cache directory",
|
||||||
|
"widget": "textbox",
|
||||||
|
"options": None,
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
|
||||||
|
# ── rocBLAS settings ───────────────────────────────────────────────────
|
||||||
|
"ROCBLAS_STREAM_ORDER_ALLOC": {
|
||||||
|
"default": "1",
|
||||||
|
"desc": "rocBLAS stream-ordered memory allocation",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Standard", "0"), ("1 - Stream-ordered", "1")],
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
"ROCBLAS_DEFAULT_ATOMICS_MODE": {
|
||||||
|
"default": "1",
|
||||||
|
"desc": "rocBLAS default atomics mode (1 = allow non-deterministic for performance)",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Off (deterministic)", "0"), ("1 - On (performance)", "1")],
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
"PYTORCH_TUNABLEOP_ROCBLAS_ENABLED": {
|
||||||
|
"default": "1",
|
||||||
|
"desc": "TunableOp: wrap and optimise rocBLAS GEMM calls",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
"PYTORCH_TUNABLEOP_TUNING": {
|
||||||
|
"default": "0",
|
||||||
|
"desc": "TunableOp: tuning mode (1 = benchmark; 0 = use saved CSV)",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Use saved CSV", "0"), ("1 - Benchmark new shapes", "1")],
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
|
||||||
|
# ── hipBLASLt settings ─────────────────────────────────────────────────
|
||||||
|
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {
|
||||||
|
"default": "0",
|
||||||
|
"desc": "TunableOp: benchmark hipBLASLt kernels",
|
||||||
|
"widget": "dropdown",
|
||||||
|
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||||
|
"restart_required": False,
|
||||||
|
},
|
||||||
|
|
||||||
|
# ── Logging: MIOpen → rocBLAS → hipBLASLt ─────────────────────────────
|
||||||
"MIOPEN_LOG_LEVEL": {
|
"MIOPEN_LOG_LEVEL": {
|
||||||
"default": "0",
|
"default": "0",
|
||||||
"desc": "MIOpen log verbosity level",
|
"desc": "MIOpen log verbosity level",
|
||||||
|
|
@ -66,13 +168,6 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||||
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
|
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
|
||||||
"restart_required": False,
|
"restart_required": False,
|
||||||
},
|
},
|
||||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
|
||||||
"default": "0",
|
|
||||||
"desc": "Deterministic convolution (reproducible results, may be slower)",
|
|
||||||
"widget": "dropdown",
|
|
||||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
|
||||||
"restart_required": False,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
|
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
|
||||||
|
|
@ -251,3 +346,13 @@ SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
|
||||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
||||||
]),
|
]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Variables that are relevant only when hipBLASLt is the active GEMM backend.
|
||||||
|
# These are visually greyed-out in the UI when rocBLAS (MIOPEN_GEMM_ENFORCE_BACKEND="1") is selected.
|
||||||
|
HIPBLASLT_VARS: set = {
|
||||||
|
"PYTORCH_HIPBLASLT_DISABLE",
|
||||||
|
"ROCBLAS_USE_HIPBLASLT",
|
||||||
|
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED",
|
||||||
|
"HIPBLASLT_LOG_LEVEL",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ class ROCmScript(scripts_manager.Script):
|
||||||
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
|
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
|
||||||
return []
|
return []
|
||||||
|
|
||||||
from scripts.rocm import rocm_mgr, rocm_vars # pylint: disable=no-name-in-module
|
from scripts.rocm import rocm_mgr, rocm_vars, rocm_profiles # pylint: disable=no-name-in-module
|
||||||
|
|
||||||
config = rocm_mgr.load_config()
|
config = rocm_mgr.load_config()
|
||||||
var_names = []
|
var_names = []
|
||||||
|
|
@ -59,11 +59,25 @@ class ROCmScript(scripts_manager.Script):
|
||||||
row("path", udb.get("path", ""))
|
row("path", udb.get("path", ""))
|
||||||
for fname, finfo in udb.get("files", {}).items():
|
for fname, finfo in udb.get("files", {}).items():
|
||||||
row(fname, finfo)
|
row(fname, finfo)
|
||||||
|
section("User cache (~/.miopen/cache)")
|
||||||
|
ucache = d.get("user_cache", {})
|
||||||
|
row("path", ucache.get("path", ""))
|
||||||
|
for fname, sz in ucache.get("files", {}).items():
|
||||||
|
row(fname, sz)
|
||||||
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
|
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
|
||||||
|
|
||||||
|
def _build_style(unavailable, hipblaslt_disabled=False):
|
||||||
|
rules = []
|
||||||
|
for v in (unavailable or []):
|
||||||
|
rules.append(f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}")
|
||||||
|
if hipblaslt_disabled:
|
||||||
|
for v in rocm_vars.HIPBLASLT_VARS:
|
||||||
|
rules.append(f"#rocm_var_{v.lower()} {{ opacity: 0.45; pointer-events: none; }}")
|
||||||
|
return f"<style>{' '.join(rules)}</style>" if rules else ""
|
||||||
|
|
||||||
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
|
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>Set your database and solver selections based on GPU profile or individually.</p><br><p>Enable cuDNN in Backend Settings to activate MIOpen.</p>")
|
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>For best performance ensure that cudnn and torch tunable ops are set to default in Backend Settings.</p>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
|
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
|
||||||
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
|
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
|
||||||
|
|
@ -74,7 +88,10 @@ class ROCmScript(scripts_manager.Script):
|
||||||
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
|
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
|
||||||
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
|
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
|
||||||
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
|
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
|
||||||
style_out = gr.HTML("")
|
_init_gemm = config.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||||
|
_init_arch = config.get(rocm_mgr._ARCH_KEY, "")
|
||||||
|
_init_unavailable = rocm_profiles.UNAVAILABLE.get(_init_arch, set()) if _init_arch else set()
|
||||||
|
style_out = gr.HTML(_build_style(_init_unavailable, _init_gemm == "1"))
|
||||||
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
|
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
|
||||||
|
|
||||||
# General vars (dropdowns, textboxes, checkboxes)
|
# General vars (dropdowns, textboxes, checkboxes)
|
||||||
|
|
@ -106,13 +123,46 @@ class ROCmScript(scripts_manager.Script):
|
||||||
|
|
||||||
for name, comp in zip(var_names, components):
|
for name, comp in zip(var_names, components):
|
||||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||||
if meta["widget"] == "dropdown":
|
if meta["widget"] == "dropdown" and name != "MIOPEN_GEMM_ENFORCE_BACKEND":
|
||||||
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
|
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
|
||||||
|
|
||||||
|
_GEMM_COMPANIONS = {
|
||||||
|
"PYTORCH_ROCM_USE_ROCBLAS": {"1": "1", "5": "0"},
|
||||||
|
"PYTORCH_HIPBLASLT_DISABLE": {"1": "1", "5": "0"},
|
||||||
|
"ROCBLAS_USE_HIPBLASLT": {"1": "0", "5": "1"},
|
||||||
|
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {"1": "0", "5": "1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def gemm_changed(gemm_display_val):
|
||||||
|
stored = rocm_mgr._dropdown_stored(str(gemm_display_val), rocm_vars.ROCM_ENV_VARS["MIOPEN_GEMM_ENFORCE_BACKEND"]["options"])
|
||||||
|
cfg = rocm_mgr.load_config().copy()
|
||||||
|
cfg["MIOPEN_GEMM_ENFORCE_BACKEND"] = stored
|
||||||
|
for var, vals in _GEMM_COMPANIONS.items():
|
||||||
|
cfg[var] = vals.get(stored, cfg.get(var, ""))
|
||||||
|
rocm_mgr.save_config(cfg)
|
||||||
|
rocm_mgr.apply_env(cfg)
|
||||||
|
arch = cfg.get(rocm_mgr._ARCH_KEY, "")
|
||||||
|
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||||
|
result = [gr.update(value=_build_style(unavailable, stored == "1"))]
|
||||||
|
for pname in var_names:
|
||||||
|
if pname in _GEMM_COMPANIONS:
|
||||||
|
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||||
|
val = _GEMM_COMPANIONS[pname].get(stored, cfg.get(pname, ""))
|
||||||
|
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
|
||||||
|
else:
|
||||||
|
result.append(gr.update())
|
||||||
|
return result
|
||||||
|
|
||||||
|
gemm_comp = components[var_names.index("MIOPEN_GEMM_ENFORCE_BACKEND")]
|
||||||
|
gemm_comp.change(fn=gemm_changed, inputs=[gemm_comp], outputs=[style_out] + components, show_progress='hidden')
|
||||||
|
|
||||||
def apply_fn(*values):
|
def apply_fn(*values):
|
||||||
rocm_mgr.apply_all(var_names, list(values))
|
rocm_mgr.apply_all(var_names, list(values))
|
||||||
saved = rocm_mgr.load_config()
|
saved = rocm_mgr.load_config()
|
||||||
result = [gr.update(value="")]
|
arch = saved.get(rocm_mgr._ARCH_KEY, "")
|
||||||
|
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||||
|
gemm_val = saved.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||||
|
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||||
for name in var_names:
|
for name in var_names:
|
||||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||||
val = saved.get(name, meta["default"])
|
val = saved.get(name, meta["default"])
|
||||||
|
|
@ -124,19 +174,13 @@ class ROCmScript(scripts_manager.Script):
|
||||||
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _build_style(unavailable):
|
|
||||||
if not unavailable:
|
|
||||||
return ""
|
|
||||||
rules = " ".join(
|
|
||||||
f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}"
|
|
||||||
for v in unavailable
|
|
||||||
)
|
|
||||||
return f"<style>{rules}</style>"
|
|
||||||
|
|
||||||
def reset_fn():
|
def reset_fn():
|
||||||
rocm_mgr.reset_defaults()
|
rocm_mgr.reset_defaults()
|
||||||
updated = rocm_mgr.load_config()
|
updated = rocm_mgr.load_config()
|
||||||
result = [gr.update(value="")]
|
arch = updated.get(rocm_mgr._ARCH_KEY, "")
|
||||||
|
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||||
|
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||||
|
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||||
for name in var_names:
|
for name in var_names:
|
||||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||||
val = updated.get(name, meta["default"])
|
val = updated.get(name, meta["default"])
|
||||||
|
|
@ -150,7 +194,9 @@ class ROCmScript(scripts_manager.Script):
|
||||||
|
|
||||||
def clear_fn():
|
def clear_fn():
|
||||||
rocm_mgr.clear_env()
|
rocm_mgr.clear_env()
|
||||||
result = [gr.update(value="")]
|
cfg = rocm_mgr.load_config()
|
||||||
|
gemm_val = cfg.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||||
|
result = [gr.update(value=_build_style(None, gemm_val == "1"))]
|
||||||
for name in var_names:
|
for name in var_names:
|
||||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||||
if meta["widget"] == "checkbox":
|
if meta["widget"] == "checkbox":
|
||||||
|
|
@ -163,7 +209,8 @@ class ROCmScript(scripts_manager.Script):
|
||||||
|
|
||||||
def delete_fn():
|
def delete_fn():
|
||||||
rocm_mgr.delete_config()
|
rocm_mgr.delete_config()
|
||||||
result = [gr.update(value="")]
|
gemm_default = rocm_vars.ROCM_ENV_VARS.get("MIOPEN_GEMM_ENFORCE_BACKEND", {}).get("default", "1")
|
||||||
|
result = [gr.update(value=_build_style(None, gemm_default == "1"))]
|
||||||
for name in var_names:
|
for name in var_names:
|
||||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||||
if meta["widget"] == "checkbox":
|
if meta["widget"] == "checkbox":
|
||||||
|
|
@ -175,11 +222,11 @@ class ROCmScript(scripts_manager.Script):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def profile_fn(arch):
|
def profile_fn(arch):
|
||||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
|
||||||
rocm_mgr.apply_profile(arch)
|
rocm_mgr.apply_profile(arch)
|
||||||
updated = rocm_mgr.load_config()
|
updated = rocm_mgr.load_config()
|
||||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||||
result = [gr.update(value=_build_style(unavailable))]
|
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||||
|
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||||
for pname in var_names:
|
for pname in var_names:
|
||||||
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||||
val = updated.get(pname, meta["default"])
|
val = updated.get(pname, meta["default"])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue