Compare commits

...

48 Commits

Author SHA1 Message Date
Vladimir Mandic 0eb4a98e07
Merge branch 'dev' into master 2026-04-04 11:10:28 +02:00
vladmandic 155dabc840 cleanup
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 11:09:39 +02:00
vladmandic 2fcabc8047 fix upscaler init causing server fail
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 11:03:34 +02:00
vladmandic d98d05ca2d update wiki
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 10:53:07 +02:00
Vladimir Mandic 27a62cfa70
Merge pull request #4734 from vladmandic/master
refresh dev
2026-04-04 09:00:28 +02:00
Vladimir Mandic d97191f342
Merge branch 'dev' into master 2026-04-04 09:00:21 +02:00
vladmandic fbf1a962f2 refresh
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 08:59:55 +02:00
vladmandic d7904b239f add ftfy
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 08:58:27 +02:00
vladmandic 90b5e7de30 update todo/changelog
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 08:56:37 +02:00
Vladimir Mandic 08c28ab257
Merge pull request #4733 from vladmandic/master
refresh dev
2026-04-04 08:50:18 +02:00
Vladimir Mandic 0c94a169ea
Merge branch 'dev' into master 2026-04-04 08:50:09 +02:00
vladmandic 32b69bdd3d guard against spaces
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-04 08:47:59 +02:00
Disty0 b2e071dc52 cleanup 2026-04-04 01:39:26 +03:00
Disty0 470a0d816e SDNQ add tensor descriptor kernel to triton mm for Intel Arc 2026-04-04 01:32:34 +03:00
Disty0 ffeda702c5 Set default openvino_accuracy to no hint 2026-04-03 23:50:45 +03:00
Vladimir Mandic bfd9a0c0f5
Merge pull request #4726 from resonantsky/dev
Added further rocblas support enhancements
2026-04-03 15:59:24 +02:00
resonantsky 25af3242c3
Merge branch 'vladmandic:dev' into dev 2026-04-03 15:57:49 +02:00
Vladimir Mandic dc6f20ec8f
Merge pull request #4729 from liutyi/dev
sdxs-1b reference image
2026-04-03 15:50:03 +02:00
resonantsky a809b616e6 Restoring platform agnosticism, Linux users report OK 2026-04-03 15:46:15 +02:00
Oleksandr Liutyi 88bde026f7 sdxs-1b reference image 2026-04-03 13:12:08 +00:00
resonantsky e49d6262e9 UI edits, small corrections 2026-04-03 12:25:06 +02:00
resonantsky ac9aacac66 edited if installer.torch_info line 470 2026-04-03 11:58:43 +02:00
resonantsky 2177609e54 further fixes requested by review 2026-04-03 11:28:25 +02:00
resonantsky ee3b141297 fixes as requested by review 2026-04-03 11:21:53 +02:00
resonantsky 24f4490a59
Apply suggestion from @awsr
Co-authored-by: awsr <43862868+awsr@users.noreply.github.com>
2026-04-03 11:09:55 +02:00
resonantsky d2a47ee0ed code quality and layout fixes 2026-04-03 10:29:53 +02:00
resonantsky 01d53edb25 code quality and layout fixes 2026-04-03 10:12:21 +02:00
resonantsky 4cafae9350
Merge branch 'vladmandic:dev' into dev 2026-04-03 08:53:18 +02:00
resonantsky b659a06c60 rocm_mgr: easy fixes - simplify _get_venv to sys.prefix, _get_root to direct script_path import 2026-04-02 21:13:22 +02:00
vladmandic 9d0ecde462 add sdxs
Signed-off-by: vladmandic <mandic00@live.com>
2026-04-02 20:15:30 +02:00
resonantsky fdc2f46457 further rocblas default settings 2026-04-02 16:20:30 +02:00
resonantsky 1ed2811c80 rocm_mgr: use modules.paths.script_path for app root 2026-04-02 15:21:54 +02:00
resonantsky f5c037a735 Added further rocblas support enhancements and performance-related best practice settings. 2026-04-02 14:59:16 +02:00
Vladimir Mandic 668a94141d
Merge pull request #4722 from awsr/various1
typing and typechecks updates
2026-04-02 08:36:41 +02:00
Vladimir Mandic 999cbe5d3a
Merge branch 'dev' into various1 2026-04-02 08:36:24 +02:00
awsr 3dd09fde08
Syntax change 2026-03-30 20:09:02 -07:00
awsr 95dadab5c3
Revert import format change
Not sure why, but it was causing errors
2026-03-29 02:04:22 -07:00
awsr 59b9ca50ee
Guard against divide by zero and out-of-bounds 2026-03-29 01:25:22 -07:00
awsr eeb9b6291b
Update typing + enforce variable not unbound + import sort 2026-03-28 19:27:45 -07:00
awsr 8d6ec348b2
Refactor get_grid_size
Type safe, avoids redefining parameter types, and the static type checker is able to parse it easily.
2026-03-28 19:18:55 -07:00
awsr a1b03a383c
Upgrade Grid to typed NamedTuple 2026-03-28 19:17:59 -07:00
awsr ba362ad3ca
Type safety 2026-03-28 17:32:23 -07:00
awsr 7f07d4cb31
Update to match actual code logic 2026-03-28 17:31:37 -07:00
awsr 715b1b0699
More typing updates 2026-03-28 17:30:06 -07:00
awsr b2b6fdf9d5
Typing updates 2026-03-28 17:15:48 -07:00
awsr 8ef8074467
Upgrade to typed NamedTuples 2026-03-28 16:57:59 -07:00
awsr 1e9bef8d56
Type enforcement for ResGPU
Only tuples support typing a specific number of entries.
2026-03-27 14:15:44 -07:00
awsr fe7e4b40ff
Update models to match actual behavior
- ResScripts
- ResEmbeddings
2026-03-27 14:12:39 -07:00
40 changed files with 666 additions and 310 deletions

View File

@ -1,5 +1,21 @@
# Change Log for SD.Next
## Update for 2026-04-04
- **Models**
- [AiArtLab SDXS-1B](https://huggingface.co/AiArtLab/sdxs-1b) Simple Diffusion XS *(training still in progress)*
this model combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE
- **Compute**
- **ROCm** futher work on advanced configuration and tuning, thanks @resonantsky
see *main interface -> scripts -> rocm advanced config*
- **Internal**
- additional typing and typechecks, thanks @awsr
- Prohibit python==3.14 unless `--experimental`
- **Fixes**
- UI CSS fixes, thanks @awsr
- detect/warn if space in system path
- add `ftfy` to requirements
## Update for 2026-04-01
### Highlights for 2026-04-01

View File

@ -1,5 +1,7 @@
# TODO
<https://github.com/huggingface/diffusers/pull/13317>
## Internal
- Feature: implement `unload_auxiliary_models`

View File

@ -864,6 +864,16 @@
"extras": "sampler: Default, cfg_scale: 1.5, steps: 50",
"size": 15.3,
"date": "2025 January"
},
"AiArtLab SDXS-1B": {
"path": "AiArtLab/sdxs-1b",
"preview": "AiArtLab--sdxs-1b.jpg",
"desc": "Simple Diffusion XS (train in progress) combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE",
"skip": true,
"extras": "sampler: Default",
"size": 15.3,
"date": "2026 January"
}
}

@ -1 +1 @@
Subproject commit c7af727f31758c9fc96cf0429bcf3608858a15e8
Subproject commit e3720332f2301fa597c94b40897aa6e983020f1f

View File

@ -474,6 +474,8 @@ def check_python(supported_minors=None, experimental_minors=None, reason=None):
else:
git_version = git('--version', folder=None, ignore=False)
log.debug(f'Git: version={git_version.replace("git version", "").strip()}')
if ' ' in sys.executable:
log.warning(f'Python: path="{sys.executable}" contains spaces which may cause issues')
ts('python', t_start)
@ -1244,7 +1246,7 @@ def install_requirements():
# set environment variables controling the behavior of various libraries
def set_environment():
log.debug('Setting environment tuning')
os.environ.setdefault('PIP_CONSTRAINT', os.path.abspath('constraints.txt'))
os.environ.setdefault('PIP_CONSTRAINT', 'constraints.txt')
os.environ.setdefault('ACCELERATE', 'True')
os.environ.setdefault('ATTN_PRECISION', 'fp16')
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
@ -1277,7 +1279,7 @@ def set_environment():
os.environ.setdefault('MIOPEN_FIND_MODE', '2')
os.environ.setdefault('UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS', '1')
os.environ.setdefault('USE_TORCH', '1')
os.environ.setdefault('UV_CONSTRAINT', os.path.abspath('constraints.txt'))
os.environ.setdefault('UV_CONSTRAINT', 'constraints.txt')
os.environ.setdefault('UV_INDEX_STRATEGY', 'unsafe-any-match')
os.environ.setdefault('UV_NO_BUILD_ISOLATION', '1')
os.environ.setdefault('UVICORN_TIMEOUT_KEEP_ALIVE', '60')

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

View File

@ -42,7 +42,7 @@ Resut should always be: list[ResGPU]
class ResGPU(BaseModel):
name: str = Field(title="GPU Name")
data: dict = Field(title="Name/Value data")
chart: list[float, float] = Field(title="Exactly two items to place on chart")
chart: tuple[float, float] = Field(title="Exactly two items to place on chart")
"""
if __name__ == '__main__':

View File

@ -484,22 +484,22 @@ else:
FlagsModel = create_model("Flags", __config__=pydantic_config, **flags)
class ResEmbeddings(BaseModel):
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
skipped: list = Field(default=None, title="skipped", description="List of skipped embeddings")
loaded: list = Field(title="loaded", description="List of loaded embeddings")
skipped: list = Field(title="skipped", description="List of skipped embeddings")
class ResMemory(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
class ResScripts(BaseModel):
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
control: list = Field(default=None, title="Control", description="Titles of scripts (control)")
txt2img: list[str] = Field(title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list[str] = Field(title="Img2img", description="Titles of scripts (img2img)")
control: list[str] = Field(title="Control", description="Titles of scripts (control)")
class ResGPU(BaseModel): # definition of http response
name: str = Field(title="GPU Name", description="GPU device name")
data: dict = Field(title="Name/Value data", description="Key-value pairs of GPU metrics (utilization, temperature, clocks, memory, etc.)")
chart: list[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
chart: tuple[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
class ItemLoadedModel(BaseModel):
name: str = Field(title="Model Name", description="Model or component name")

View File

@ -72,7 +72,7 @@ def get_nvml():
"System load": f'GPU {load.gpu}% | VRAM {load.memory}% | Temp {pynvml.nvmlDeviceGetTemperature(dev, 0)}C | Fan {pynvml.nvmlDeviceGetFanSpeed(dev)}%',
'State': get_reason(pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(dev)),
}
chart = [load.memory, load.gpu]
chart = (load.memory, load.gpu)
devices.append({
'name': name,
'data': data,

View File

@ -96,7 +96,7 @@ def get_rocm_smi():
'Throttle reason': str(ThrottleStatus(int(rocm_smi_data[key].get("throttle_status", 0)))),
}
name = rocm_smi_data[key].get('Device Name', 'unknown')
chart = [load["memory"], load["gpu"]]
chart = (load["memory"], load["gpu"])
devices.append({
'name': name,
'data': data,

View File

@ -29,7 +29,7 @@ def get_xpu_smi():
"VRAM usage": f'{round(100 * load["memory"] / total)}% | {load["memory"]} MB used | {total - load["memory"]} MB free | {total} MB total',
"RAM usage": f'{round(100 * ram["used"] / ram["total"])}% | {round(1024 * ram["used"])} MB used | {round(1024 * ram["free"])} MB free | {round(1024 * ram["total"])} MB total',
}
chart = [load["memory"], load["gpu"]]
chart = (load["memory"], load["gpu"])
devices.append({
'name': torch.xpu.get_device_name(),
'data': data,

View File

@ -40,7 +40,7 @@ def ts2utc(timestamp: int) -> datetime:
except Exception:
return "unknown"
def active():
def active() -> list[Extension]:
if shared.opts.disable_all_extensions == "all":
return []
elif shared.opts.disable_all_extensions == "user":
@ -185,12 +185,12 @@ class Extension:
log.error(f"Extension: failed reading data from git repo={self.name}: {ex}")
self.remote = None
def list_files(self, subdir, extension):
from modules import scripts_manager
def list_files(self, subdir: str, extension: str):
from modules.scripts_manager import ScriptFile
dirpath = os.path.join(self.path, subdir)
res: list[ScriptFile] = []
if not os.path.isdir(dirpath):
return []
res = []
return res
for filename in sorted(os.listdir(dirpath)):
if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"):
continue
@ -198,7 +198,7 @@ class Extension:
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
priority = str(f.read().strip())
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
res.append(ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
if priority != '50':
log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]

View File

@ -1,15 +1,23 @@
import math
from collections import namedtuple
from typing import NamedTuple
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from modules import shared, script_callbacks
from PIL import Image, ImageDraw, ImageFont
from modules import script_callbacks, shared
from modules.logger import log
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
class Grid(NamedTuple):
tiles: list
tile_w: int
tile_h: int
image_w: int
image_h: int
overlap: int
def check_grid_size(imgs):
def check_grid_size(imgs: list[Image.Image] | list[list[Image.Image]] | None):
if imgs is None or len(imgs) == 0:
return False
mp = 0
@ -26,39 +34,36 @@ def check_grid_size(imgs):
return ok
def get_grid_size(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
if rows and rows > len(imgs):
rows = len(imgs)
if cols and cols > len(imgs):
cols = len(imgs)
def get_grid_size(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
rows_int, cols_int = len(imgs), len(imgs)
if rows is None and cols is None:
if shared.opts.n_rows > 0:
rows = shared.opts.n_rows
cols = math.ceil(len(imgs) / rows)
elif shared.opts.n_rows == 0:
rows = batch_size
cols = math.ceil(len(imgs) / rows)
elif shared.opts.n_cols > 0:
cols = shared.opts.n_cols
rows = math.ceil(len(imgs) / cols)
elif shared.opts.n_cols == 0:
cols = batch_size
rows = math.ceil(len(imgs) / cols)
n_rows, n_cols = shared.opts.n_rows, shared.opts.n_cols
if n_rows >= 0:
rows_int: int = batch_size if n_rows == 0 else n_rows
cols_int = math.ceil(len(imgs) / rows_int)
elif n_cols >= 0:
cols_int: int = batch_size if n_cols == 0 else n_cols
rows_int = math.ceil(len(imgs) / cols_int)
else:
rows = math.floor(math.sqrt(len(imgs)))
while len(imgs) % rows != 0:
rows -= 1
cols = math.ceil(len(imgs) / rows)
elif rows is not None and cols is None:
cols = math.ceil(len(imgs) / rows)
elif rows is None and cols is not None:
rows = math.ceil(len(imgs) / cols)
else:
pass
return rows, cols
rows_int = math.floor(math.sqrt(len(imgs)))
while len(imgs) % rows_int != 0:
rows_int -= 1
cols_int = math.ceil(len(imgs) / rows_int)
return rows_int, cols_int
# Set limits
if rows is not None:
rows_int = max(min(rows, len(imgs)), 1)
if cols is not None:
cols_int = max(min(cols, len(imgs)), 1)
# Calculate
if rows is None:
rows_int = math.ceil(len(imgs) / cols_int)
if cols is None:
cols_int = math.ceil(len(imgs) / rows_int)
return rows_int, cols_int
def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
def image_grid(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols)
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
script_callbacks.image_grid_callback(params)
@ -73,7 +78,7 @@ def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = N
return grid
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
def split_grid(image: Image.Image, tile_w=512, tile_h=512, overlap=64):
w = image.width
h = image.height
non_overlap_width = tile_w - overlap
@ -98,7 +103,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
return grid
def combine_grid(grid):
def combine_grid(grid: Grid):
def make_mask_image(r):
r = r * 255 / grid.overlap
r = r.astype(np.uint8)
@ -127,18 +132,18 @@ class GridAnnotation:
def __init__(self, text='', is_active=True):
self.text = str(text)
self.is_active = is_active
self.size = None
self.size: tuple[int, int] = (10, 10) # Placeholder values
def get_font(fontsize):
def get_font(fontsize: float):
try:
return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize)
except Exception:
return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize)
def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=None):
def wrap(drawing, text, font, line_length):
def draw_grid_annotations(im: Image.Image, width: int, height: int, x_texts: list[list[GridAnnotation]], y_texts: list[list[GridAnnotation]], margin=0, title: list[GridAnnotation] | None = None):
def wrap(drawing: ImageDraw.ImageDraw, text, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
@ -148,7 +153,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
lines.append(word)
return lines
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
def draw_texts(drawing: ImageDraw.ImageDraw, draw_x: float, draw_y: float, lines, initial_fnt: ImageFont.FreeTypeFont, initial_fontsize: int):
for line in lines:
font = initial_fnt
fontsize = initial_fontsize
@ -185,9 +190,10 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in x_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in y_texts]
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
title_text_heights = []
title_pad = 0
if title:
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts]
title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
for row in range(rows):
@ -210,7 +216,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
return result
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
def draw_prompt_matrix(im: Image.Image, width: int, height: int, all_prompts: list[str], margin=0):
prompts = all_prompts[1:]
boundary = math.ceil(len(prompts) / 2)
prompts_horiz = prompts[:boundary]

View File

@ -100,6 +100,8 @@ def get_model_type(pipe):
model_type = 'hunyuanimage3'
elif 'HunyuanImage' in name:
model_type = 'hunyuanimage'
elif 'sdxs-1b' in name:
model_type = 'sdxs'
# video models
elif "Kandinsky5" in name and '2V' in name:
model_type = 'kandinsky5video'

View File

@ -449,14 +449,19 @@ def load_upscalers():
used_classes[classname] = cls
upscaler_types = []
for cls in reversed(used_classes.values()):
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
commandline_model_path = commandline_options.get(cmd_name, None)
scaler = cls(commandline_model_path)
scaler.user_path = commandline_model_path
scaler.model_download_path = commandline_model_path or scaler.model_path
upscalers += scaler.scalers
upscaler_types.append(name[8:])
try:
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
commandline_model_path = commandline_options.get(cmd_name, None)
scaler = cls(commandline_model_path)
scaler.user_path = commandline_model_path
scaler.model_download_path = commandline_model_path or scaler.model_path
upscalers += scaler.scalers
upscaler_types.append(name[8:])
except Exception as e:
log.error(f'Upscaler: {cls} {e}')
if len(upscalers) == 0:
log.error('Upscalers: no data')
shared.sd_upscalers = upscalers
t1 = time.time()
log.info(f"Available Upscalers: items={len(shared.sd_upscalers)} downloaded={len([x for x in shared.sd_upscalers if x.data_path is not None and os.path.isfile(x.data_path)])} user={len([x for x in shared.sd_upscalers if x.custom])} time={t1-t0:.2f} types={upscaler_types}")

View File

@ -325,8 +325,8 @@ class StableDiffusionProcessing:
# initializers
self.prompt = prompt
self.seed = seed
self.subseed = subseed
self.seed = int(seed)
self.subseed = int(subseed)
self.subseed_strength = subseed_strength
self.seed_resize_from_h = seed_resize_from_h
self.seed_resize_from_w = seed_resize_from_w

View File

@ -1,14 +1,23 @@
from __future__ import annotations
import os
import re
import sys
import time
from collections import namedtuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, NamedTuple
import gradio as gr
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
from modules.logger import log
from installer import control_extensions
if TYPE_CHECKING:
from types import ModuleType
from modules.api.models import ItemScript
from gradio.blocks import Block
from gradio.components import IOComponent
from modules.processing import Processed, StableDiffusionProcessing
AlwaysVisible = object()
time_component = {}
@ -28,22 +37,22 @@ class PostprocessBatchListArgs:
@dataclass
class OnComponent:
component: gr.blocks.Block
component: Block
class Script:
parent = None
name = None
filename = None
parent: str | None = None
name: str | None = None
filename: str | None = None
args_from = 0
args_to = 0
alwayson = False
is_txt2img = False
is_img2img = False
api_info = None
api_info: ItemScript | None = None
group = None
infotext_fields = None
paste_field_names = None
infotext_fields: list | None = None
paste_field_names: list[str] | None = None
section = None
standalone = False
external = False
@ -57,14 +66,14 @@ class Script:
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
raise NotImplementedError
def ui(self, is_img2img):
def ui(self, is_img2img) -> list[IOComponent]:
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
pass # pylint: disable=unnecessary-pass
def show(self, is_img2img): # pylint: disable=unused-argument
def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument
"""
is_img2img is True if this function is called for the img2img interface, and False otherwise
This function should return:
@ -74,7 +83,7 @@ class Script:
"""
return True
def run(self, p, *args):
def run(self, p: StableDiffusionProcessing, *args):
"""
This function is called if the script has been selected in the script dropdown.
It must do all processing and return the Processed object with results, same as
@ -84,13 +93,13 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def setup(self, p, *args):
def setup(self, p: StableDiffusionProcessing, *args):
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
args contains all values returned by components from ui().
"""
pass # pylint: disable=unnecessary-pass
def before_process(self, p, *args):
def before_process(self, p: StableDiffusionProcessing, *args):
"""
This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -98,7 +107,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process(self, p, *args):
def process(self, p: StableDiffusionProcessing, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -106,7 +115,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process_images(self, p, *args):
def process_images(self, p: StableDiffusionProcessing, *args):
"""
This function is called instead of main processing for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
@ -114,7 +123,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def before_process_batch(self, p, *args, **kwargs):
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Called before extra networks are parsed from the prompt, so you can add
new extra network keywords to the prompt with this callback.
@ -126,7 +135,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def process_batch(self, p, *args, **kwargs):
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
@ -137,7 +146,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess_batch(self, p, *args, **kwargs):
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
"""
Same as process_batch(), but called for every batch after it has been generated.
**kwargs will have same items as process_batch, and also:
@ -146,13 +155,13 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""
pass # pylint: disable=unnecessary-pass
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs):
"""
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
This is useful when you want to update the entire batch instead of individual images.
@ -168,14 +177,14 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def postprocess(self, p, processed, *args):
def postprocess(self, p: StableDiffusionProcessing, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
args contains all values returned by components from ui()
"""
pass # pylint: disable=unnecessary-pass
def before_component(self, component, **kwargs):
def before_component(self, component: IOComponent, **kwargs):
"""
Called before a component is created.
Use elem_id/label fields of kwargs to figure out which component it is.
@ -184,7 +193,7 @@ class Script:
"""
pass # pylint: disable=unnecessary-pass
def after_component(self, component, **kwargs):
def after_component(self, component: IOComponent, **kwargs):
"""
Called after a component is created. Same as above.
"""
@ -194,7 +203,7 @@ class Script:
"""unused"""
return ""
def elem_id(self, item_id):
def elem_id(self, item_id: str):
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{self.parent}_{title}_{item_id}'
@ -211,21 +220,33 @@ def basedir():
return current_basedir
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
class ScriptFile(NamedTuple):
basedir: str
filename: str
path: str
priority: str
class ScriptClassData(NamedTuple):
script_class: type[Script] | type[scripts_postprocessing.ScriptPostprocessing]
path: str
basedir: str
module: ModuleType
scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension):
tmp_list = []
def list_scripts(scriptdirname: str, extension: str):
tmp_list: list[ScriptFile] = []
base = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(base):
for filename in sorted(os.listdir(base)):
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
for ext in extensions.active():
tmp_list += ext.list_files(scriptdirname, extension)
priority_list = []
priority_list: list[ScriptFile] = []
for script in tmp_list:
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
if script.basedir == paths.script_path:
@ -250,7 +271,7 @@ def list_scripts(scriptdirname, extension):
return priority_sort
def list_files_with_name(filename):
def list_files_with_name(filename: str):
res = []
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
for dirpath in dirs:
@ -273,7 +294,7 @@ def load_scripts():
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
syspath = sys.path
def register_scripts_from_module(module, scriptfile):
def register_scripts_from_module(module: ModuleType, scriptfile):
for script_class in module.__dict__.values():
if type(script_class) != type:
continue
@ -305,7 +326,7 @@ def load_scripts():
return t, time.time()-t0
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
def wrap_call(func: Callable, filename: str, funcname: str, *args, default=None, **kwargs):
try:
res = func(*args, **kwargs)
return res
@ -315,7 +336,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptSummary:
def __init__(self, op):
def __init__(self, op: str):
self.start = time.time()
self.update = time.time()
self.op = op
@ -336,17 +357,17 @@ class ScriptSummary:
class ScriptRunner:
def __init__(self, name=''):
self.name = name
self.scripts = []
self.selectable_scripts = []
self.alwayson_scripts = []
self.auto_processing_scripts = []
self.titles = []
self.alwayson_titles = []
self.infotext_fields = []
self.paste_field_names = []
self.scripts: list[Script] = []
self.selectable_scripts: list[Script] = []
self.alwayson_scripts: list[Script] = []
self.auto_processing_scripts: list[ScriptClassData] = []
self.titles: list[str] = []
self.alwayson_titles: list[str] = []
self.infotext_fields: list[tuple[IOComponent, str]] = []
self.paste_field_names: list[str] = []
self.script_load_ctr = 0
self.is_img2img = False
self.inputs = [None]
self.inputs: list = [None]
self.time = 0
def add_script(self, script_class, path, is_img2img, is_control):
@ -557,7 +578,7 @@ class ScriptRunner:
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
return inputs
def run(self, p, *args):
def run(self, p: StableDiffusionProcessing, *args) -> Processed | None:
s = ScriptSummary('run')
script_index = args[0] if len(args) > 0 else 0
if (script_index is None) or (script_index == 0):
@ -582,7 +603,7 @@ class ScriptRunner:
s.report()
return processed
def after(self, p, processed, *args):
def after(self, p: StableDiffusionProcessing, processed: Processed, *args):
s = ScriptSummary('after')
script_index = args[0] if len(args) > 0 else 0
if (script_index is None) or (script_index == 0):
@ -600,7 +621,7 @@ class ScriptRunner:
s.report()
return processed
def before_process(self, p, **kwargs):
def before_process(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('before-process')
for script in self.alwayson_scripts:
try:
@ -612,7 +633,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def process(self, p, **kwargs):
def process(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('process')
for script in self.alwayson_scripts:
try:
@ -624,7 +645,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def process_images(self, p, **kwargs):
def process_images(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('process_images')
processed = None
for script in self.alwayson_scripts:
@ -640,7 +661,7 @@ class ScriptRunner:
s.report()
return processed
def before_process_batch(self, p, **kwargs):
def before_process_batch(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('before-process-batch')
for script in self.alwayson_scripts:
try:
@ -652,7 +673,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def process_batch(self, p, **kwargs):
def process_batch(self, p: StableDiffusionProcessing, **kwargs):
s = ScriptSummary('process-batch')
for script in self.alwayson_scripts:
try:
@ -664,7 +685,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess(self, p, processed):
def postprocess(self, p: StableDiffusionProcessing, processed):
s = ScriptSummary('postprocess')
for script in self.alwayson_scripts:
try:
@ -676,7 +697,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess_batch(self, p, images, **kwargs):
def postprocess_batch(self, p: StableDiffusionProcessing, images, **kwargs):
s = ScriptSummary('postprocess-batch')
for script in self.alwayson_scripts:
try:
@ -688,7 +709,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, **kwargs):
s = ScriptSummary('postprocess-batch-list')
for script in self.alwayson_scripts:
try:
@ -700,7 +721,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def postprocess_image(self, p, pp: PostprocessImageArgs):
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs):
s = ScriptSummary('postprocess-image')
for script in self.alwayson_scripts:
try:
@ -712,7 +733,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def before_component(self, component, **kwargs):
def before_component(self, component: IOComponent, **kwargs):
s = ScriptSummary('before-component')
for script in self.scripts:
try:
@ -722,7 +743,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def after_component(self, component, **kwargs):
def after_component(self, component: IOComponent, **kwargs):
s = ScriptSummary('after-component')
for script in self.scripts:
for elem_id, callback in script.on_after_component_elem_id:
@ -738,7 +759,7 @@ class ScriptRunner:
s.record(script.title())
s.report()
def reload_sources(self, cache):
def reload_sources(self, cache: dict):
s = ScriptSummary('reload-sources')
for si, script in list(enumerate(self.scripts)):
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):

View File

@ -150,6 +150,8 @@ def guess_by_name(fn, current_guess):
new_guess = 'Ovis-Image'
elif 'glm-image' in fn.lower():
new_guess = 'GLM-Image'
elif 'sdxs-1b' in fn.lower():
new_guess = 'SDXS'
if debug_load:
log.trace(f'Autodetect: method=name file="{fn}" previous="{current_guess}" current="{new_guess}"')
return new_guess or current_guess
@ -166,6 +168,8 @@ def guess_by_diffusers(fn, current_guess):
if name is not None and name in exclude_by_name:
return current_guess, None
cls = index.get('_class_name', None)
if isinstance(cls, list):
cls = cls[-1]
if cls is not None:
pipeline = getattr(diffusers, cls, None)
if pipeline is None:

View File

@ -8,7 +8,7 @@ def hijack_encode_prompt(*args, **kwargs):
jobid = shared.state.begin('TE Encode')
t0 = time.time()
if 'max_sequence_length' in kwargs and kwargs['max_sequence_length'] is not None:
kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('HIDREAM_MAX_SEQUENCE_LENGTH', 256))
kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('MAX_SEQUENCE_LENGTH', 256))
try:
prompt = kwargs.get('prompt', None) or (args[0] if len(args) > 0 else None)
if prompt is not None:
@ -20,8 +20,6 @@ def hijack_encode_prompt(*args, **kwargs):
res = None
t1 = time.time()
timer.process.add('te', t1-t0)
# if hasattr(shared.sd_model, "maybe_free_model_hooks"):
# shared.sd_model.maybe_free_model_hooks()
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
shared.state.end(jobid)
return res

View File

@ -507,6 +507,10 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
from pipelines.model_glm import load_glm_image
sd_model = load_glm_image(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['SDXS']:
from pipelines.model_sdxs import load_sdxs
sd_model = load_sdxs(checkpoint_info, diffusers_load_config)
allow_post_quant = False
except Exception as e:
log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
# if debug_load:
@ -1418,7 +1422,7 @@ def hf_auth_check(checkpoint_info, force:bool=False):
return False
def save_model(name: str, path: str | None = None, shard: str | None = None, overwrite = False):
def save_model(name: str, path: str | None = None, shard: str = "5GB", overwrite = False):
if (name is None) or len(name.strip()) == 0:
log.error('Save model: invalid model name')
return 'Invalid model name'
@ -1432,6 +1436,8 @@ def save_model(name: str, path: str | None = None, shard: str | None = None, ove
if os.path.exists(model_name) and not overwrite:
log.error(f'Save model: path="{model_name}" exists')
return f'Path exists: {model_name}'
if not shard.strip():
shard = "5GB" # Guard against empty input
try:
t0 = time.time()
save_sdnq_model(

View File

@ -38,8 +38,8 @@ def load_unet_sdxl_nunchaku(repo_id):
def load_unet(model, repo_id: str | None = None):
global loaded_unet # pylint: disable=global-statement
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and (('stable-diffusion-xl-base' in repo_id) or ('sdxl-turbo' in repo_id)):
if model_quant.check_nunchaku('Model'):
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and repo_id is not None and (("stable-diffusion-xl-base" in repo_id) or ("sdxl-turbo" in repo_id)):
if model_quant.check_nunchaku("Model"):
unet = load_unet_sdxl_nunchaku(repo_id)
if unet is not None:
model.unet = unet

View File

@ -40,18 +40,18 @@ def conv_fp16_matmul(
scale = scale.t()
elif weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16) # fp8 weights
input, scale = quantize_fp_mm_input_tensorwise(input, scale, matmul_dtype="float16")
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype, matmul_dtype="float16")
input, weight = check_mats(input, weight)
if groups == 1:
result = fp_mm_func(input, weight)
result = fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
else:
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = []
for i in range(groups):
result.append(fp_mm_func(input[:, i], weight[:, i]))
result = torch.cat(result, dim=-1)
result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
if bias is not None:
dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
else:

View File

@ -38,19 +38,19 @@ def conv_fp8_matmul_tensorwise(
if quantized_weight_shape is not None:
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_()
scale = scale.t()
input, scale = quantize_fp_mm_input_tensorwise(input, scale)
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype)
input, weight = check_mats(input, weight)
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
if groups == 1:
result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype)
result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).mul_(input_scale)
else:
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = []
for i in range(groups):
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype))
result = torch.cat(result, dim=-1)
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype))
result = torch.cat(result, dim=-1).mul_(input_scale)
if bias is not None:
dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
else:

View File

@ -38,18 +38,18 @@ def conv_int8_matmul(
if quantized_weight_shape is not None:
weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_()
scale = scale.t()
input, scale = quantize_int_mm_input(input, scale)
input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
input, weight = check_mats(input, weight)
if groups == 1:
result = int_mm_func(input, weight)
result = int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
else:
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
input = input.view(input.shape[0], groups, input.shape[1] // groups)
result = []
for i in range(groups):
result.append(int_mm_func(input[:, i], weight[:, i]))
result = torch.cat(result, dim=-1)
result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
if bias is not None:
result = dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
else:

View File

@ -33,12 +33,12 @@ def fp16_matmul(
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
input, scale = quantize_fp_mm_input_tensorwise(input, scale, matmul_dtype="float16")
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype, matmul_dtype="float16")
input, weight = check_mats(input, weight)
if bias is not None:
return dequantize_symmetric_with_bias(fp_mm_func(input, weight), scale, bias, dtype=return_dtype, result_shape=output_shape)
return dequantize_symmetric_with_bias(fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
else:
return dequantize_symmetric(fp_mm_func(input, weight), scale, dtype=return_dtype, result_shape=output_shape)
return dequantize_symmetric(fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
def quantized_linear_forward_fp16_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -9,13 +9,14 @@ from ...dequantizer import quantize_fp_mm, dequantize_symmetric, dequantize_symm
from .forward import check_mats
def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype)
def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, dtype: torch.dtype | None = None, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
input = input.flatten(0,-2)
if dtype is not None:
input = input.to(dtype=dtype)
input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype)
scale = torch.mul(input_scale, scale)
if scale.dtype == torch.float16: # fp16 will overflow
scale = scale.to(dtype=torch.float32)
return input, scale
if input_scale.dtype == torch.float16: # fp16 will overflow
input_scale = input_scale.to(dtype=torch.float32)
return input, input_scale
def fp8_matmul_tensorwise(
@ -40,12 +41,12 @@ def fp8_matmul_tensorwise(
else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
input, scale = quantize_fp_mm_input_tensorwise(input, scale)
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype)
input, weight = check_mats(input, weight, allow_contiguous_mm=False)
if bias is not None:
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, bias, dtype=return_dtype, result_shape=output_shape)
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
else:
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, dtype=return_dtype, result_shape=output_shape)
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -9,13 +9,14 @@ from ...packed_int import unpack_int # noqa: TID252
from .forward import check_mats
def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> tuple[torch.CharTensor, torch.FloatTensor]:
input = input.flatten(0,-2).to(dtype=scale.dtype)
def quantize_int_mm_input(input: torch.FloatTensor, dtype: torch.dtype | None = None) -> tuple[torch.CharTensor, torch.FloatTensor]:
input = input.flatten(0,-2)
if dtype is not None:
input = input.to(dtype=dtype)
input, input_scale = quantize_int_mm(input, dim=-1)
scale = torch.mul(input_scale, scale)
if scale.dtype == torch.float16: # fp16 will overflow
scale = scale.to(dtype=torch.float32)
return input, scale
if input_scale.dtype == torch.float16: # fp16 will overflow
input_scale = input_scale.to(dtype=torch.float32)
return input, input_scale
def int8_matmul(
@ -39,12 +40,12 @@ def int8_matmul(
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
else:
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
input, scale = quantize_int_mm_input(input, scale)
input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
input, weight = check_mats(input, weight)
if bias is not None:
return dequantize_symmetric_with_bias(int_mm_func(input, weight), scale, bias, dtype=return_dtype, result_shape=output_shape)
return dequantize_symmetric_with_bias(int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
else:
return dequantize_symmetric(int_mm_func(input, weight), scale, dtype=return_dtype, result_shape=output_shape)
return dequantize_symmetric(int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
def quantized_linear_forward_int8_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:

View File

@ -1,8 +1,10 @@
"""
Modified from Triton MatMul example.
PyTorch torch._int_mm is broken on backward pass with Nvidia.
AMD RDNA2 doesn't support torch._int_mm, so we use int_mm via Triton.
PyTorch doesn't support FP32 output type with FP16 MM so we use Triton for it too.
PyTorch torch._int_mm is broken on backward pass with Nvidia, so we use Triton on the backward pass with Nvidia.
AMD RDNA2 doesn't support torch._int_mm as it requires INT8 WMMA, so we use INT8 DP4A via Triton.
PyTorch doesn't support FP32 output type with FP16 MM, so we use Triton for FP16 MM too.
matmul_configs we use takes AMD and Intel into consideration too.
SDNQ Triton configs can outperform RocBLAS and OneDNN.
"""
import torch
@ -22,7 +24,7 @@ matmul_configs = [
]
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"])
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"], cache_results=True)
@triton.jit
def triton_mm_kernel(
a_ptr, b_ptr, c_ptr,
@ -76,6 +78,55 @@ def triton_mm_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
# Intel requires tensor descriptors to perform good
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"], cache_results=True)
@triton.jit
def triton_mm_td_kernel(
a_ptr, b_ptr, c_ptr,
M: int, N: int, K: int,
stride_am: int, stride_ak: int,
stride_bk: int, stride_bn: int,
stride_cm: int, stride_cn: int,
ACCUMULATOR_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
off_k = 0
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
for _ in range(0, K, BLOCK_SIZE_K):
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
off_k += BLOCK_SIZE_K
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], accumulator)
def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
@ -84,7 +135,8 @@ def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
triton_mm_kernel[grid](
mm_kernel_func = triton_mm_td_kernel if b.is_contiguous() else triton_mm_kernel
mm_kernel_func[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
@ -103,7 +155,8 @@ def fp_mm(a: torch.FloatTensor, b: torch.FloatTensor) -> torch.FloatTensor:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
triton_mm_kernel[grid](
mm_kernel_func = triton_mm_td_kernel if b.is_contiguous() else triton_mm_kernel
mm_kernel_func[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),

View File

@ -65,6 +65,7 @@ pipelines = {
'HunyuanImage3': getattr(diffusers, 'DiffusionPipeline', None),
'ChronoEdit': getattr(diffusers, 'DiffusionPipeline', None),
'Anima': getattr(diffusers, 'DiffusionPipeline', None),
'SDXS': getattr(diffusers, 'DiffusionPipeline', None),
}

View File

@ -265,7 +265,7 @@ def create_settings(cmd_opts):
"openvino_sep": OptionInfo("<h2>OpenVINO</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
"openvino_devices": OptionInfo([], "OpenVINO devices to use", gr.CheckboxGroup, {"choices": get_openvino_device_list() if cmd_opts.use_openvino else [], "visible": cmd_opts.use_openvino}),
"openvino_accuracy": OptionInfo("performance", "OpenVINO accuracy mode", gr.Radio, {"choices": ["performance", "accuracy"], "visible": cmd_opts.use_openvino}),
"openvino_accuracy": OptionInfo("default", "OpenVINO accuracy mode", gr.Radio, {"choices": ["default", "performance", "accuracy"], "visible": cmd_opts.use_openvino}),
"openvino_disable_model_caching": OptionInfo(True, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
"openvino_disable_memory_cleanup": OptionInfo(True, "OpenVINO disable memory cleanup after compile", gr.Checkbox, {"visible": cmd_opts.use_openvino}),

View File

@ -365,6 +365,8 @@ def create_resize_inputs(tab, images, accordion=True, latent=False, non_zero=Tru
with gr.Accordion(open=False, label="Resize", elem_classes=["small-accordion"], elem_id=f"{tab}_resize_group") if accordion else gr.Group():
with gr.Row():
available_upscalers = [x.name for x in shared.sd_upscalers]
if len(available_upscalers) == 0:
available_upscalers = ['None']
if not latent:
available_upscalers = [x for x in available_upscalers if not x.lower().startswith('latent')]
resize_mode = gr.Dropdown(label=f"Mode{prefix}" if non_zero else "Resize mode", elem_id=f"{tab}_resize_mode", choices=shared.resize_modes, type="index", value='Fixed')

59
pipelines/model_sdxs.py Normal file
View File

@ -0,0 +1,59 @@
import time
import diffusers
import transformers
from modules import shared, devices, errors, timer, sd_models, model_quant, sd_hijack_vae
from modules.logger import log
from pipelines import generic
def hijack_encode_text(prompt: str | list[str]):
jobid = shared.state.begin('TE Encode')
t0 = time.time()
try:
prompt = shared.sd_model.refine_prompts(prompt)
except Exception as e:
log.error(f'Encode prompt: {e}')
errors.display(e, 'Encode prompt')
try:
res = shared.sd_model.orig_encode_text(prompt)
except Exception as e:
log.error(f'Encode prompt: {e}')
errors.display(e, 'Encode prompt')
res = None
t1 = time.time()
timer.process.add('te', t1-t0)
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
shared.state.end(jobid)
return res
def load_sdxs(checkpoint_info, diffusers_load_config=None):
if diffusers_load_config is None:
diffusers_load_config = {}
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
log.debug(f'Load model: type=SDXS repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen3_5ForConditionalGeneration, load_config=diffusers_load_config, allow_shared=False)
pipe = diffusers.DiffusionPipeline.from_pretrained(
repo_id,
text_encoder=text_encoder,
cache_dir=shared.opts.diffusers_dir,
trust_remote_code=True,
**load_args,
)
pipe.task_args = {
'generator': None,
'output_type': 'np',
}
pipe.orig_encode_text = pipe.encode_text
pipe.encode_text = hijack_encode_text
sd_hijack_vae.init_hijack(pipe)
del text_encoder
devices.torch_gc(force=True, reason='load')
return pipe

View File

@ -18,6 +18,7 @@ fasteners
limits
orjson
websockets
ftfy
# versioned
fastapi==0.124.4

View File

@ -1,4 +1,11 @@
from __future__ import annotations
from modules.logger import log
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
repo_id = 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf'
policy_template = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
Hate:
@ -89,11 +96,11 @@ To provide your assessment use the following json template for each category:
"rationale": str,
}.
"""
model = None
processor = None
model: LlavaOnevisionForConditionalGeneration | None = None
processor: AutoProcessor | None = None
def image_guard(image, policy:str | None=None) -> str:
def image_guard(image, policy:str | None=None):
global model, processor # pylint: disable=global-statement
import json
from installer import install

View File

@ -49,14 +49,14 @@ def create_ui(accordion=True):
# main processing used in both modes
def process(
p: processing.StableDiffusionProcessing=None,
pp: scripts.PostprocessImageArgs=None,
p: processing.StableDiffusionProcessing | None = None,
pp: scripts.PostprocessImageArgs | scripts_postprocessing.PostprocessedImage | None = None,
enabled=True,
lang=False,
policy=False,
banned=False,
metadata=True,
copy=False,
copy=False, # Compatability
score=0.2,
blocks=3,
censor=[],
@ -74,7 +74,7 @@ def process(
nudes = nudenet.detector.censor(image=pp.image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay)
t1 = time.time()
if len(nudes.censored) > 0: # Check if there are any censored areas
if not copy:
if p is None:
pp.image = nudes.output
else:
info = processing.create_infotext(p)
@ -85,7 +85,7 @@ def process(
if metadata and p is not None:
p.extra_generation_params["NudeNet"] = meta
p.extra_generation_params["NSFW"] = nsfw
if metadata and hasattr(pp, 'info'):
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
pp.info['NudeNet'] = meta
pp.info['NSFW'] = nsfw
log.debug(f'NudeNet detect: {dct} nsfw={nsfw} time={(t1 - t0):.2f}')
@ -118,7 +118,7 @@ def process(
if metadata and p is not None:
p.extra_generation_params["Rating"] = res.get('rating', 'N/A')
p.extra_generation_params["Category"] = res.get('category', 'N/A')
if metadata and hasattr(pp, 'info'):
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
pp.info["Rating"] = res.get('rating', 'N/A')
pp.info["Category"] = res.get('category', 'N/A')
@ -131,11 +131,11 @@ class ScriptNudeNet(scripts.Script):
def title(self):
return 'NudeNet'
def show(self, _is_img2img):
def show(self, *args, **kwargs):
return scripts.AlwaysVisible
# return signature is array of gradio components
def ui(self, _is_img2img):
def ui(self, *args, **kwargs):
return create_ui(accordion=True)
# triggered by callback

View File

@ -1,4 +1,5 @@
import os
import re
import sys
from pathlib import Path
from typing import Dict, Optional
@ -6,39 +7,28 @@ from typing import Dict, Optional
import installer
from modules.logger import log
from modules.json_helpers import readfile, writefile
from modules.shared import opts
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
def _check_rocm() -> bool:
from modules import shared
if getattr(shared.cmd_opts, 'use_rocm', False):
return True
if installer.torch_info.get('type') == 'rocm':
return True
import torch # pylint: disable=import-outside-toplevel
return hasattr(torch.version, 'hip') and torch.version.hip is not None
is_rocm = _check_rocm()
CONFIG = Path(os.path.abspath(os.path.join('data', 'rocm.json')))
_cache: Optional[Dict[str, str]] = None # loaded once, invalidated on save
# Metadata key written into rocm.json to record which architecture profile is active.
# Not an environment variable always skipped during env application but preserved in the
# Not an environment variable - always skipped during env application but preserved in the
# saved config so that arch-safety enforcement is consistent across restarts.
_ARCH_KEY = "_rocm_arch"
# Vars that must never appear in the process environment.
#
# _DTYPE_UNSAFE: alter FP16 inference dtype must be cleared regardless of config
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL DEBUG alias: routes all FP16 convs through BF16 exponent math
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL API-level alias: same BF16-exponent effect
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM unstable experimental FP16 path
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 changes FP16 WrW atomic accumulation
# _DTYPE_UNSAFE: alter FP16 inference dtype - must be cleared regardless of config
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL - DEBUG alias: routes all FP16 convs through BF16 exponent math
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL - API-level alias: same BF16-exponent effect
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM - unstable experimental FP16 path
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 - changes FP16 WrW atomic accumulation
#
# SOLVER_DISABLED_BY_DEFAULT: every solver known to be incompatible with this runtime
# (FP32-only, training-only WrW/BWD, fixed-geometry mismatches, XDLOPS/CDNA-only, arch-specific).
@ -53,18 +43,18 @@ _DTYPE_UNSAFE = {
# regardless of saved config. Limited to dtype-corrupting vars only.
# IMPORTANT: SOLVER_DISABLED_BY_DEFAULT is intentionally NOT included here.
# When a solver var is absent (unset) MIOpen still calls IsApplicable() on every
# conv-find wasted probing overhead. When a var is explicitly "0" MIOpen skips
# conv-find - wasted probing overhead. When a var is explicitly "0" MIOpen skips
# IsApplicable() immediately. Solver defaults flow through the config loop as "0"
# (their ROCM_ENV_VARS default is "0") so they are explicitly set to "0" in the env.
_UNSET_VARS = _DTYPE_UNSAFE
# Additional environment vars that must be removed from the process before MIOpen loads.
# These are not MIOpen solver toggles but can corrupt MIOpen's runtime behaviour:
# HIP_PATH / HIP_PATH_71 point to the system AMD ROCm install; override the venv-bundled
# HIP_PATH / HIP_PATH_71 - point to the system AMD ROCm install; override the venv-bundled
# _rocm_sdk_devel DLLs with a potentially mismatched system version
# QML_*/QT_* QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
# QML_*/QT_* - QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
# PyTorch but can conflict with Gradio's embedded Qt helpers
# PYENV_VIRTUALENV_DISABLE_PROMPT pyenv noise that confuses venv detection
# PYENV_VIRTUALENV_DISABLE_PROMPT - pyenv noise that confuses venv detection
_EXTRA_CLEAR_VARS = {
"HIP_PATH",
"HIP_PATH_71",
@ -72,7 +62,7 @@ _EXTRA_CLEAR_VARS = {
"QML_DISABLE_DISK_CACHE",
"QML_FORCE_DISK_CACHE",
"QT_DISABLE_SHADER_DISK_CACHE",
# PERF_VALS vars are NOT boolean toggles MIOpen reads them as perf-config strings.
# PERF_VALS vars are NOT boolean toggles - MIOpen reads them as perf-config strings.
# If inherited from a parent shell with value "1", MIOpen's GetPerfConfFromEnv parses
# "1" as a degenerate config and can return dtype=float32 output from FP16 tensors.
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS",
@ -81,12 +71,12 @@ _EXTRA_CLEAR_VARS = {
# Solvers whose MIOpen IsApplicable() explicitly rejects non-FP32 tensors.
# They are safe to leave enabled in FP32 mode. When the active dtype is FP16 or BF16
# we force them OFF so MIOpen skips the IsApplicable probe entirely avoids overhead on
# we force them OFF so MIOpen skips the IsApplicable probe entirely - avoids overhead on
# every conv shape find. These are NOT in _UNSET_VARS because they are valid in FP32.
_FP32_ONLY_SOLVERS = {
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution FP32 only (MIOpen source: IsFp32 check)
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 FP32 only
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd FP32 only
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution - FP32 only (MIOpen source: IsFp32 check)
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 - FP32 only
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd - FP32 only
}
@ -106,8 +96,7 @@ def _resolve_dtype() -> str:
except Exception:
pass
try:
from modules import shared as _sh # pylint: disable=import-outside-toplevel
v = getattr(getattr(_sh, 'opts', None), 'cuda_dtype', None)
v = getattr(opts, 'cuda_dtype', None)
if v in ('FP16', 'BF16', 'FP32'):
return v
except Exception:
@ -118,17 +107,25 @@ def _resolve_dtype() -> str:
# --- venv helpers ---
def _get_venv() -> str:
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
return sys.prefix
def _get_root() -> str:
from modules.paths import script_path # pylint: disable=import-outside-toplevel
return str(script_path)
def _expand_venv(value: str) -> str:
return value.replace("{VIRTUAL_ENV}", _get_venv())
return value.replace("{VIRTUAL_ENV}", _get_venv()).replace("{ROOT}", _get_root())
def _collapse_venv(value: str) -> str:
venv = _get_venv()
root = _get_root()
if venv and value.startswith(venv):
return "{VIRTUAL_ENV}" + value[len(venv):]
if root and value.startswith(root):
return "{ROOT}" + value[len(root):]
return value
@ -163,7 +160,7 @@ def load_config() -> Dict[str, str]:
_cache = data if data else {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
# Purge unsafe vars from a stale saved config and re-persist only if the file existed.
# When running without a saved config (first run / after Delete), load_config() must
# never create the file that only happens via save_config() on Apply or Apply Profile.
# never create the file - that only happens via save_config() on Apply or Apply Profile.
dirty = {k for k in _cache if k in _UNSET_VARS or (k != _ARCH_KEY and k not in ROCM_ENV_VARS)}
if dirty:
_cache = {k: v for k, v in _cache.items() if k not in dirty}
@ -212,7 +209,7 @@ def apply_env(config: Optional[Dict[str, str]] = None) -> None:
os.environ[var] = expanded
# Arch safety net: hard-force all hardware-incompatible vars to "0" in the env.
# This runs *after* the config loop so it overrides any stale "1" that survived in the JSON.
# Source of truth: rocm_profiles.UNAVAILABLE[arch] vars with no supporting hardware.
# Source of truth: rocm_profiles.UNAVAILABLE[arch] - vars with no supporting hardware.
arch = config.get(_ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
if unavailable:
@ -240,7 +237,7 @@ def apply_all(names: list, values: list) -> None:
meta = ROCM_ENV_VARS[name]
if meta["widget"] == "checkbox":
if value is None:
pass # Gradio passed None (component not interacted with) leave config unchanged
pass # Gradio passed None (component not interacted with) - leave config unchanged
else:
config[name] = "1" if value else "0"
elif meta["widget"] == "radio":
@ -248,7 +245,7 @@ def apply_all(names: list, values: list) -> None:
valid = {v for _, v in meta["options"]} if meta["options"] and isinstance(meta["options"][0], tuple) else set(meta["options"] or [])
if stored in valid:
config[name] = stored
# else: value was None/invalid leave the existing saved value untouched
# else: value was None/invalid - leave the existing saved value untouched
else:
if meta.get("options"):
value = _dropdown_stored(str(value), meta["options"])
@ -291,7 +288,7 @@ def delete_config() -> None:
CONFIG.unlink()
log.info(f'ROCm delete_config: deleted {CONFIG}')
_cache = None
# Delete the MIOpen user DB (~/.miopen/db) stale entries can cause solver mismatches
# Delete the MIOpen user DB (~/.miopen/db) - stale entries can cause solver mismatches
miopen_db = Path(os.path.expanduser('~')) / '.miopen' / 'db'
if miopen_db.exists():
shutil.rmtree(miopen_db, ignore_errors=True)
@ -365,6 +362,28 @@ def _user_db_summary(path: Path) -> dict:
return out
def _extract_db_hash(db_path: Path) -> str:
"""Derive the cache subfolder name from udb.txt filenames.
e.g. gfx1030_30.HIP.3_5_1_5454e9e2da.udb.txt '3.5.1.5454e9e2da'"""
for f in db_path.glob("*.HIP.*.udb.txt"):
m = re.search(r'\.HIP\.([^.]+)\.udb\.txt$', f.name)
if m:
return m.group(1).replace("_", ".")
return ""
def _user_cache_summary(path: Path) -> dict:
"""Return {filename: 'N KB'} for binary cache blobs in the resolved cache path."""
out = {}
if not path.exists():
return out
for f in sorted(path.iterdir()):
if f.is_file():
kb = f.stat().st_size // 1024
out[f.name] = f"{kb} KB"
return out
def info() -> dict:
config = load_config()
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
@ -427,20 +446,29 @@ def info() -> dict:
if ufiles:
udb["files"] = ufiles
# User cache (~/.miopen/cache/<version-hash>)
cache_base = Path.home() / ".miopen" / "cache"
db_hash = _extract_db_hash(user_db_path) if user_db_path.exists() else ""
cache_path = cache_base / db_hash if db_hash else cache_base
ucache = {"path": str(cache_path), "exists": cache_path.exists()}
if cache_path.exists():
cfiles = _user_cache_summary(cache_path)
if cfiles:
ucache["files"] = cfiles
return {
"rocm": rocm_section,
"torch": torch_section,
"gpu": gpu_section,
"system_db": sdb,
"user_db": udb,
"user_cache": ucache,
}
# Apply saved config to os.environ at import time (only when ROCm is present)
if is_rocm:
if installer.torch_info.get('type', None) == 'rocm':
try:
apply_env()
except Exception as _e:
print(f"[rocm_mgr] Warning: failed to apply env at import: {_e}", file=sys.stderr)
else:
log.debug('ROCm is not installed — skipping rocm_mgr env apply')
log.debug(f"[rocm_mgr] Warning: failed to apply env at import: {_e}")

View File

@ -1,4 +1,4 @@
"""
"""
Architecture-specific MIOpen solver profiles for AMD GCN/RDNA GPUs.
Sources:
@ -6,8 +6,8 @@ Sources:
Key axis: consumer RDNA GPUs have NO XDLOPS hardware (that's CDNA/Instinct only).
RDNA2 (gfx1030): RX 6000 series
RDNA3 (gfx1100): RX 7000 series adds Fury Winograd, wider MPASS
RDNA4 (gfx1200): RX 9000 series adds Rage Winograd, wider MPASS
RDNA3 (gfx1100): RX 7000 series - adds Fury Winograd, wider MPASS
RDNA4 (gfx1200): RX 9000 series - adds Rage Winograd, wider MPASS
Each profile is a dict of {var: value} that will be MERGED on top of the
current config (general vars like DB path / log level are preserved).
@ -15,9 +15,9 @@ current config (general vars like DB path / log level are preserved).
from typing import Dict
# ---------------------------------------------------------------------------
# Shared: everything that must be OFF on ALL consumer RDNA (no XDLOPS hw)
# ---------------------------------------------------------------------------
_XDLOPS_OFF: Dict[str, str] = {
# GTC XDLOPS (CDNA-only)
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS": "0",
@ -55,7 +55,7 @@ _XDLOPS_OFF: Dict[str, str] = {
# MLIR (CDNA-only in practice)
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS": "0",
# MP BD Winograd (Multi-pass Block-Decomposed CDNA / high-end only)
# MP BD Winograd (Multi-pass Block-Decomposed - CDNA / high-end only)
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F2X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F3X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F4X3": "0",
@ -68,17 +68,17 @@ _XDLOPS_OFF: Dict[str, str] = {
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F6X3": "0",
}
# ---------------------------------------------------------------------------
# RDNA2 gfx1030 (RX 6000 series)
# RDNA2 - gfx1030 (RX 6000 series)
# No XDLOPS, no Fury/Rage Winograd, MPASS limited to F3x2/F3x3
# ASM IGEMM: V4R1 variants only; HIP IGEMM: non-XDLOPS V4R1/R4 only
# ---------------------------------------------------------------------------
RDNA2: Dict[str, str] = {
**_XDLOPS_OFF,
# General settings (architecture-independent; set here so all profiles cover them)
"MIOPEN_SEARCH_CUTOFF": "0",
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": "0",
# Core algo enables FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
# Core algo enables - FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
"MIOPEN_DEBUG_CONV_FFT": "1",
"MIOPEN_DEBUG_CONV_DIRECT": "1",
"MIOPEN_DEBUG_CONV_GEMM": "1",
@ -93,16 +93,16 @@ RDNA2: Dict[str, str] = {
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "1",
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "1",
"MIOPEN_DEBUG_ATTN_SOFTMAX": "1",
# Direct ASM dtype notes
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward enabled
# Direct ASM - dtype notes
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward - enabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches disabled
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches - disabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) - never matches SD - disabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) disabled for inference
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) - disabled for inference
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW3X3": "0",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW1X1": "0",
# PERF_VALS intentionally blank: MIOpen reads this as a config string not a boolean;
@ -110,30 +110,30 @@ RDNA2: Dict[str, str] = {
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS": "",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "1",
# NAIVE_CONV_FWD: scalar FP32 reference solver IsApplicable does NOT reliably filter for FP16;
# NAIVE_CONV_FWD: scalar FP32 reference solver - IsApplicable does NOT reliably filter for FP16;
# can be selected for unusual shapes (e.g. VAE decoder 3-ch output) and returns dtype=float32
"MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD": "0",
# Direct OCL dtype notes
# FWD / FWD1X1: FP32/FP16 forward enabled
# Direct OCL - dtype notes
# FWD / FWD1X1: FP32/FP16 forward - enabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
# FWD11X11: requires 11*11 kernel — no SD match — disabled
# FWD11X11: requires 11*11 kernel - no SD match - disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
# FWDGEN: FP32 generic OCL fallback IsApplicable does NOT reliably reject for FP16;
# can produce dtype=float32 output for FP16 inputs disabled
# FWDGEN: FP32 generic OCL fallback - IsApplicable does NOT reliably reject for FP16;
# can produce dtype=float32 output for FP16 inputs - disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWDGEN": "0",
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient disabled
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient - disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW2": "0",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW53": "0",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW1X1": "0",
# Winograd RxS dtype per MIOpen docs
# WINOGRAD_3X3: FP32-only harmless (IsApplicable rejects for fp16); enabled
# Winograd RxS - dtype per MIOpen docs
# WINOGRAD_3X3: FP32-only - harmless (IsApplicable rejects for fp16); enabled
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "1",
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW keep enabled (fp16 fwd/bwd path exists)
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW - keep enabled (fp16 fwd/bwd path exists)
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "1",
# RXS_FWD_BWD: FP32/FP16 explicitly the fp16-capable subset
# RXS_FWD_BWD: FP32/FP16 - explicitly the fp16-capable subset
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "1",
# RXS_WRW: FP32 WrW only training-only, disabled for inference fp16 profile
# RXS_WRW: FP32 WrW only - training-only, disabled for inference fp16 profile
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_WRW": "0",
# RXS_F3X2: FP32/FP16 Fwd/Bwd
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "1",
@ -141,15 +141,15 @@ RDNA2: Dict[str, str] = {
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "1",
# RXS_F2X3_G1: FP32/FP16 Fwd/Bwd (non-group convolutions)
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "1",
# FUSED_WINOGRAD: FP32-only harmless (IsApplicable rejects for fp16); enabled
# FUSED_WINOGRAD: FP32-only - harmless (IsApplicable rejects for fp16); enabled
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "1",
# PERF_VALS intentionally blank: same reason as ASM_1X1U not a boolean, config string
# PERF_VALS intentionally blank: same reason as ASM_1X1U - not a boolean, config string
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS": "",
# Fury/Rage Winograd NOT available on RDNA2
# Fury/Rage Winograd - NOT available on RDNA2
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "0",
# MPASS only F3x2 and F3x3 are safe on RDNA2
# MPASS - only F3x2 and F3x3 are safe on RDNA2
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "1",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "1",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "0",
@ -159,50 +159,50 @@ RDNA2: Dict[str, str] = {
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3": "0",
# ASM Implicit GEMM forward V4R1 only; no GTC/XDLOPS on RDNA2
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only disabled
# ASM Implicit GEMM - forward V4R1 only; no GTC/XDLOPS on RDNA2
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only - disabled
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_V4R1": "0",
# HIP Implicit GEMM non-XDLOPS V4R1/R4 forward only
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only disabled
# HIP Implicit GEMM - non-XDLOPS V4R1/R4 forward only
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only - disabled
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4": "0",
# Group Conv XDLOPS / CK default kernels RDNA3/4 only, not available on RDNA2
# Group Conv XDLOPS / CK default kernels - RDNA3/4 only, not available on RDNA2
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "0",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "0",
}
# ---------------------------------------------------------------------------
# RDNA3 gfx1100 (RX 7000 series)
# RDNA3 - gfx1100 (RX 7000 series)
# Fury Winograd added; MPASS F3x4 enabled; Group Conv XDLOPS + CK default kernels enabled
# ---------------------------------------------------------------------------
RDNA3: Dict[str, str] = {
**RDNA2,
# Fury Winograd introduced for gfx1100 (RDNA3)
# Fury Winograd - introduced for gfx1100 (RDNA3)
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "1",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "1",
# Wider MPASS on RDNA3
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "1",
# Group Conv XDLOPS / CK available from gfx1100 (RDNA3) onwards
# Group Conv XDLOPS / CK - available from gfx1100 (RDNA3) onwards
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "1",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "1",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "1",
}
# ---------------------------------------------------------------------------
# RDNA4 gfx1200 (RX 9000 series)
# RDNA4 - gfx1200 (RX 9000 series)
# Rage Winograd added; MPASS F3x5 enabled
# ---------------------------------------------------------------------------
RDNA4: Dict[str, str] = {
**RDNA3,
# Rage Winograd introduced for gfx1200 (RDNA4)
# Rage Winograd - introduced for gfx1200 (RDNA4)
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "1",
# Wider MPASS on RDNA4
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "1",

View File

@ -1,15 +1,48 @@
from typing import Dict, Any, List, Tuple
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
"MIOPEN_SYSTEM_DB_PATH": {
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
"desc": "MIOpen system DB path",
"widget": "textbox",
"options": None,
"restart_required": True,
},
"ROCBLAS_TENSILE_LIBPATH": {
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\rocblas\\library",
"desc": "rocBLAS Tensile library path",
"widget": "textbox",
"options": None,
"restart_required": True,
},
"MIOPEN_GEMM_ENFORCE_BACKEND": {
"default": "1",
"desc": "Enforce GEMM backend",
"desc": "GEMM backend",
"widget": "dropdown",
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
"restart_required": False,
},
"PYTORCH_ROCM_USE_ROCBLAS": {
"default": "0",
"desc": "PyTorch: Use rocBLAS",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": True,
},
"PYTORCH_HIPBLASLT_DISABLE": {
"default": "1",
"desc": "PyTorch: Use hipBLASLt",
"widget": "dropdown",
"options": [("0 - Allow hipBLASLt", "0"), ("1 - Disable hipBLASLt", "1")],
"restart_required": True,
},
"ROCBLAS_USE_HIPBLASLT": {
"default": "0",
"desc": "rocBLAS: use hipBLASLt backend",
"widget": "dropdown",
"options": [("0 - Tensile (rocBLAS)", "0"), ("1 - hipBLASLt", "1")],
"restart_required": True,
},
"MIOPEN_FIND_MODE": {
"default": "2",
"desc": "MIOpen Find Mode",
@ -31,12 +64,69 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": True,
},
"MIOPEN_SYSTEM_DB_PATH": {
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
"desc": "MIOpen system DB path",
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
"default": "0",
"desc": "Deterministic convolutions",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": False,
},
"MIOPEN_CONVOLUTION_MAX_WORKSPACE": {
"default": "1073741824",
"desc": "MIOpen convolutions: max workspace (bytes; 1 GB)",
"widget": "textbox",
"options": None,
"restart_required": True,
"restart_required": False,
},
"ROCBLAS_DEVICE_MEMORY_SIZE": {
"default": "",
"desc": "rocBLAS workspace size in bytes (empty = dynamic)",
"widget": "textbox",
"options": None,
"restart_required": False,
},
"PYTORCH_TUNABLEOP_CACHE_DIR": {
"default": "{ROOT}\\models\\tunable",
"desc": "TunableOp cache directory",
"widget": "textbox",
"options": None,
"restart_required": False,
},
"ROCBLAS_STREAM_ORDER_ALLOC": {
"default": "1",
"desc": "rocBLAS stream-ordered memory allocation",
"widget": "dropdown",
"options": [("0 - Standard", "0"), ("1 - Stream-ordered", "1")],
"restart_required": False,
},
"ROCBLAS_DEFAULT_ATOMICS_MODE": {
"default": "1",
"desc": "rocBLAS allow atomics",
"widget": "dropdown",
"options": [("0 - Off (deterministic)", "0"), ("1 - On (performance)", "1")],
"restart_required": False,
},
"PYTORCH_TUNABLEOP_ROCBLAS_ENABLED": {
"default": "0",
"desc": "TunableOp: Enable tuning",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": False,
},
"PYTORCH_TUNABLEOP_TUNING": {
"default": "0",
"desc": "TunableOp: Tuning mode",
"widget": "dropdown",
"options": [("0 - Use Cache", "0"), ("1 - Benchmark new shapes", "1")],
"restart_required": False,
},
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {
"default": "0",
"desc": "TunableOp: benchmark hipBLASLt kernels",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": False,
},
"MIOPEN_LOG_LEVEL": {
"default": "0",
@ -66,23 +156,8 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
"restart_required": False,
},
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
"default": "0",
"desc": "Deterministic convolution (reproducible results, may be slower)",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": False,
},
}
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
# Removed entirely — not representable in the UI, cannot be set by users:
# WRW (weight-gradient) and BWD (data-gradient) — training passes only, never run during inference
# XDLOPS/CK CDNA-exclusive (MI100/MI200/MI300 matrix engine variants) — not on any RDNA
# Fixed-geometry (5x10, 7x7-ImageNet, 11x11) — shapes never appear in SD/video inference
# FP32-reference (NAIVE_CONV_FWD, FWDGEN) — IsApplicable() unreliable for FP16/BF16
# Wide MPASS (F3x4..F7x3) — kernel sizes that cannot match any SD convolution shape
# Disabled by default (added but off): RDNA3/4-only — Group Conv XDLOPS, CK default kernels
_SOLVER_DESCS: Dict[str, str] = {}
_SOLVER_DESCS.update({
@ -251,3 +326,13 @@ SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
]),
]
# Variables that are relevant only when hipBLASLt is the active GEMM backend.
# These are visually greyed-out in the UI when rocBLAS (MIOPEN_GEMM_ENFORCE_BACKEND="1") is selected.
HIPBLASLT_VARS: set = {
"PYTORCH_HIPBLASLT_DISABLE",
"ROCBLAS_USE_HIPBLASLT",
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED",
"HIPBLASLT_LOG_LEVEL",
}

View File

@ -5,10 +5,9 @@ from modules import scripts_manager, shared
# rocm_mgr exposes package-internal helpers (prefixed _) that are intentionally called here
# pylint: disable=protected-access
class ROCmScript(scripts_manager.Script):
def title(self):
return "ROCm: Advanced Config"
return "Windows ROCm: Advanced Config"
def show(self, _is_img2img):
if shared.cmd_opts.use_rocm or installer.torch_info.get('type') == 'rocm':
@ -19,7 +18,7 @@ class ROCmScript(scripts_manager.Script):
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
return []
from scripts.rocm import rocm_mgr, rocm_vars # pylint: disable=no-name-in-module
from scripts.rocm import rocm_mgr, rocm_vars, rocm_profiles # pylint: disable=no-name-in-module
config = rocm_mgr.load_config()
var_names = []
@ -59,11 +58,25 @@ class ROCmScript(scripts_manager.Script):
row("path", udb.get("path", ""))
for fname, finfo in udb.get("files", {}).items():
row(fname, finfo)
section("User cache (~/.miopen/cache)")
ucache = d.get("user_cache", {})
row("path", ucache.get("path", ""))
for fname, sz in ucache.get("files", {}).items():
row(fname, sz)
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
def _build_style(unavailable, hipblaslt_disabled=False):
rules = []
for v in (unavailable or []):
rules.append(f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}")
if hipblaslt_disabled:
for v in rocm_vars.HIPBLASLT_VARS:
rules.append(f"#rocm_var_{v.lower()} {{ opacity: 0.45; pointer-events: none; }}")
return f"<style>{' '.join(rules)}</style>" if rules else ""
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
with gr.Row():
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>Set your database and solver selections based on GPU profile or individually.</p><br><p>Enable cuDNN in Backend Settings to activate MIOpen.</p>")
gr.HTML("<p><u>Advanced configuration for ROCm users.</u></p><br><p>This script aims to take the guesswork out of configuring MIOpen and rocBLAS on Windows ROCm, but also to expose the functioning switches of MIOpen for advanced configurations.</p><br><p>For best performance ensure that cuDNN and PyTorch tunable ops are set to <b><i>default</i></b> in Backend Settings.</p><br><p>This script was written with the intent to support ROCm Windows users, it should however, function identically for Linux users.</p><br>")
with gr.Row():
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
@ -74,12 +87,15 @@ class ROCmScript(scripts_manager.Script):
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
style_out = gr.HTML("")
_init_gemm = config.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
_init_arch = config.get(rocm_mgr._ARCH_KEY, "")
_init_unavailable = rocm_profiles.UNAVAILABLE.get(_init_arch, set()) if _init_arch else set()
style_out = gr.HTML(_build_style(_init_unavailable, _init_gemm == "1"))
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
# General vars (dropdowns, textboxes, checkboxes)
with gr.Group():
gr.HTML("<h3>MIOpen Settings</h3><hr>")
gr.HTML("<br><h3>MIOpen Settings</h3><hr>")
for name, meta in rocm_vars.GENERAL_VARS.items():
comp = _make_component(name, meta, config)
var_names.append(name)
@ -106,13 +122,46 @@ class ROCmScript(scripts_manager.Script):
for name, comp in zip(var_names, components):
meta = rocm_vars.ROCM_ENV_VARS[name]
if meta["widget"] == "dropdown":
if meta["widget"] == "dropdown" and name != "MIOPEN_GEMM_ENFORCE_BACKEND":
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
_GEMM_COMPANIONS = {
"PYTORCH_ROCM_USE_ROCBLAS": {"1": "1", "5": "0"},
"PYTORCH_HIPBLASLT_DISABLE": {"1": "1", "5": "0"},
"ROCBLAS_USE_HIPBLASLT": {"1": "0", "5": "1"},
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {"1": "0", "5": "1"},
}
def gemm_changed(gemm_display_val):
stored = rocm_mgr._dropdown_stored(str(gemm_display_val), rocm_vars.ROCM_ENV_VARS["MIOPEN_GEMM_ENFORCE_BACKEND"]["options"])
cfg = rocm_mgr.load_config().copy()
cfg["MIOPEN_GEMM_ENFORCE_BACKEND"] = stored
for var, vals in _GEMM_COMPANIONS.items():
cfg[var] = vals.get(stored, cfg.get(var, ""))
rocm_mgr.save_config(cfg)
rocm_mgr.apply_env(cfg)
arch = cfg.get(rocm_mgr._ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
result = [gr.update(value=_build_style(unavailable, stored == "1"))]
for pname in var_names:
if pname in _GEMM_COMPANIONS:
meta = rocm_vars.ROCM_ENV_VARS[pname]
val = _GEMM_COMPANIONS[pname].get(stored, cfg.get(pname, ""))
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
else:
result.append(gr.update())
return result
gemm_comp = components[var_names.index("MIOPEN_GEMM_ENFORCE_BACKEND")]
gemm_comp.change(fn=gemm_changed, inputs=[gemm_comp], outputs=[style_out] + components, show_progress='hidden')
def apply_fn(*values):
rocm_mgr.apply_all(var_names, list(values))
saved = rocm_mgr.load_config()
result = [gr.update(value="")]
arch = saved.get(rocm_mgr._ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
gemm_val = saved.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
val = saved.get(name, meta["default"])
@ -124,19 +173,13 @@ class ROCmScript(scripts_manager.Script):
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
return result
def _build_style(unavailable):
if not unavailable:
return ""
rules = " ".join(
f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}"
for v in unavailable
)
return f"<style>{rules}</style>"
def reset_fn():
rocm_mgr.reset_defaults()
updated = rocm_mgr.load_config()
result = [gr.update(value="")]
arch = updated.get(rocm_mgr._ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
val = updated.get(name, meta["default"])
@ -150,7 +193,9 @@ class ROCmScript(scripts_manager.Script):
def clear_fn():
rocm_mgr.clear_env()
result = [gr.update(value="")]
cfg = rocm_mgr.load_config()
gemm_val = cfg.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
result = [gr.update(value=_build_style(None, gemm_val == "1"))]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
if meta["widget"] == "checkbox":
@ -163,7 +208,8 @@ class ROCmScript(scripts_manager.Script):
def delete_fn():
rocm_mgr.delete_config()
result = [gr.update(value="")]
gemm_default = rocm_vars.ROCM_ENV_VARS.get("MIOPEN_GEMM_ENFORCE_BACKEND", {}).get("default", "1")
result = [gr.update(value=_build_style(None, gemm_default == "1"))]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
if meta["widget"] == "checkbox":
@ -175,11 +221,11 @@ class ROCmScript(scripts_manager.Script):
return result
def profile_fn(arch):
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
rocm_mgr.apply_profile(arch)
updated = rocm_mgr.load_config()
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
result = [gr.update(value=_build_style(unavailable))]
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
for pname in var_names:
meta = rocm_vars.ROCM_ENV_VARS[pname]
val = updated.get(pname, meta["default"])

2
wiki

@ -1 +1 @@
Subproject commit d54ade8e5f79a62b5228de6406400c8eda71b67f
Subproject commit cbbbfc73af2366650cdf8cc71fabbf3a508b607b