mirror of https://github.com/vladmandic/automatic
Compare commits
48 Commits
2026-04-02
...
master
| Author | SHA1 | Date |
|---|---|---|
|
|
0eb4a98e07 | |
|
|
155dabc840 | |
|
|
2fcabc8047 | |
|
|
d98d05ca2d | |
|
|
27a62cfa70 | |
|
|
d97191f342 | |
|
|
fbf1a962f2 | |
|
|
d7904b239f | |
|
|
90b5e7de30 | |
|
|
08c28ab257 | |
|
|
0c94a169ea | |
|
|
32b69bdd3d | |
|
|
b2e071dc52 | |
|
|
470a0d816e | |
|
|
ffeda702c5 | |
|
|
bfd9a0c0f5 | |
|
|
25af3242c3 | |
|
|
dc6f20ec8f | |
|
|
a809b616e6 | |
|
|
88bde026f7 | |
|
|
e49d6262e9 | |
|
|
ac9aacac66 | |
|
|
2177609e54 | |
|
|
ee3b141297 | |
|
|
24f4490a59 | |
|
|
d2a47ee0ed | |
|
|
01d53edb25 | |
|
|
4cafae9350 | |
|
|
b659a06c60 | |
|
|
9d0ecde462 | |
|
|
fdc2f46457 | |
|
|
1ed2811c80 | |
|
|
f5c037a735 | |
|
|
668a94141d | |
|
|
999cbe5d3a | |
|
|
3dd09fde08 | |
|
|
95dadab5c3 | |
|
|
59b9ca50ee | |
|
|
eeb9b6291b | |
|
|
8d6ec348b2 | |
|
|
a1b03a383c | |
|
|
ba362ad3ca | |
|
|
7f07d4cb31 | |
|
|
715b1b0699 | |
|
|
b2b6fdf9d5 | |
|
|
8ef8074467 | |
|
|
1e9bef8d56 | |
|
|
fe7e4b40ff |
16
CHANGELOG.md
16
CHANGELOG.md
|
|
@ -1,5 +1,21 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-04-04
|
||||
|
||||
- **Models**
|
||||
- [AiArtLab SDXS-1B](https://huggingface.co/AiArtLab/sdxs-1b) Simple Diffusion XS *(training still in progress)*
|
||||
this model combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE
|
||||
- **Compute**
|
||||
- **ROCm** futher work on advanced configuration and tuning, thanks @resonantsky
|
||||
see *main interface -> scripts -> rocm advanced config*
|
||||
- **Internal**
|
||||
- additional typing and typechecks, thanks @awsr
|
||||
- Prohibit python==3.14 unless `--experimental`
|
||||
- **Fixes**
|
||||
- UI CSS fixes, thanks @awsr
|
||||
- detect/warn if space in system path
|
||||
- add `ftfy` to requirements
|
||||
|
||||
## Update for 2026-04-01
|
||||
|
||||
### Highlights for 2026-04-01
|
||||
|
|
|
|||
2
TODO.md
2
TODO.md
|
|
@ -1,5 +1,7 @@
|
|||
# TODO
|
||||
|
||||
<https://github.com/huggingface/diffusers/pull/13317>
|
||||
|
||||
## Internal
|
||||
|
||||
- Feature: implement `unload_auxiliary_models`
|
||||
|
|
|
|||
|
|
@ -864,6 +864,16 @@
|
|||
"extras": "sampler: Default, cfg_scale: 1.5, steps: 50",
|
||||
"size": 15.3,
|
||||
"date": "2025 January"
|
||||
},
|
||||
|
||||
"AiArtLab SDXS-1B": {
|
||||
"path": "AiArtLab/sdxs-1b",
|
||||
"preview": "AiArtLab--sdxs-1b.jpg",
|
||||
"desc": "Simple Diffusion XS (train in progress) combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE",
|
||||
"skip": true,
|
||||
"extras": "sampler: Default",
|
||||
"size": 15.3,
|
||||
"date": "2026 January"
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit c7af727f31758c9fc96cf0429bcf3608858a15e8
|
||||
Subproject commit e3720332f2301fa597c94b40897aa6e983020f1f
|
||||
|
|
@ -474,6 +474,8 @@ def check_python(supported_minors=None, experimental_minors=None, reason=None):
|
|||
else:
|
||||
git_version = git('--version', folder=None, ignore=False)
|
||||
log.debug(f'Git: version={git_version.replace("git version", "").strip()}')
|
||||
if ' ' in sys.executable:
|
||||
log.warning(f'Python: path="{sys.executable}" contains spaces which may cause issues')
|
||||
ts('python', t_start)
|
||||
|
||||
|
||||
|
|
@ -1244,7 +1246,7 @@ def install_requirements():
|
|||
# set environment variables controling the behavior of various libraries
|
||||
def set_environment():
|
||||
log.debug('Setting environment tuning')
|
||||
os.environ.setdefault('PIP_CONSTRAINT', os.path.abspath('constraints.txt'))
|
||||
os.environ.setdefault('PIP_CONSTRAINT', 'constraints.txt')
|
||||
os.environ.setdefault('ACCELERATE', 'True')
|
||||
os.environ.setdefault('ATTN_PRECISION', 'fp16')
|
||||
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
|
||||
|
|
@ -1277,7 +1279,7 @@ def set_environment():
|
|||
os.environ.setdefault('MIOPEN_FIND_MODE', '2')
|
||||
os.environ.setdefault('UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS', '1')
|
||||
os.environ.setdefault('USE_TORCH', '1')
|
||||
os.environ.setdefault('UV_CONSTRAINT', os.path.abspath('constraints.txt'))
|
||||
os.environ.setdefault('UV_CONSTRAINT', 'constraints.txt')
|
||||
os.environ.setdefault('UV_INDEX_STRATEGY', 'unsafe-any-match')
|
||||
os.environ.setdefault('UV_NO_BUILD_ISOLATION', '1')
|
||||
os.environ.setdefault('UVICORN_TIMEOUT_KEEP_ALIVE', '60')
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 63 KiB |
|
|
@ -42,7 +42,7 @@ Resut should always be: list[ResGPU]
|
|||
class ResGPU(BaseModel):
|
||||
name: str = Field(title="GPU Name")
|
||||
data: dict = Field(title="Name/Value data")
|
||||
chart: list[float, float] = Field(title="Exactly two items to place on chart")
|
||||
chart: tuple[float, float] = Field(title="Exactly two items to place on chart")
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -484,22 +484,22 @@ else:
|
|||
FlagsModel = create_model("Flags", __config__=pydantic_config, **flags)
|
||||
|
||||
class ResEmbeddings(BaseModel):
|
||||
loaded: list = Field(default=None, title="loaded", description="List of loaded embeddings")
|
||||
skipped: list = Field(default=None, title="skipped", description="List of skipped embeddings")
|
||||
loaded: list = Field(title="loaded", description="List of loaded embeddings")
|
||||
skipped: list = Field(title="skipped", description="List of skipped embeddings")
|
||||
|
||||
class ResMemory(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||
|
||||
class ResScripts(BaseModel):
|
||||
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||
control: list = Field(default=None, title="Control", description="Titles of scripts (control)")
|
||||
txt2img: list[str] = Field(title="Txt2img", description="Titles of scripts (txt2img)")
|
||||
img2img: list[str] = Field(title="Img2img", description="Titles of scripts (img2img)")
|
||||
control: list[str] = Field(title="Control", description="Titles of scripts (control)")
|
||||
|
||||
class ResGPU(BaseModel): # definition of http response
|
||||
name: str = Field(title="GPU Name", description="GPU device name")
|
||||
data: dict = Field(title="Name/Value data", description="Key-value pairs of GPU metrics (utilization, temperature, clocks, memory, etc.)")
|
||||
chart: list[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
|
||||
chart: tuple[float, float] = Field(title="Exactly two items to place on chart", description="Two numeric values for chart display (e.g., GPU utilization %, VRAM usage %)")
|
||||
|
||||
class ItemLoadedModel(BaseModel):
|
||||
name: str = Field(title="Model Name", description="Model or component name")
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def get_nvml():
|
|||
"System load": f'GPU {load.gpu}% | VRAM {load.memory}% | Temp {pynvml.nvmlDeviceGetTemperature(dev, 0)}C | Fan {pynvml.nvmlDeviceGetFanSpeed(dev)}%',
|
||||
'State': get_reason(pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(dev)),
|
||||
}
|
||||
chart = [load.memory, load.gpu]
|
||||
chart = (load.memory, load.gpu)
|
||||
devices.append({
|
||||
'name': name,
|
||||
'data': data,
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def get_rocm_smi():
|
|||
'Throttle reason': str(ThrottleStatus(int(rocm_smi_data[key].get("throttle_status", 0)))),
|
||||
}
|
||||
name = rocm_smi_data[key].get('Device Name', 'unknown')
|
||||
chart = [load["memory"], load["gpu"]]
|
||||
chart = (load["memory"], load["gpu"])
|
||||
devices.append({
|
||||
'name': name,
|
||||
'data': data,
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def get_xpu_smi():
|
|||
"VRAM usage": f'{round(100 * load["memory"] / total)}% | {load["memory"]} MB used | {total - load["memory"]} MB free | {total} MB total',
|
||||
"RAM usage": f'{round(100 * ram["used"] / ram["total"])}% | {round(1024 * ram["used"])} MB used | {round(1024 * ram["free"])} MB free | {round(1024 * ram["total"])} MB total',
|
||||
}
|
||||
chart = [load["memory"], load["gpu"]]
|
||||
chart = (load["memory"], load["gpu"])
|
||||
devices.append({
|
||||
'name': torch.xpu.get_device_name(),
|
||||
'data': data,
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def ts2utc(timestamp: int) -> datetime:
|
|||
except Exception:
|
||||
return "unknown"
|
||||
|
||||
def active():
|
||||
def active() -> list[Extension]:
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
return []
|
||||
elif shared.opts.disable_all_extensions == "user":
|
||||
|
|
@ -185,12 +185,12 @@ class Extension:
|
|||
log.error(f"Extension: failed reading data from git repo={self.name}: {ex}")
|
||||
self.remote = None
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts_manager
|
||||
def list_files(self, subdir: str, extension: str):
|
||||
from modules.scripts_manager import ScriptFile
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
res: list[ScriptFile] = []
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
res = []
|
||||
return res
|
||||
for filename in sorted(os.listdir(dirpath)):
|
||||
if not filename.endswith(".py") and not filename.endswith(".js") and not filename.endswith(".mjs"):
|
||||
continue
|
||||
|
|
@ -198,7 +198,7 @@ class Extension:
|
|||
if os.path.isfile(os.path.join(dirpath, "..", ".priority")):
|
||||
with open(os.path.join(dirpath, "..", ".priority"), encoding="utf-8") as f:
|
||||
priority = str(f.read().strip())
|
||||
res.append(scripts_manager.ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
res.append(ScriptFile(self.path, filename, os.path.join(dirpath, filename), priority))
|
||||
if priority != '50':
|
||||
log.debug(f'Extension priority override: {os.path.dirname(dirpath)}:{priority}')
|
||||
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
|
|
|||
|
|
@ -1,15 +1,23 @@
|
|||
import math
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFont, ImageDraw
|
||||
from modules import shared, script_callbacks
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
||||
class Grid(NamedTuple):
|
||||
tiles: list
|
||||
tile_w: int
|
||||
tile_h: int
|
||||
image_w: int
|
||||
image_h: int
|
||||
overlap: int
|
||||
|
||||
|
||||
def check_grid_size(imgs):
|
||||
def check_grid_size(imgs: list[Image.Image] | list[list[Image.Image]] | None):
|
||||
if imgs is None or len(imgs) == 0:
|
||||
return False
|
||||
mp = 0
|
||||
|
|
@ -26,39 +34,36 @@ def check_grid_size(imgs):
|
|||
return ok
|
||||
|
||||
|
||||
def get_grid_size(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
if rows and rows > len(imgs):
|
||||
rows = len(imgs)
|
||||
if cols and cols > len(imgs):
|
||||
cols = len(imgs)
|
||||
def get_grid_size(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
rows_int, cols_int = len(imgs), len(imgs)
|
||||
if rows is None and cols is None:
|
||||
if shared.opts.n_rows > 0:
|
||||
rows = shared.opts.n_rows
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif shared.opts.n_rows == 0:
|
||||
rows = batch_size
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif shared.opts.n_cols > 0:
|
||||
cols = shared.opts.n_cols
|
||||
rows = math.ceil(len(imgs) / cols)
|
||||
elif shared.opts.n_cols == 0:
|
||||
cols = batch_size
|
||||
rows = math.ceil(len(imgs) / cols)
|
||||
n_rows, n_cols = shared.opts.n_rows, shared.opts.n_cols
|
||||
if n_rows >= 0:
|
||||
rows_int: int = batch_size if n_rows == 0 else n_rows
|
||||
cols_int = math.ceil(len(imgs) / rows_int)
|
||||
elif n_cols >= 0:
|
||||
cols_int: int = batch_size if n_cols == 0 else n_cols
|
||||
rows_int = math.ceil(len(imgs) / cols_int)
|
||||
else:
|
||||
rows = math.floor(math.sqrt(len(imgs)))
|
||||
while len(imgs) % rows != 0:
|
||||
rows -= 1
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif rows is not None and cols is None:
|
||||
cols = math.ceil(len(imgs) / rows)
|
||||
elif rows is None and cols is not None:
|
||||
rows = math.ceil(len(imgs) / cols)
|
||||
else:
|
||||
pass
|
||||
return rows, cols
|
||||
rows_int = math.floor(math.sqrt(len(imgs)))
|
||||
while len(imgs) % rows_int != 0:
|
||||
rows_int -= 1
|
||||
cols_int = math.ceil(len(imgs) / rows_int)
|
||||
return rows_int, cols_int
|
||||
# Set limits
|
||||
if rows is not None:
|
||||
rows_int = max(min(rows, len(imgs)), 1)
|
||||
if cols is not None:
|
||||
cols_int = max(min(cols, len(imgs)), 1)
|
||||
# Calculate
|
||||
if rows is None:
|
||||
rows_int = math.ceil(len(imgs) / cols_int)
|
||||
if cols is None:
|
||||
cols_int = math.ceil(len(imgs) / rows_int)
|
||||
return rows_int, cols_int
|
||||
|
||||
|
||||
def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
def image_grid(imgs: list, batch_size=1, rows: int | None = None, cols: int | None = None):
|
||||
rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols)
|
||||
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
||||
script_callbacks.image_grid_callback(params)
|
||||
|
|
@ -73,7 +78,7 @@ def image_grid(imgs, batch_size=1, rows: int | None = None, cols: int | None = N
|
|||
return grid
|
||||
|
||||
|
||||
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
def split_grid(image: Image.Image, tile_w=512, tile_h=512, overlap=64):
|
||||
w = image.width
|
||||
h = image.height
|
||||
non_overlap_width = tile_w - overlap
|
||||
|
|
@ -98,7 +103,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|||
return grid
|
||||
|
||||
|
||||
def combine_grid(grid):
|
||||
def combine_grid(grid: Grid):
|
||||
def make_mask_image(r):
|
||||
r = r * 255 / grid.overlap
|
||||
r = r.astype(np.uint8)
|
||||
|
|
@ -127,18 +132,18 @@ class GridAnnotation:
|
|||
def __init__(self, text='', is_active=True):
|
||||
self.text = str(text)
|
||||
self.is_active = is_active
|
||||
self.size = None
|
||||
self.size: tuple[int, int] = (10, 10) # Placeholder values
|
||||
|
||||
|
||||
def get_font(fontsize):
|
||||
def get_font(fontsize: float):
|
||||
try:
|
||||
return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize)
|
||||
|
||||
|
||||
def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=None):
|
||||
def wrap(drawing, text, font, line_length):
|
||||
def draw_grid_annotations(im: Image.Image, width: int, height: int, x_texts: list[list[GridAnnotation]], y_texts: list[list[GridAnnotation]], margin=0, title: list[GridAnnotation] | None = None):
|
||||
def wrap(drawing: ImageDraw.ImageDraw, text, font, line_length):
|
||||
lines = ['']
|
||||
for word in text.split():
|
||||
line = f'{lines[-1]} {word}'.strip()
|
||||
|
|
@ -148,7 +153,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
lines.append(word)
|
||||
return lines
|
||||
|
||||
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
def draw_texts(drawing: ImageDraw.ImageDraw, draw_x: float, draw_y: float, lines, initial_fnt: ImageFont.FreeTypeFont, initial_fontsize: int):
|
||||
for line in lines:
|
||||
font = initial_fnt
|
||||
fontsize = initial_fontsize
|
||||
|
|
@ -185,9 +190,10 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in x_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in y_texts]
|
||||
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
||||
title_text_heights = []
|
||||
title_pad = 0
|
||||
if title:
|
||||
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
|
||||
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts]
|
||||
title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
|
||||
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
|
||||
for row in range(rows):
|
||||
|
|
@ -210,7 +216,7 @@ def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=N
|
|||
return result
|
||||
|
||||
|
||||
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
||||
def draw_prompt_matrix(im: Image.Image, width: int, height: int, all_prompts: list[str], margin=0):
|
||||
prompts = all_prompts[1:]
|
||||
boundary = math.ceil(len(prompts) / 2)
|
||||
prompts_horiz = prompts[:boundary]
|
||||
|
|
|
|||
|
|
@ -100,6 +100,8 @@ def get_model_type(pipe):
|
|||
model_type = 'hunyuanimage3'
|
||||
elif 'HunyuanImage' in name:
|
||||
model_type = 'hunyuanimage'
|
||||
elif 'sdxs-1b' in name:
|
||||
model_type = 'sdxs'
|
||||
# video models
|
||||
elif "Kandinsky5" in name and '2V' in name:
|
||||
model_type = 'kandinsky5video'
|
||||
|
|
|
|||
|
|
@ -449,14 +449,19 @@ def load_upscalers():
|
|||
used_classes[classname] = cls
|
||||
upscaler_types = []
|
||||
for cls in reversed(used_classes.values()):
|
||||
name = cls.__name__
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
commandline_model_path = commandline_options.get(cmd_name, None)
|
||||
scaler = cls(commandline_model_path)
|
||||
scaler.user_path = commandline_model_path
|
||||
scaler.model_download_path = commandline_model_path or scaler.model_path
|
||||
upscalers += scaler.scalers
|
||||
upscaler_types.append(name[8:])
|
||||
try:
|
||||
name = cls.__name__
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
commandline_model_path = commandline_options.get(cmd_name, None)
|
||||
scaler = cls(commandline_model_path)
|
||||
scaler.user_path = commandline_model_path
|
||||
scaler.model_download_path = commandline_model_path or scaler.model_path
|
||||
upscalers += scaler.scalers
|
||||
upscaler_types.append(name[8:])
|
||||
except Exception as e:
|
||||
log.error(f'Upscaler: {cls} {e}')
|
||||
if len(upscalers) == 0:
|
||||
log.error('Upscalers: no data')
|
||||
shared.sd_upscalers = upscalers
|
||||
t1 = time.time()
|
||||
log.info(f"Available Upscalers: items={len(shared.sd_upscalers)} downloaded={len([x for x in shared.sd_upscalers if x.data_path is not None and os.path.isfile(x.data_path)])} user={len([x for x in shared.sd_upscalers if x.custom])} time={t1-t0:.2f} types={upscaler_types}")
|
||||
|
|
|
|||
|
|
@ -325,8 +325,8 @@ class StableDiffusionProcessing:
|
|||
|
||||
# initializers
|
||||
self.prompt = prompt
|
||||
self.seed = seed
|
||||
self.subseed = subseed
|
||||
self.seed = int(seed)
|
||||
self.subseed = int(subseed)
|
||||
self.subseed_strength = subseed_strength
|
||||
self.seed_resize_from_h = seed_resize_from_h
|
||||
self.seed_resize_from_w = seed_resize_from_w
|
||||
|
|
|
|||
|
|
@ -1,14 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
||||
from modules.logger import log
|
||||
from installer import control_extensions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
from modules.api.models import ItemScript
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import IOComponent
|
||||
from modules.processing import Processed, StableDiffusionProcessing
|
||||
|
||||
|
||||
AlwaysVisible = object()
|
||||
time_component = {}
|
||||
|
|
@ -28,22 +37,22 @@ class PostprocessBatchListArgs:
|
|||
|
||||
@dataclass
|
||||
class OnComponent:
|
||||
component: gr.blocks.Block
|
||||
component: Block
|
||||
|
||||
|
||||
class Script:
|
||||
parent = None
|
||||
name = None
|
||||
filename = None
|
||||
parent: str | None = None
|
||||
name: str | None = None
|
||||
filename: str | None = None
|
||||
args_from = 0
|
||||
args_to = 0
|
||||
alwayson = False
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
api_info = None
|
||||
api_info: ItemScript | None = None
|
||||
group = None
|
||||
infotext_fields = None
|
||||
paste_field_names = None
|
||||
infotext_fields: list | None = None
|
||||
paste_field_names: list[str] | None = None
|
||||
section = None
|
||||
standalone = False
|
||||
external = False
|
||||
|
|
@ -57,14 +66,14 @@ class Script:
|
|||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ui(self, is_img2img):
|
||||
def ui(self, is_img2img) -> list[IOComponent]:
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
Values of those returned components will be passed to run() and process() functions.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def show(self, is_img2img): # pylint: disable=unused-argument
|
||||
def show(self, is_img2img) -> bool | AlwaysVisible: # pylint: disable=unused-argument
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||
This function should return:
|
||||
|
|
@ -74,7 +83,7 @@ class Script:
|
|||
"""
|
||||
return True
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called if the script has been selected in the script dropdown.
|
||||
It must do all processing and return the Processed object with results, same as
|
||||
|
|
@ -84,13 +93,13 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
def setup(self, p: StableDiffusionProcessing, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process(self, p, *args):
|
||||
def before_process(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -98,7 +107,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process(self, p, *args):
|
||||
def process(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -106,7 +115,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_images(self, p, *args):
|
||||
def process_images(self, p: StableDiffusionProcessing, *args):
|
||||
"""
|
||||
This function is called instead of main processing for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
|
|
@ -114,7 +123,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_process_batch(self, p, *args, **kwargs):
|
||||
def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Called before extra networks are parsed from the prompt, so you can add
|
||||
new extra network keywords to the prompt with this callback.
|
||||
|
|
@ -126,7 +135,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Same as process(), but called for every batch.
|
||||
**kwargs will have those items:
|
||||
|
|
@ -137,7 +146,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch(self, p, *args, **kwargs):
|
||||
def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
|
||||
"""
|
||||
Same as process_batch(), but called for every batch after it has been generated.
|
||||
**kwargs will have same items as process_batch, and also:
|
||||
|
|
@ -146,13 +155,13 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||
"""
|
||||
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||
This is useful when you want to update the entire batch instead of individual images.
|
||||
|
|
@ -168,14 +177,14 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
def postprocess(self, p: StableDiffusionProcessing, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
def before_component(self, component: IOComponent, **kwargs):
|
||||
"""
|
||||
Called before a component is created.
|
||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||
|
|
@ -184,7 +193,7 @@ class Script:
|
|||
"""
|
||||
pass # pylint: disable=unnecessary-pass
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
def after_component(self, component: IOComponent, **kwargs):
|
||||
"""
|
||||
Called after a component is created. Same as above.
|
||||
"""
|
||||
|
|
@ -194,7 +203,7 @@ class Script:
|
|||
"""unused"""
|
||||
return ""
|
||||
|
||||
def elem_id(self, item_id):
|
||||
def elem_id(self, item_id: str):
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
return f'script_{self.parent}_{title}_{item_id}'
|
||||
|
|
@ -211,21 +220,33 @@ def basedir():
|
|||
return current_basedir
|
||||
|
||||
|
||||
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "priority"])
|
||||
class ScriptFile(NamedTuple):
|
||||
basedir: str
|
||||
filename: str
|
||||
path: str
|
||||
priority: str
|
||||
|
||||
|
||||
class ScriptClassData(NamedTuple):
|
||||
script_class: type[Script] | type[scripts_postprocessing.ScriptPostprocessing]
|
||||
path: str
|
||||
basedir: str
|
||||
module: ModuleType
|
||||
|
||||
|
||||
scripts_data = []
|
||||
postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
tmp_list = []
|
||||
def list_scripts(scriptdirname: str, extension: str):
|
||||
tmp_list: list[ScriptFile] = []
|
||||
base = os.path.join(paths.script_path, scriptdirname)
|
||||
if os.path.exists(base):
|
||||
for filename in sorted(os.listdir(base)):
|
||||
tmp_list.append(ScriptFile(paths.script_path, filename, os.path.join(base, filename), '50'))
|
||||
for ext in extensions.active():
|
||||
tmp_list += ext.list_files(scriptdirname, extension)
|
||||
priority_list = []
|
||||
priority_list: list[ScriptFile] = []
|
||||
for script in tmp_list:
|
||||
if os.path.splitext(script.path)[1].lower() == extension and os.path.isfile(script.path):
|
||||
if script.basedir == paths.script_path:
|
||||
|
|
@ -250,7 +271,7 @@ def list_scripts(scriptdirname, extension):
|
|||
return priority_sort
|
||||
|
||||
|
||||
def list_files_with_name(filename):
|
||||
def list_files_with_name(filename: str):
|
||||
res = []
|
||||
dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
|
||||
for dirpath in dirs:
|
||||
|
|
@ -273,7 +294,7 @@ def load_scripts():
|
|||
scripts_list = sorted(scripts_list, key=lambda item: item.priority + item.path.lower(), reverse=False)
|
||||
syspath = sys.path
|
||||
|
||||
def register_scripts_from_module(module, scriptfile):
|
||||
def register_scripts_from_module(module: ModuleType, scriptfile):
|
||||
for script_class in module.__dict__.values():
|
||||
if type(script_class) != type:
|
||||
continue
|
||||
|
|
@ -305,7 +326,7 @@ def load_scripts():
|
|||
return t, time.time()-t0
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
def wrap_call(func: Callable, filename: str, funcname: str, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
return res
|
||||
|
|
@ -315,7 +336,7 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
|||
|
||||
|
||||
class ScriptSummary:
|
||||
def __init__(self, op):
|
||||
def __init__(self, op: str):
|
||||
self.start = time.time()
|
||||
self.update = time.time()
|
||||
self.op = op
|
||||
|
|
@ -336,17 +357,17 @@ class ScriptSummary:
|
|||
class ScriptRunner:
|
||||
def __init__(self, name=''):
|
||||
self.name = name
|
||||
self.scripts = []
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.auto_processing_scripts = []
|
||||
self.titles = []
|
||||
self.alwayson_titles = []
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.scripts: list[Script] = []
|
||||
self.selectable_scripts: list[Script] = []
|
||||
self.alwayson_scripts: list[Script] = []
|
||||
self.auto_processing_scripts: list[ScriptClassData] = []
|
||||
self.titles: list[str] = []
|
||||
self.alwayson_titles: list[str] = []
|
||||
self.infotext_fields: list[tuple[IOComponent, str]] = []
|
||||
self.paste_field_names: list[str] = []
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = False
|
||||
self.inputs = [None]
|
||||
self.inputs: list = [None]
|
||||
self.time = 0
|
||||
|
||||
def add_script(self, script_class, path, is_img2img, is_control):
|
||||
|
|
@ -557,7 +578,7 @@ class ScriptRunner:
|
|||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts if script.group is not None])
|
||||
return inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
def run(self, p: StableDiffusionProcessing, *args) -> Processed | None:
|
||||
s = ScriptSummary('run')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if (script_index is None) or (script_index == 0):
|
||||
|
|
@ -582,7 +603,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def after(self, p, processed, *args):
|
||||
def after(self, p: StableDiffusionProcessing, processed: Processed, *args):
|
||||
s = ScriptSummary('after')
|
||||
script_index = args[0] if len(args) > 0 else 0
|
||||
if (script_index is None) or (script_index == 0):
|
||||
|
|
@ -600,7 +621,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process(self, p, **kwargs):
|
||||
def before_process(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('before-process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -612,7 +633,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process(self, p, **kwargs):
|
||||
def process(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('process')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -624,7 +645,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_images(self, p, **kwargs):
|
||||
def process_images(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('process_images')
|
||||
processed = None
|
||||
for script in self.alwayson_scripts:
|
||||
|
|
@ -640,7 +661,7 @@ class ScriptRunner:
|
|||
s.report()
|
||||
return processed
|
||||
|
||||
def before_process_batch(self, p, **kwargs):
|
||||
def before_process_batch(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('before-process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -652,7 +673,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def process_batch(self, p, **kwargs):
|
||||
def process_batch(self, p: StableDiffusionProcessing, **kwargs):
|
||||
s = ScriptSummary('process-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -664,7 +685,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
def postprocess(self, p: StableDiffusionProcessing, processed):
|
||||
s = ScriptSummary('postprocess')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -676,7 +697,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch(self, p, images, **kwargs):
|
||||
def postprocess_batch(self, p: StableDiffusionProcessing, images, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -688,7 +709,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||
def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, **kwargs):
|
||||
s = ScriptSummary('postprocess-batch-list')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -700,7 +721,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs):
|
||||
s = ScriptSummary('postprocess-image')
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
|
|
@ -712,7 +733,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
def before_component(self, component: IOComponent, **kwargs):
|
||||
s = ScriptSummary('before-component')
|
||||
for script in self.scripts:
|
||||
try:
|
||||
|
|
@ -722,7 +743,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
def after_component(self, component: IOComponent, **kwargs):
|
||||
s = ScriptSummary('after-component')
|
||||
for script in self.scripts:
|
||||
for elem_id, callback in script.on_after_component_elem_id:
|
||||
|
|
@ -738,7 +759,7 @@ class ScriptRunner:
|
|||
s.record(script.title())
|
||||
s.report()
|
||||
|
||||
def reload_sources(self, cache):
|
||||
def reload_sources(self, cache: dict):
|
||||
s = ScriptSummary('reload-sources')
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
if hasattr(script, 'args_to') and hasattr(script, 'args_from'):
|
||||
|
|
|
|||
|
|
@ -150,6 +150,8 @@ def guess_by_name(fn, current_guess):
|
|||
new_guess = 'Ovis-Image'
|
||||
elif 'glm-image' in fn.lower():
|
||||
new_guess = 'GLM-Image'
|
||||
elif 'sdxs-1b' in fn.lower():
|
||||
new_guess = 'SDXS'
|
||||
if debug_load:
|
||||
log.trace(f'Autodetect: method=name file="{fn}" previous="{current_guess}" current="{new_guess}"')
|
||||
return new_guess or current_guess
|
||||
|
|
@ -166,6 +168,8 @@ def guess_by_diffusers(fn, current_guess):
|
|||
if name is not None and name in exclude_by_name:
|
||||
return current_guess, None
|
||||
cls = index.get('_class_name', None)
|
||||
if isinstance(cls, list):
|
||||
cls = cls[-1]
|
||||
if cls is not None:
|
||||
pipeline = getattr(diffusers, cls, None)
|
||||
if pipeline is None:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ def hijack_encode_prompt(*args, **kwargs):
|
|||
jobid = shared.state.begin('TE Encode')
|
||||
t0 = time.time()
|
||||
if 'max_sequence_length' in kwargs and kwargs['max_sequence_length'] is not None:
|
||||
kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('HIDREAM_MAX_SEQUENCE_LENGTH', 256))
|
||||
kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('MAX_SEQUENCE_LENGTH', 256))
|
||||
try:
|
||||
prompt = kwargs.get('prompt', None) or (args[0] if len(args) > 0 else None)
|
||||
if prompt is not None:
|
||||
|
|
@ -20,8 +20,6 @@ def hijack_encode_prompt(*args, **kwargs):
|
|||
res = None
|
||||
t1 = time.time()
|
||||
timer.process.add('te', t1-t0)
|
||||
# if hasattr(shared.sd_model, "maybe_free_model_hooks"):
|
||||
# shared.sd_model.maybe_free_model_hooks()
|
||||
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
|
||||
shared.state.end(jobid)
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -507,6 +507,10 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
|
|||
from pipelines.model_glm import load_glm_image
|
||||
sd_model = load_glm_image(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['SDXS']:
|
||||
from pipelines.model_sdxs import load_sdxs
|
||||
sd_model = load_sdxs(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
except Exception as e:
|
||||
log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
|
||||
# if debug_load:
|
||||
|
|
@ -1418,7 +1422,7 @@ def hf_auth_check(checkpoint_info, force:bool=False):
|
|||
return False
|
||||
|
||||
|
||||
def save_model(name: str, path: str | None = None, shard: str | None = None, overwrite = False):
|
||||
def save_model(name: str, path: str | None = None, shard: str = "5GB", overwrite = False):
|
||||
if (name is None) or len(name.strip()) == 0:
|
||||
log.error('Save model: invalid model name')
|
||||
return 'Invalid model name'
|
||||
|
|
@ -1432,6 +1436,8 @@ def save_model(name: str, path: str | None = None, shard: str | None = None, ove
|
|||
if os.path.exists(model_name) and not overwrite:
|
||||
log.error(f'Save model: path="{model_name}" exists')
|
||||
return f'Path exists: {model_name}'
|
||||
if not shard.strip():
|
||||
shard = "5GB" # Guard against empty input
|
||||
try:
|
||||
t0 = time.time()
|
||||
save_sdnq_model(
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ def load_unet_sdxl_nunchaku(repo_id):
|
|||
def load_unet(model, repo_id: str | None = None):
|
||||
global loaded_unet # pylint: disable=global-statement
|
||||
|
||||
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and (('stable-diffusion-xl-base' in repo_id) or ('sdxl-turbo' in repo_id)):
|
||||
if model_quant.check_nunchaku('Model'):
|
||||
if ("StableDiffusionXLPipeline" in model.__class__.__name__) and repo_id is not None and (("stable-diffusion-xl-base" in repo_id) or ("sdxl-turbo" in repo_id)):
|
||||
if model_quant.check_nunchaku("Model"):
|
||||
unet = load_unet_sdxl_nunchaku(repo_id)
|
||||
if unet is not None:
|
||||
model.unet = unet
|
||||
|
|
|
|||
|
|
@ -40,18 +40,18 @@ def conv_fp16_matmul(
|
|||
scale = scale.t()
|
||||
elif weight.dtype != torch.float16:
|
||||
weight = weight.to(dtype=torch.float16) # fp8 weights
|
||||
input, scale = quantize_fp_mm_input_tensorwise(input, scale, matmul_dtype="float16")
|
||||
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype, matmul_dtype="float16")
|
||||
input, weight = check_mats(input, weight)
|
||||
|
||||
if groups == 1:
|
||||
result = fp_mm_func(input, weight)
|
||||
result = fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
|
||||
else:
|
||||
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
|
||||
input = input.view(input.shape[0], groups, input.shape[1] // groups)
|
||||
result = []
|
||||
for i in range(groups):
|
||||
result.append(fp_mm_func(input[:, i], weight[:, i]))
|
||||
result = torch.cat(result, dim=-1)
|
||||
result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
|
||||
if bias is not None:
|
||||
dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -38,19 +38,19 @@ def conv_fp8_matmul_tensorwise(
|
|||
if quantized_weight_shape is not None:
|
||||
weight = unpack_float(weight, weights_dtype, quantized_weight_shape).to(dtype=torch.float8_e4m3fn).t_()
|
||||
scale = scale.t()
|
||||
input, scale = quantize_fp_mm_input_tensorwise(input, scale)
|
||||
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype)
|
||||
input, weight = check_mats(input, weight)
|
||||
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
|
||||
|
||||
if groups == 1:
|
||||
result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype)
|
||||
result = torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).mul_(input_scale)
|
||||
else:
|
||||
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
|
||||
input = input.view(input.shape[0], groups, input.shape[1] // groups)
|
||||
result = []
|
||||
for i in range(groups):
|
||||
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype))
|
||||
result = torch.cat(result, dim=-1)
|
||||
result.append(torch._scaled_mm(input[:, i], weight[:, i], scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype))
|
||||
result = torch.cat(result, dim=-1).mul_(input_scale)
|
||||
if bias is not None:
|
||||
dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -38,18 +38,18 @@ def conv_int8_matmul(
|
|||
if quantized_weight_shape is not None:
|
||||
weight = unpack_int(weight, weights_dtype, quantized_weight_shape, dtype=torch.int8).t_()
|
||||
scale = scale.t()
|
||||
input, scale = quantize_int_mm_input(input, scale)
|
||||
input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
|
||||
input, weight = check_mats(input, weight)
|
||||
|
||||
if groups == 1:
|
||||
result = int_mm_func(input, weight)
|
||||
result = int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale)
|
||||
else:
|
||||
weight = weight.view(weight.shape[0], groups, weight.shape[1] // groups)
|
||||
input = input.view(input.shape[0], groups, input.shape[1] // groups)
|
||||
result = []
|
||||
for i in range(groups):
|
||||
result.append(int_mm_func(input[:, i], weight[:, i]))
|
||||
result = torch.cat(result, dim=-1)
|
||||
result = torch.cat(result, dim=-1).to(dtype=input_scale.dtype).mul_(input_scale)
|
||||
if bias is not None:
|
||||
result = dequantize_symmetric_with_bias(result, scale, bias, dtype=return_dtype, result_shape=mm_output_shape)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -33,12 +33,12 @@ def fp16_matmul(
|
|||
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
||||
else:
|
||||
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
||||
input, scale = quantize_fp_mm_input_tensorwise(input, scale, matmul_dtype="float16")
|
||||
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype, matmul_dtype="float16")
|
||||
input, weight = check_mats(input, weight)
|
||||
if bias is not None:
|
||||
return dequantize_symmetric_with_bias(fp_mm_func(input, weight), scale, bias, dtype=return_dtype, result_shape=output_shape)
|
||||
return dequantize_symmetric_with_bias(fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
|
||||
else:
|
||||
return dequantize_symmetric(fp_mm_func(input, weight), scale, dtype=return_dtype, result_shape=output_shape)
|
||||
return dequantize_symmetric(fp_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
|
||||
|
||||
|
||||
def quantized_linear_forward_fp16_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@ from ...dequantizer import quantize_fp_mm, dequantize_symmetric, dequantize_symm
|
|||
from .forward import check_mats
|
||||
|
||||
|
||||
def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, scale: torch.FloatTensor, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
|
||||
input = input.flatten(0,-2).to(dtype=scale.dtype)
|
||||
def quantize_fp_mm_input_tensorwise(input: torch.FloatTensor, dtype: torch.dtype | None = None, matmul_dtype: str = "float8_e4m3fn") -> tuple[torch.Tensor, torch.FloatTensor]:
|
||||
input = input.flatten(0,-2)
|
||||
if dtype is not None:
|
||||
input = input.to(dtype=dtype)
|
||||
input, input_scale = quantize_fp_mm(input, dim=-1, matmul_dtype=matmul_dtype)
|
||||
scale = torch.mul(input_scale, scale)
|
||||
if scale.dtype == torch.float16: # fp16 will overflow
|
||||
scale = scale.to(dtype=torch.float32)
|
||||
return input, scale
|
||||
if input_scale.dtype == torch.float16: # fp16 will overflow
|
||||
input_scale = input_scale.to(dtype=torch.float32)
|
||||
return input, input_scale
|
||||
|
||||
|
||||
def fp8_matmul_tensorwise(
|
||||
|
|
@ -40,12 +41,12 @@ def fp8_matmul_tensorwise(
|
|||
else:
|
||||
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
||||
dummy_input_scale = torch.ones(1, device=input.device, dtype=torch.float32)
|
||||
input, scale = quantize_fp_mm_input_tensorwise(input, scale)
|
||||
input, input_scale = quantize_fp_mm_input_tensorwise(input, dtype=scale.dtype)
|
||||
input, weight = check_mats(input, weight, allow_contiguous_mm=False)
|
||||
if bias is not None:
|
||||
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, bias, dtype=return_dtype, result_shape=output_shape)
|
||||
return dequantize_symmetric_with_bias(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
|
||||
else:
|
||||
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=scale.dtype), scale, dtype=return_dtype, result_shape=output_shape)
|
||||
return dequantize_symmetric(torch._scaled_mm(input, weight, scale_a=dummy_input_scale, scale_b=dummy_input_scale, bias=None, out_dtype=input_scale.dtype).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
|
||||
|
||||
|
||||
def quantized_linear_forward_fp8_matmul_tensorwise(self, input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@ from ...packed_int import unpack_int # noqa: TID252
|
|||
from .forward import check_mats
|
||||
|
||||
|
||||
def quantize_int_mm_input(input: torch.FloatTensor, scale: torch.FloatTensor) -> tuple[torch.CharTensor, torch.FloatTensor]:
|
||||
input = input.flatten(0,-2).to(dtype=scale.dtype)
|
||||
def quantize_int_mm_input(input: torch.FloatTensor, dtype: torch.dtype | None = None) -> tuple[torch.CharTensor, torch.FloatTensor]:
|
||||
input = input.flatten(0,-2)
|
||||
if dtype is not None:
|
||||
input = input.to(dtype=dtype)
|
||||
input, input_scale = quantize_int_mm(input, dim=-1)
|
||||
scale = torch.mul(input_scale, scale)
|
||||
if scale.dtype == torch.float16: # fp16 will overflow
|
||||
scale = scale.to(dtype=torch.float32)
|
||||
return input, scale
|
||||
if input_scale.dtype == torch.float16: # fp16 will overflow
|
||||
input_scale = input_scale.to(dtype=torch.float32)
|
||||
return input, input_scale
|
||||
|
||||
|
||||
def int8_matmul(
|
||||
|
|
@ -39,12 +40,12 @@ def int8_matmul(
|
|||
bias = torch.addmm(bias.to(dtype=svd_down.dtype), torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
||||
else:
|
||||
bias = torch.mm(torch.mm(input.to(dtype=svd_down.dtype), svd_down), svd_up)
|
||||
input, scale = quantize_int_mm_input(input, scale)
|
||||
input, input_scale = quantize_int_mm_input(input, dtype=scale.dtype)
|
||||
input, weight = check_mats(input, weight)
|
||||
if bias is not None:
|
||||
return dequantize_symmetric_with_bias(int_mm_func(input, weight), scale, bias, dtype=return_dtype, result_shape=output_shape)
|
||||
return dequantize_symmetric_with_bias(int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, bias, dtype=return_dtype, result_shape=output_shape)
|
||||
else:
|
||||
return dequantize_symmetric(int_mm_func(input, weight), scale, dtype=return_dtype, result_shape=output_shape)
|
||||
return dequantize_symmetric(int_mm_func(input, weight).to(dtype=input_scale.dtype).mul_(input_scale), scale, dtype=return_dtype, result_shape=output_shape)
|
||||
|
||||
|
||||
def quantized_linear_forward_int8_matmul(self, input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""
|
||||
Modified from Triton MatMul example.
|
||||
PyTorch torch._int_mm is broken on backward pass with Nvidia.
|
||||
AMD RDNA2 doesn't support torch._int_mm, so we use int_mm via Triton.
|
||||
PyTorch doesn't support FP32 output type with FP16 MM so we use Triton for it too.
|
||||
PyTorch torch._int_mm is broken on backward pass with Nvidia, so we use Triton on the backward pass with Nvidia.
|
||||
AMD RDNA2 doesn't support torch._int_mm as it requires INT8 WMMA, so we use INT8 DP4A via Triton.
|
||||
PyTorch doesn't support FP32 output type with FP16 MM, so we use Triton for FP16 MM too.
|
||||
matmul_configs we use takes AMD and Intel into consideration too.
|
||||
SDNQ Triton configs can outperform RocBLAS and OneDNN.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
|
@ -22,7 +24,7 @@ matmul_configs = [
|
|||
]
|
||||
|
||||
|
||||
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"])
|
||||
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"], cache_results=True)
|
||||
@triton.jit
|
||||
def triton_mm_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
|
|
@ -76,6 +78,55 @@ def triton_mm_kernel(
|
|||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
# Intel requires tensor descriptors to perform good
|
||||
@triton.autotune(configs=matmul_configs, key=["M", "N", "K", "stride_bk", "ACCUMULATOR_DTYPE"], cache_results=True)
|
||||
@triton.jit
|
||||
def triton_mm_td_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
M: int, N: int, K: int,
|
||||
stride_am: int, stride_ak: int,
|
||||
stride_bk: int, stride_bn: int,
|
||||
stride_cm: int, stride_cn: int,
|
||||
ACCUMULATOR_DTYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
tl.assume(pid_m >= 0)
|
||||
tl.assume(pid_n >= 0)
|
||||
tl.assume(stride_am > 0)
|
||||
tl.assume(stride_ak > 0)
|
||||
tl.assume(stride_bn > 0)
|
||||
tl.assume(stride_bk > 0)
|
||||
tl.assume(stride_cm > 0)
|
||||
tl.assume(stride_cn > 0)
|
||||
|
||||
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
|
||||
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||
|
||||
off_k = 0
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
|
||||
for _ in range(0, K, BLOCK_SIZE_K):
|
||||
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
|
||||
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
|
||||
off_k += BLOCK_SIZE_K
|
||||
|
||||
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
|
||||
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], accumulator)
|
||||
|
||||
|
||||
def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.is_contiguous(), "Matrix A must be contiguous"
|
||||
|
|
@ -84,7 +135,8 @@ def int_mm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|||
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
triton_mm_kernel[grid](
|
||||
mm_kernel_func = triton_mm_td_kernel if b.is_contiguous() else triton_mm_kernel
|
||||
mm_kernel_func[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
|
|
@ -103,7 +155,8 @@ def fp_mm(a: torch.FloatTensor, b: torch.FloatTensor) -> torch.FloatTensor:
|
|||
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
triton_mm_kernel[grid](
|
||||
mm_kernel_func = triton_mm_td_kernel if b.is_contiguous() else triton_mm_kernel
|
||||
mm_kernel_func[grid](
|
||||
a, b, c,
|
||||
M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ pipelines = {
|
|||
'HunyuanImage3': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'ChronoEdit': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'Anima': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'SDXS': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ def create_settings(cmd_opts):
|
|||
|
||||
"openvino_sep": OptionInfo("<h2>OpenVINO</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
|
||||
"openvino_devices": OptionInfo([], "OpenVINO devices to use", gr.CheckboxGroup, {"choices": get_openvino_device_list() if cmd_opts.use_openvino else [], "visible": cmd_opts.use_openvino}),
|
||||
"openvino_accuracy": OptionInfo("performance", "OpenVINO accuracy mode", gr.Radio, {"choices": ["performance", "accuracy"], "visible": cmd_opts.use_openvino}),
|
||||
"openvino_accuracy": OptionInfo("default", "OpenVINO accuracy mode", gr.Radio, {"choices": ["default", "performance", "accuracy"], "visible": cmd_opts.use_openvino}),
|
||||
"openvino_disable_model_caching": OptionInfo(True, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
|
||||
"openvino_disable_memory_cleanup": OptionInfo(True, "OpenVINO disable memory cleanup after compile", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
|
||||
|
||||
|
|
|
|||
|
|
@ -365,6 +365,8 @@ def create_resize_inputs(tab, images, accordion=True, latent=False, non_zero=Tru
|
|||
with gr.Accordion(open=False, label="Resize", elem_classes=["small-accordion"], elem_id=f"{tab}_resize_group") if accordion else gr.Group():
|
||||
with gr.Row():
|
||||
available_upscalers = [x.name for x in shared.sd_upscalers]
|
||||
if len(available_upscalers) == 0:
|
||||
available_upscalers = ['None']
|
||||
if not latent:
|
||||
available_upscalers = [x for x in available_upscalers if not x.lower().startswith('latent')]
|
||||
resize_mode = gr.Dropdown(label=f"Mode{prefix}" if non_zero else "Resize mode", elem_id=f"{tab}_resize_mode", choices=shared.resize_modes, type="index", value='Fixed')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
import time
|
||||
import diffusers
|
||||
import transformers
|
||||
from modules import shared, devices, errors, timer, sd_models, model_quant, sd_hijack_vae
|
||||
from modules.logger import log
|
||||
from pipelines import generic
|
||||
|
||||
|
||||
def hijack_encode_text(prompt: str | list[str]):
|
||||
jobid = shared.state.begin('TE Encode')
|
||||
t0 = time.time()
|
||||
try:
|
||||
prompt = shared.sd_model.refine_prompts(prompt)
|
||||
except Exception as e:
|
||||
log.error(f'Encode prompt: {e}')
|
||||
errors.display(e, 'Encode prompt')
|
||||
try:
|
||||
res = shared.sd_model.orig_encode_text(prompt)
|
||||
except Exception as e:
|
||||
log.error(f'Encode prompt: {e}')
|
||||
errors.display(e, 'Encode prompt')
|
||||
res = None
|
||||
t1 = time.time()
|
||||
timer.process.add('te', t1-t0)
|
||||
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
|
||||
shared.state.end(jobid)
|
||||
return res
|
||||
|
||||
|
||||
def load_sdxs(checkpoint_info, diffusers_load_config=None):
|
||||
if diffusers_load_config is None:
|
||||
diffusers_load_config = {}
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
log.debug(f'Load model: type=SDXS repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
text_encoder = generic.load_text_encoder(repo_id, cls_name=transformers.Qwen3_5ForConditionalGeneration, load_config=diffusers_load_config, allow_shared=False)
|
||||
|
||||
pipe = diffusers.DiffusionPipeline.from_pretrained(
|
||||
repo_id,
|
||||
text_encoder=text_encoder,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
trust_remote_code=True,
|
||||
**load_args,
|
||||
)
|
||||
pipe.task_args = {
|
||||
'generator': None,
|
||||
'output_type': 'np',
|
||||
}
|
||||
|
||||
pipe.orig_encode_text = pipe.encode_text
|
||||
pipe.encode_text = hijack_encode_text
|
||||
sd_hijack_vae.init_hijack(pipe)
|
||||
del text_encoder
|
||||
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
|
|
@ -18,6 +18,7 @@ fasteners
|
|||
limits
|
||||
orjson
|
||||
websockets
|
||||
ftfy
|
||||
|
||||
# versioned
|
||||
fastapi==0.124.4
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from modules.logger import log
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
|
||||
|
||||
repo_id = 'AIML-TUDA/LlavaGuard-v1.2-0.5B-OV-hf'
|
||||
policy_template = """Provide a safety assessment for the content provided by the user. Determine compliance or violation against our safety policy by reviewing the following policy categories:
|
||||
Hate:
|
||||
|
|
@ -89,11 +96,11 @@ To provide your assessment use the following json template for each category:
|
|||
"rationale": str,
|
||||
}.
|
||||
"""
|
||||
model = None
|
||||
processor = None
|
||||
model: LlavaOnevisionForConditionalGeneration | None = None
|
||||
processor: AutoProcessor | None = None
|
||||
|
||||
|
||||
def image_guard(image, policy:str | None=None) -> str:
|
||||
def image_guard(image, policy:str | None=None):
|
||||
global model, processor # pylint: disable=global-statement
|
||||
import json
|
||||
from installer import install
|
||||
|
|
|
|||
|
|
@ -49,14 +49,14 @@ def create_ui(accordion=True):
|
|||
|
||||
# main processing used in both modes
|
||||
def process(
|
||||
p: processing.StableDiffusionProcessing=None,
|
||||
pp: scripts.PostprocessImageArgs=None,
|
||||
p: processing.StableDiffusionProcessing | None = None,
|
||||
pp: scripts.PostprocessImageArgs | scripts_postprocessing.PostprocessedImage | None = None,
|
||||
enabled=True,
|
||||
lang=False,
|
||||
policy=False,
|
||||
banned=False,
|
||||
metadata=True,
|
||||
copy=False,
|
||||
copy=False, # Compatability
|
||||
score=0.2,
|
||||
blocks=3,
|
||||
censor=[],
|
||||
|
|
@ -74,7 +74,7 @@ def process(
|
|||
nudes = nudenet.detector.censor(image=pp.image, method=method, min_score=score, censor=censor, blocks=blocks, overlay=overlay)
|
||||
t1 = time.time()
|
||||
if len(nudes.censored) > 0: # Check if there are any censored areas
|
||||
if not copy:
|
||||
if p is None:
|
||||
pp.image = nudes.output
|
||||
else:
|
||||
info = processing.create_infotext(p)
|
||||
|
|
@ -85,7 +85,7 @@ def process(
|
|||
if metadata and p is not None:
|
||||
p.extra_generation_params["NudeNet"] = meta
|
||||
p.extra_generation_params["NSFW"] = nsfw
|
||||
if metadata and hasattr(pp, 'info'):
|
||||
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
|
||||
pp.info['NudeNet'] = meta
|
||||
pp.info['NSFW'] = nsfw
|
||||
log.debug(f'NudeNet detect: {dct} nsfw={nsfw} time={(t1 - t0):.2f}')
|
||||
|
|
@ -118,7 +118,7 @@ def process(
|
|||
if metadata and p is not None:
|
||||
p.extra_generation_params["Rating"] = res.get('rating', 'N/A')
|
||||
p.extra_generation_params["Category"] = res.get('category', 'N/A')
|
||||
if metadata and hasattr(pp, 'info'):
|
||||
if metadata and isinstance(pp, scripts_postprocessing.PostprocessedImage):
|
||||
pp.info["Rating"] = res.get('rating', 'N/A')
|
||||
pp.info["Category"] = res.get('category', 'N/A')
|
||||
|
||||
|
|
@ -131,11 +131,11 @@ class ScriptNudeNet(scripts.Script):
|
|||
def title(self):
|
||||
return 'NudeNet'
|
||||
|
||||
def show(self, _is_img2img):
|
||||
def show(self, *args, **kwargs):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
# return signature is array of gradio components
|
||||
def ui(self, _is_img2img):
|
||||
def ui(self, *args, **kwargs):
|
||||
return create_ui(accordion=True)
|
||||
|
||||
# triggered by callback
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
|
@ -6,39 +7,28 @@ from typing import Dict, Optional
|
|||
import installer
|
||||
from modules.logger import log
|
||||
from modules.json_helpers import readfile, writefile
|
||||
from modules.shared import opts
|
||||
|
||||
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
|
||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||
|
||||
|
||||
def _check_rocm() -> bool:
|
||||
from modules import shared
|
||||
if getattr(shared.cmd_opts, 'use_rocm', False):
|
||||
return True
|
||||
if installer.torch_info.get('type') == 'rocm':
|
||||
return True
|
||||
import torch # pylint: disable=import-outside-toplevel
|
||||
return hasattr(torch.version, 'hip') and torch.version.hip is not None
|
||||
|
||||
|
||||
is_rocm = _check_rocm()
|
||||
|
||||
|
||||
CONFIG = Path(os.path.abspath(os.path.join('data', 'rocm.json')))
|
||||
|
||||
_cache: Optional[Dict[str, str]] = None # loaded once, invalidated on save
|
||||
|
||||
# Metadata key written into rocm.json to record which architecture profile is active.
|
||||
# Not an environment variable — always skipped during env application but preserved in the
|
||||
# Not an environment variable - always skipped during env application but preserved in the
|
||||
# saved config so that arch-safety enforcement is consistent across restarts.
|
||||
_ARCH_KEY = "_rocm_arch"
|
||||
|
||||
# Vars that must never appear in the process environment.
|
||||
#
|
||||
# _DTYPE_UNSAFE: alter FP16 inference dtype — must be cleared regardless of config
|
||||
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — DEBUG alias: routes all FP16 convs through BF16 exponent math
|
||||
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — API-level alias: same BF16-exponent effect
|
||||
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM — unstable experimental FP16 path
|
||||
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 — changes FP16 WrW atomic accumulation
|
||||
# _DTYPE_UNSAFE: alter FP16 inference dtype - must be cleared regardless of config
|
||||
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL - DEBUG alias: routes all FP16 convs through BF16 exponent math
|
||||
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL - API-level alias: same BF16-exponent effect
|
||||
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM - unstable experimental FP16 path
|
||||
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 - changes FP16 WrW atomic accumulation
|
||||
#
|
||||
# SOLVER_DISABLED_BY_DEFAULT: every solver known to be incompatible with this runtime
|
||||
# (FP32-only, training-only WrW/BWD, fixed-geometry mismatches, XDLOPS/CDNA-only, arch-specific).
|
||||
|
|
@ -53,18 +43,18 @@ _DTYPE_UNSAFE = {
|
|||
# regardless of saved config. Limited to dtype-corrupting vars only.
|
||||
# IMPORTANT: SOLVER_DISABLED_BY_DEFAULT is intentionally NOT included here.
|
||||
# When a solver var is absent (unset) MIOpen still calls IsApplicable() on every
|
||||
# conv-find — wasted probing overhead. When a var is explicitly "0" MIOpen skips
|
||||
# conv-find - wasted probing overhead. When a var is explicitly "0" MIOpen skips
|
||||
# IsApplicable() immediately. Solver defaults flow through the config loop as "0"
|
||||
# (their ROCM_ENV_VARS default is "0") so they are explicitly set to "0" in the env.
|
||||
_UNSET_VARS = _DTYPE_UNSAFE
|
||||
|
||||
# Additional environment vars that must be removed from the process before MIOpen loads.
|
||||
# These are not MIOpen solver toggles but can corrupt MIOpen's runtime behaviour:
|
||||
# HIP_PATH / HIP_PATH_71 — point to the system AMD ROCm install; override the venv-bundled
|
||||
# HIP_PATH / HIP_PATH_71 - point to the system AMD ROCm install; override the venv-bundled
|
||||
# _rocm_sdk_devel DLLs with a potentially mismatched system version
|
||||
# QML_*/QT_* — QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
|
||||
# QML_*/QT_* - QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
|
||||
# PyTorch but can conflict with Gradio's embedded Qt helpers
|
||||
# PYENV_VIRTUALENV_DISABLE_PROMPT — pyenv noise that confuses venv detection
|
||||
# PYENV_VIRTUALENV_DISABLE_PROMPT - pyenv noise that confuses venv detection
|
||||
_EXTRA_CLEAR_VARS = {
|
||||
"HIP_PATH",
|
||||
"HIP_PATH_71",
|
||||
|
|
@ -72,7 +62,7 @@ _EXTRA_CLEAR_VARS = {
|
|||
"QML_DISABLE_DISK_CACHE",
|
||||
"QML_FORCE_DISK_CACHE",
|
||||
"QT_DISABLE_SHADER_DISK_CACHE",
|
||||
# PERF_VALS vars are NOT boolean toggles — MIOpen reads them as perf-config strings.
|
||||
# PERF_VALS vars are NOT boolean toggles - MIOpen reads them as perf-config strings.
|
||||
# If inherited from a parent shell with value "1", MIOpen's GetPerfConfFromEnv parses
|
||||
# "1" as a degenerate config and can return dtype=float32 output from FP16 tensors.
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS",
|
||||
|
|
@ -81,12 +71,12 @@ _EXTRA_CLEAR_VARS = {
|
|||
|
||||
# Solvers whose MIOpen IsApplicable() explicitly rejects non-FP32 tensors.
|
||||
# They are safe to leave enabled in FP32 mode. When the active dtype is FP16 or BF16
|
||||
# we force them OFF so MIOpen skips the IsApplicable probe entirely — avoids overhead on
|
||||
# we force them OFF so MIOpen skips the IsApplicable probe entirely - avoids overhead on
|
||||
# every conv shape find. These are NOT in _UNSET_VARS because they are valid in FP32.
|
||||
_FP32_ONLY_SOLVERS = {
|
||||
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution — FP32 only (MIOpen source: IsFp32 check)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 — FP32 only
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd — FP32 only
|
||||
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution - FP32 only (MIOpen source: IsFp32 check)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 - FP32 only
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd - FP32 only
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -106,8 +96,7 @@ def _resolve_dtype() -> str:
|
|||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from modules import shared as _sh # pylint: disable=import-outside-toplevel
|
||||
v = getattr(getattr(_sh, 'opts', None), 'cuda_dtype', None)
|
||||
v = getattr(opts, 'cuda_dtype', None)
|
||||
if v in ('FP16', 'BF16', 'FP32'):
|
||||
return v
|
||||
except Exception:
|
||||
|
|
@ -118,17 +107,25 @@ def _resolve_dtype() -> str:
|
|||
# --- venv helpers ---
|
||||
|
||||
def _get_venv() -> str:
|
||||
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
|
||||
return sys.prefix
|
||||
|
||||
|
||||
def _get_root() -> str:
|
||||
from modules.paths import script_path # pylint: disable=import-outside-toplevel
|
||||
return str(script_path)
|
||||
|
||||
|
||||
def _expand_venv(value: str) -> str:
|
||||
return value.replace("{VIRTUAL_ENV}", _get_venv())
|
||||
return value.replace("{VIRTUAL_ENV}", _get_venv()).replace("{ROOT}", _get_root())
|
||||
|
||||
|
||||
def _collapse_venv(value: str) -> str:
|
||||
venv = _get_venv()
|
||||
root = _get_root()
|
||||
if venv and value.startswith(venv):
|
||||
return "{VIRTUAL_ENV}" + value[len(venv):]
|
||||
if root and value.startswith(root):
|
||||
return "{ROOT}" + value[len(root):]
|
||||
return value
|
||||
|
||||
|
||||
|
|
@ -163,7 +160,7 @@ def load_config() -> Dict[str, str]:
|
|||
_cache = data if data else {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
|
||||
# Purge unsafe vars from a stale saved config and re-persist only if the file existed.
|
||||
# When running without a saved config (first run / after Delete), load_config() must
|
||||
# never create the file — that only happens via save_config() on Apply or Apply Profile.
|
||||
# never create the file - that only happens via save_config() on Apply or Apply Profile.
|
||||
dirty = {k for k in _cache if k in _UNSET_VARS or (k != _ARCH_KEY and k not in ROCM_ENV_VARS)}
|
||||
if dirty:
|
||||
_cache = {k: v for k, v in _cache.items() if k not in dirty}
|
||||
|
|
@ -212,7 +209,7 @@ def apply_env(config: Optional[Dict[str, str]] = None) -> None:
|
|||
os.environ[var] = expanded
|
||||
# Arch safety net: hard-force all hardware-incompatible vars to "0" in the env.
|
||||
# This runs *after* the config loop so it overrides any stale "1" that survived in the JSON.
|
||||
# Source of truth: rocm_profiles.UNAVAILABLE[arch] — vars with no supporting hardware.
|
||||
# Source of truth: rocm_profiles.UNAVAILABLE[arch] - vars with no supporting hardware.
|
||||
arch = config.get(_ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
if unavailable:
|
||||
|
|
@ -240,7 +237,7 @@ def apply_all(names: list, values: list) -> None:
|
|||
meta = ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
if value is None:
|
||||
pass # Gradio passed None (component not interacted with) — leave config unchanged
|
||||
pass # Gradio passed None (component not interacted with) - leave config unchanged
|
||||
else:
|
||||
config[name] = "1" if value else "0"
|
||||
elif meta["widget"] == "radio":
|
||||
|
|
@ -248,7 +245,7 @@ def apply_all(names: list, values: list) -> None:
|
|||
valid = {v for _, v in meta["options"]} if meta["options"] and isinstance(meta["options"][0], tuple) else set(meta["options"] or [])
|
||||
if stored in valid:
|
||||
config[name] = stored
|
||||
# else: value was None/invalid — leave the existing saved value untouched
|
||||
# else: value was None/invalid - leave the existing saved value untouched
|
||||
else:
|
||||
if meta.get("options"):
|
||||
value = _dropdown_stored(str(value), meta["options"])
|
||||
|
|
@ -291,7 +288,7 @@ def delete_config() -> None:
|
|||
CONFIG.unlink()
|
||||
log.info(f'ROCm delete_config: deleted {CONFIG}')
|
||||
_cache = None
|
||||
# Delete the MIOpen user DB (~/.miopen/db) — stale entries can cause solver mismatches
|
||||
# Delete the MIOpen user DB (~/.miopen/db) - stale entries can cause solver mismatches
|
||||
miopen_db = Path(os.path.expanduser('~')) / '.miopen' / 'db'
|
||||
if miopen_db.exists():
|
||||
shutil.rmtree(miopen_db, ignore_errors=True)
|
||||
|
|
@ -365,6 +362,28 @@ def _user_db_summary(path: Path) -> dict:
|
|||
return out
|
||||
|
||||
|
||||
def _extract_db_hash(db_path: Path) -> str:
|
||||
"""Derive the cache subfolder name from udb.txt filenames.
|
||||
e.g. gfx1030_30.HIP.3_5_1_5454e9e2da.udb.txt → '3.5.1.5454e9e2da'"""
|
||||
for f in db_path.glob("*.HIP.*.udb.txt"):
|
||||
m = re.search(r'\.HIP\.([^.]+)\.udb\.txt$', f.name)
|
||||
if m:
|
||||
return m.group(1).replace("_", ".")
|
||||
return ""
|
||||
|
||||
|
||||
def _user_cache_summary(path: Path) -> dict:
|
||||
"""Return {filename: 'N KB'} for binary cache blobs in the resolved cache path."""
|
||||
out = {}
|
||||
if not path.exists():
|
||||
return out
|
||||
for f in sorted(path.iterdir()):
|
||||
if f.is_file():
|
||||
kb = f.stat().st_size // 1024
|
||||
out[f.name] = f"{kb} KB"
|
||||
return out
|
||||
|
||||
|
||||
def info() -> dict:
|
||||
config = load_config()
|
||||
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
|
||||
|
|
@ -427,20 +446,29 @@ def info() -> dict:
|
|||
if ufiles:
|
||||
udb["files"] = ufiles
|
||||
|
||||
# User cache (~/.miopen/cache/<version-hash>)
|
||||
cache_base = Path.home() / ".miopen" / "cache"
|
||||
db_hash = _extract_db_hash(user_db_path) if user_db_path.exists() else ""
|
||||
cache_path = cache_base / db_hash if db_hash else cache_base
|
||||
ucache = {"path": str(cache_path), "exists": cache_path.exists()}
|
||||
if cache_path.exists():
|
||||
cfiles = _user_cache_summary(cache_path)
|
||||
if cfiles:
|
||||
ucache["files"] = cfiles
|
||||
|
||||
return {
|
||||
"rocm": rocm_section,
|
||||
"torch": torch_section,
|
||||
"gpu": gpu_section,
|
||||
"system_db": sdb,
|
||||
"user_db": udb,
|
||||
"user_cache": ucache,
|
||||
}
|
||||
|
||||
|
||||
# Apply saved config to os.environ at import time (only when ROCm is present)
|
||||
if is_rocm:
|
||||
if installer.torch_info.get('type', None) == 'rocm':
|
||||
try:
|
||||
apply_env()
|
||||
except Exception as _e:
|
||||
print(f"[rocm_mgr] Warning: failed to apply env at import: {_e}", file=sys.stderr)
|
||||
else:
|
||||
log.debug('ROCm is not installed — skipping rocm_mgr env apply')
|
||||
log.debug(f"[rocm_mgr] Warning: failed to apply env at import: {_e}")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""
|
||||
"""
|
||||
Architecture-specific MIOpen solver profiles for AMD GCN/RDNA GPUs.
|
||||
|
||||
Sources:
|
||||
|
|
@ -6,8 +6,8 @@ Sources:
|
|||
|
||||
Key axis: consumer RDNA GPUs have NO XDLOPS hardware (that's CDNA/Instinct only).
|
||||
RDNA2 (gfx1030): RX 6000 series
|
||||
RDNA3 (gfx1100): RX 7000 series — adds Fury Winograd, wider MPASS
|
||||
RDNA4 (gfx1200): RX 9000 series — adds Rage Winograd, wider MPASS
|
||||
RDNA3 (gfx1100): RX 7000 series - adds Fury Winograd, wider MPASS
|
||||
RDNA4 (gfx1200): RX 9000 series - adds Rage Winograd, wider MPASS
|
||||
|
||||
Each profile is a dict of {var: value} that will be MERGED on top of the
|
||||
current config (general vars like DB path / log level are preserved).
|
||||
|
|
@ -15,9 +15,9 @@ current config (general vars like DB path / log level are preserved).
|
|||
|
||||
from typing import Dict
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Shared: everything that must be OFF on ALL consumer RDNA (no XDLOPS hw)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_XDLOPS_OFF: Dict[str, str] = {
|
||||
# GTC XDLOPS (CDNA-only)
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS": "0",
|
||||
|
|
@ -55,7 +55,7 @@ _XDLOPS_OFF: Dict[str, str] = {
|
|||
# MLIR (CDNA-only in practice)
|
||||
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS": "0",
|
||||
# MP BD Winograd (Multi-pass Block-Decomposed — CDNA / high-end only)
|
||||
# MP BD Winograd (Multi-pass Block-Decomposed - CDNA / high-end only)
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F3X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F4X3": "0",
|
||||
|
|
@ -68,17 +68,17 @@ _XDLOPS_OFF: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F6X3": "0",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA2 — gfx1030 (RX 6000 series)
|
||||
|
||||
# RDNA2 - gfx1030 (RX 6000 series)
|
||||
# No XDLOPS, no Fury/Rage Winograd, MPASS limited to F3x2/F3x3
|
||||
# ASM IGEMM: V4R1 variants only; HIP IGEMM: non-XDLOPS V4R1/R4 only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RDNA2: Dict[str, str] = {
|
||||
**_XDLOPS_OFF,
|
||||
# General settings (architecture-independent; set here so all profiles cover them)
|
||||
"MIOPEN_SEARCH_CUTOFF": "0",
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": "0",
|
||||
# Core algo enables — FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
|
||||
# Core algo enables - FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
|
||||
"MIOPEN_DEBUG_CONV_FFT": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT": "1",
|
||||
"MIOPEN_DEBUG_CONV_GEMM": "1",
|
||||
|
|
@ -93,16 +93,16 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "1",
|
||||
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "1",
|
||||
"MIOPEN_DEBUG_ATTN_SOFTMAX": "1",
|
||||
# Direct ASM — dtype notes
|
||||
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward — enabled
|
||||
# Direct ASM - dtype notes
|
||||
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward - enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
|
||||
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches — disabled
|
||||
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
|
||||
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
|
||||
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) - never matches SD - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
|
||||
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) — disabled for inference
|
||||
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) - disabled for inference
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW3X3": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW1X1": "0",
|
||||
# PERF_VALS intentionally blank: MIOpen reads this as a config string not a boolean;
|
||||
|
|
@ -110,30 +110,30 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS": "",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "1",
|
||||
# NAIVE_CONV_FWD: scalar FP32 reference solver — IsApplicable does NOT reliably filter for FP16;
|
||||
# NAIVE_CONV_FWD: scalar FP32 reference solver - IsApplicable does NOT reliably filter for FP16;
|
||||
# can be selected for unusual shapes (e.g. VAE decoder 3-ch output) and returns dtype=float32
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD": "0",
|
||||
# Direct OCL — dtype notes
|
||||
# FWD / FWD1X1: FP32/FP16 forward — enabled
|
||||
# Direct OCL - dtype notes
|
||||
# FWD / FWD1X1: FP32/FP16 forward - enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
|
||||
# FWD11X11: requires 11*11 kernel — no SD match — disabled
|
||||
# FWD11X11: requires 11*11 kernel - no SD match - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
|
||||
# FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16;
|
||||
# can produce dtype=float32 output for FP16 inputs — disabled
|
||||
# FWDGEN: FP32 generic OCL fallback - IsApplicable does NOT reliably reject for FP16;
|
||||
# can produce dtype=float32 output for FP16 inputs - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWDGEN": "0",
|
||||
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient — disabled
|
||||
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient - disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW2": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW53": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW1X1": "0",
|
||||
# Winograd RxS — dtype per MIOpen docs
|
||||
# WINOGRAD_3X3: FP32-only — harmless (IsApplicable rejects for fp16); enabled
|
||||
# Winograd RxS - dtype per MIOpen docs
|
||||
# WINOGRAD_3X3: FP32-only - harmless (IsApplicable rejects for fp16); enabled
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "1",
|
||||
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW — keep enabled (fp16 fwd/bwd path exists)
|
||||
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW - keep enabled (fp16 fwd/bwd path exists)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "1",
|
||||
# RXS_FWD_BWD: FP32/FP16 — explicitly the fp16-capable subset
|
||||
# RXS_FWD_BWD: FP32/FP16 - explicitly the fp16-capable subset
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "1",
|
||||
# RXS_WRW: FP32 WrW only — training-only, disabled for inference fp16 profile
|
||||
# RXS_WRW: FP32 WrW only - training-only, disabled for inference fp16 profile
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_WRW": "0",
|
||||
# RXS_F3X2: FP32/FP16 Fwd/Bwd
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "1",
|
||||
|
|
@ -141,15 +141,15 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "1",
|
||||
# RXS_F2X3_G1: FP32/FP16 Fwd/Bwd (non-group convolutions)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "1",
|
||||
# FUSED_WINOGRAD: FP32-only — harmless (IsApplicable rejects for fp16); enabled
|
||||
# FUSED_WINOGRAD: FP32-only - harmless (IsApplicable rejects for fp16); enabled
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "1",
|
||||
# PERF_VALS intentionally blank: same reason as ASM_1X1U — not a boolean, config string
|
||||
# PERF_VALS intentionally blank: same reason as ASM_1X1U - not a boolean, config string
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS": "",
|
||||
# Fury/Rage Winograd — NOT available on RDNA2
|
||||
# Fury/Rage Winograd - NOT available on RDNA2
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "0",
|
||||
# MPASS — only F3x2 and F3x3 are safe on RDNA2
|
||||
# MPASS - only F3x2 and F3x3 are safe on RDNA2
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "0",
|
||||
|
|
@ -159,50 +159,50 @@ RDNA2: Dict[str, str] = {
|
|||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3": "0",
|
||||
# ASM Implicit GEMM — forward V4R1 only; no GTC/XDLOPS on RDNA2
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
|
||||
# ASM Implicit GEMM - forward V4R1 only; no GTC/XDLOPS on RDNA2
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only - disabled
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_V4R1": "0",
|
||||
# HIP Implicit GEMM — non-XDLOPS V4R1/R4 forward only
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
|
||||
# HIP Implicit GEMM - non-XDLOPS V4R1/R4 forward only
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only - disabled
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4": "0",
|
||||
# Group Conv XDLOPS / CK default kernels — RDNA3/4 only, not available on RDNA2
|
||||
# Group Conv XDLOPS / CK default kernels - RDNA3/4 only, not available on RDNA2
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "0",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "0",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA3 — gfx1100 (RX 7000 series)
|
||||
# RDNA3 - gfx1100 (RX 7000 series)
|
||||
# Fury Winograd added; MPASS F3x4 enabled; Group Conv XDLOPS + CK default kernels enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA3: Dict[str, str] = {
|
||||
**RDNA2,
|
||||
# Fury Winograd — introduced for gfx1100 (RDNA3)
|
||||
# Fury Winograd - introduced for gfx1100 (RDNA3)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "1",
|
||||
# Wider MPASS on RDNA3
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "1",
|
||||
# Group Conv XDLOPS / CK — available from gfx1100 (RDNA3) onwards
|
||||
# Group Conv XDLOPS / CK - available from gfx1100 (RDNA3) onwards
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "1",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "1",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "1",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA4 — gfx1200 (RX 9000 series)
|
||||
# RDNA4 - gfx1200 (RX 9000 series)
|
||||
# Rage Winograd added; MPASS F3x5 enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA4: Dict[str, str] = {
|
||||
**RDNA3,
|
||||
# Rage Winograd — introduced for gfx1200 (RDNA4)
|
||||
# Rage Winograd - introduced for gfx1200 (RDNA4)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "1",
|
||||
# Wider MPASS on RDNA4
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "1",
|
||||
|
|
|
|||
|
|
@ -1,15 +1,48 @@
|
|||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
|
||||
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||
|
||||
"MIOPEN_SYSTEM_DB_PATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
||||
"desc": "MIOpen system DB path",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
},
|
||||
"ROCBLAS_TENSILE_LIBPATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\rocblas\\library",
|
||||
"desc": "rocBLAS Tensile library path",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_GEMM_ENFORCE_BACKEND": {
|
||||
"default": "1",
|
||||
"desc": "Enforce GEMM backend",
|
||||
"desc": "GEMM backend",
|
||||
"widget": "dropdown",
|
||||
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_ROCM_USE_ROCBLAS": {
|
||||
"default": "0",
|
||||
"desc": "PyTorch: Use rocBLAS",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"PYTORCH_HIPBLASLT_DISABLE": {
|
||||
"default": "1",
|
||||
"desc": "PyTorch: Use hipBLASLt",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Allow hipBLASLt", "0"), ("1 - Disable hipBLASLt", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"ROCBLAS_USE_HIPBLASLT": {
|
||||
"default": "0",
|
||||
"desc": "rocBLAS: use hipBLASLt backend",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Tensile (rocBLAS)", "0"), ("1 - hipBLASLt", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_FIND_MODE": {
|
||||
"default": "2",
|
||||
"desc": "MIOpen Find Mode",
|
||||
|
|
@ -31,12 +64,69 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
|||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_SYSTEM_DB_PATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
||||
"desc": "MIOpen system DB path",
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
||||
"default": "0",
|
||||
"desc": "Deterministic convolutions",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_CONVOLUTION_MAX_WORKSPACE": {
|
||||
"default": "1073741824",
|
||||
"desc": "MIOpen convolutions: max workspace (bytes; 1 GB)",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
"restart_required": False,
|
||||
},
|
||||
"ROCBLAS_DEVICE_MEMORY_SIZE": {
|
||||
"default": "",
|
||||
"desc": "rocBLAS workspace size in bytes (empty = dynamic)",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_CACHE_DIR": {
|
||||
"default": "{ROOT}\\models\\tunable",
|
||||
"desc": "TunableOp cache directory",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": False,
|
||||
},
|
||||
|
||||
"ROCBLAS_STREAM_ORDER_ALLOC": {
|
||||
"default": "1",
|
||||
"desc": "rocBLAS stream-ordered memory allocation",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Standard", "0"), ("1 - Stream-ordered", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"ROCBLAS_DEFAULT_ATOMICS_MODE": {
|
||||
"default": "1",
|
||||
"desc": "rocBLAS allow atomics",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off (deterministic)", "0"), ("1 - On (performance)", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_ROCBLAS_ENABLED": {
|
||||
"default": "0",
|
||||
"desc": "TunableOp: Enable tuning",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_TUNING": {
|
||||
"default": "0",
|
||||
"desc": "TunableOp: Tuning mode",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Use Cache", "0"), ("1 - Benchmark new shapes", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {
|
||||
"default": "0",
|
||||
"desc": "TunableOp: benchmark hipBLASLt kernels",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_LOG_LEVEL": {
|
||||
"default": "0",
|
||||
|
|
@ -66,23 +156,8 @@ GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
|||
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
||||
"default": "0",
|
||||
"desc": "Deterministic convolution (reproducible results, may be slower)",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
|
||||
# Removed entirely — not representable in the UI, cannot be set by users:
|
||||
# WRW (weight-gradient) and BWD (data-gradient) — training passes only, never run during inference
|
||||
# XDLOPS/CK CDNA-exclusive (MI100/MI200/MI300 matrix engine variants) — not on any RDNA
|
||||
# Fixed-geometry (5x10, 7x7-ImageNet, 11x11) — shapes never appear in SD/video inference
|
||||
# FP32-reference (NAIVE_CONV_FWD, FWDGEN) — IsApplicable() unreliable for FP16/BF16
|
||||
# Wide MPASS (F3x4..F7x3) — kernel sizes that cannot match any SD convolution shape
|
||||
# Disabled by default (added but off): RDNA3/4-only — Group Conv XDLOPS, CK default kernels
|
||||
_SOLVER_DESCS: Dict[str, str] = {}
|
||||
|
||||
_SOLVER_DESCS.update({
|
||||
|
|
@ -251,3 +326,13 @@ SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
|
|||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
||||
]),
|
||||
]
|
||||
|
||||
# Variables that are relevant only when hipBLASLt is the active GEMM backend.
|
||||
# These are visually greyed-out in the UI when rocBLAS (MIOPEN_GEMM_ENFORCE_BACKEND="1") is selected.
|
||||
HIPBLASLT_VARS: set = {
|
||||
"PYTORCH_HIPBLASLT_DISABLE",
|
||||
"ROCBLAS_USE_HIPBLASLT",
|
||||
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED",
|
||||
"HIPBLASLT_LOG_LEVEL",
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,9 @@ from modules import scripts_manager, shared
|
|||
# rocm_mgr exposes package-internal helpers (prefixed _) that are intentionally called here
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
class ROCmScript(scripts_manager.Script):
|
||||
def title(self):
|
||||
return "ROCm: Advanced Config"
|
||||
return "Windows ROCm: Advanced Config"
|
||||
|
||||
def show(self, _is_img2img):
|
||||
if shared.cmd_opts.use_rocm or installer.torch_info.get('type') == 'rocm':
|
||||
|
|
@ -19,7 +18,7 @@ class ROCmScript(scripts_manager.Script):
|
|||
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
|
||||
return []
|
||||
|
||||
from scripts.rocm import rocm_mgr, rocm_vars # pylint: disable=no-name-in-module
|
||||
from scripts.rocm import rocm_mgr, rocm_vars, rocm_profiles # pylint: disable=no-name-in-module
|
||||
|
||||
config = rocm_mgr.load_config()
|
||||
var_names = []
|
||||
|
|
@ -59,11 +58,25 @@ class ROCmScript(scripts_manager.Script):
|
|||
row("path", udb.get("path", ""))
|
||||
for fname, finfo in udb.get("files", {}).items():
|
||||
row(fname, finfo)
|
||||
section("User cache (~/.miopen/cache)")
|
||||
ucache = d.get("user_cache", {})
|
||||
row("path", ucache.get("path", ""))
|
||||
for fname, sz in ucache.get("files", {}).items():
|
||||
row(fname, sz)
|
||||
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
|
||||
|
||||
def _build_style(unavailable, hipblaslt_disabled=False):
|
||||
rules = []
|
||||
for v in (unavailable or []):
|
||||
rules.append(f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}")
|
||||
if hipblaslt_disabled:
|
||||
for v in rocm_vars.HIPBLASLT_VARS:
|
||||
rules.append(f"#rocm_var_{v.lower()} {{ opacity: 0.45; pointer-events: none; }}")
|
||||
return f"<style>{' '.join(rules)}</style>" if rules else ""
|
||||
|
||||
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
|
||||
with gr.Row():
|
||||
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>Set your database and solver selections based on GPU profile or individually.</p><br><p>Enable cuDNN in Backend Settings to activate MIOpen.</p>")
|
||||
gr.HTML("<p><u>Advanced configuration for ROCm users.</u></p><br><p>This script aims to take the guesswork out of configuring MIOpen and rocBLAS on Windows ROCm, but also to expose the functioning switches of MIOpen for advanced configurations.</p><br><p>For best performance ensure that cuDNN and PyTorch tunable ops are set to <b><i>default</i></b> in Backend Settings.</p><br><p>This script was written with the intent to support ROCm Windows users, it should however, function identically for Linux users.</p><br>")
|
||||
with gr.Row():
|
||||
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
|
||||
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
|
||||
|
|
@ -74,12 +87,15 @@ class ROCmScript(scripts_manager.Script):
|
|||
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
|
||||
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
|
||||
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
|
||||
style_out = gr.HTML("")
|
||||
_init_gemm = config.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
_init_arch = config.get(rocm_mgr._ARCH_KEY, "")
|
||||
_init_unavailable = rocm_profiles.UNAVAILABLE.get(_init_arch, set()) if _init_arch else set()
|
||||
style_out = gr.HTML(_build_style(_init_unavailable, _init_gemm == "1"))
|
||||
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
|
||||
|
||||
# General vars (dropdowns, textboxes, checkboxes)
|
||||
with gr.Group():
|
||||
gr.HTML("<h3>MIOpen Settings</h3><hr>")
|
||||
gr.HTML("<br><h3>MIOpen Settings</h3><hr>")
|
||||
for name, meta in rocm_vars.GENERAL_VARS.items():
|
||||
comp = _make_component(name, meta, config)
|
||||
var_names.append(name)
|
||||
|
|
@ -106,13 +122,46 @@ class ROCmScript(scripts_manager.Script):
|
|||
|
||||
for name, comp in zip(var_names, components):
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "dropdown":
|
||||
if meta["widget"] == "dropdown" and name != "MIOPEN_GEMM_ENFORCE_BACKEND":
|
||||
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
|
||||
|
||||
_GEMM_COMPANIONS = {
|
||||
"PYTORCH_ROCM_USE_ROCBLAS": {"1": "1", "5": "0"},
|
||||
"PYTORCH_HIPBLASLT_DISABLE": {"1": "1", "5": "0"},
|
||||
"ROCBLAS_USE_HIPBLASLT": {"1": "0", "5": "1"},
|
||||
"PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED": {"1": "0", "5": "1"},
|
||||
}
|
||||
|
||||
def gemm_changed(gemm_display_val):
|
||||
stored = rocm_mgr._dropdown_stored(str(gemm_display_val), rocm_vars.ROCM_ENV_VARS["MIOPEN_GEMM_ENFORCE_BACKEND"]["options"])
|
||||
cfg = rocm_mgr.load_config().copy()
|
||||
cfg["MIOPEN_GEMM_ENFORCE_BACKEND"] = stored
|
||||
for var, vals in _GEMM_COMPANIONS.items():
|
||||
cfg[var] = vals.get(stored, cfg.get(var, ""))
|
||||
rocm_mgr.save_config(cfg)
|
||||
rocm_mgr.apply_env(cfg)
|
||||
arch = cfg.get(rocm_mgr._ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
result = [gr.update(value=_build_style(unavailable, stored == "1"))]
|
||||
for pname in var_names:
|
||||
if pname in _GEMM_COMPANIONS:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||
val = _GEMM_COMPANIONS[pname].get(stored, cfg.get(pname, ""))
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
|
||||
else:
|
||||
result.append(gr.update())
|
||||
return result
|
||||
|
||||
gemm_comp = components[var_names.index("MIOPEN_GEMM_ENFORCE_BACKEND")]
|
||||
gemm_comp.change(fn=gemm_changed, inputs=[gemm_comp], outputs=[style_out] + components, show_progress='hidden')
|
||||
|
||||
def apply_fn(*values):
|
||||
rocm_mgr.apply_all(var_names, list(values))
|
||||
saved = rocm_mgr.load_config()
|
||||
result = [gr.update(value="")]
|
||||
arch = saved.get(rocm_mgr._ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
gemm_val = saved.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
val = saved.get(name, meta["default"])
|
||||
|
|
@ -124,19 +173,13 @@ class ROCmScript(scripts_manager.Script):
|
|||
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
||||
return result
|
||||
|
||||
def _build_style(unavailable):
|
||||
if not unavailable:
|
||||
return ""
|
||||
rules = " ".join(
|
||||
f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}"
|
||||
for v in unavailable
|
||||
)
|
||||
return f"<style>{rules}</style>"
|
||||
|
||||
def reset_fn():
|
||||
rocm_mgr.reset_defaults()
|
||||
updated = rocm_mgr.load_config()
|
||||
result = [gr.update(value="")]
|
||||
arch = updated.get(rocm_mgr._ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
val = updated.get(name, meta["default"])
|
||||
|
|
@ -150,7 +193,9 @@ class ROCmScript(scripts_manager.Script):
|
|||
|
||||
def clear_fn():
|
||||
rocm_mgr.clear_env()
|
||||
result = [gr.update(value="")]
|
||||
cfg = rocm_mgr.load_config()
|
||||
gemm_val = cfg.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(None, gemm_val == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
|
|
@ -163,7 +208,8 @@ class ROCmScript(scripts_manager.Script):
|
|||
|
||||
def delete_fn():
|
||||
rocm_mgr.delete_config()
|
||||
result = [gr.update(value="")]
|
||||
gemm_default = rocm_vars.ROCM_ENV_VARS.get("MIOPEN_GEMM_ENFORCE_BACKEND", {}).get("default", "1")
|
||||
result = [gr.update(value=_build_style(None, gemm_default == "1"))]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
|
|
@ -175,11 +221,11 @@ class ROCmScript(scripts_manager.Script):
|
|||
return result
|
||||
|
||||
def profile_fn(arch):
|
||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||
rocm_mgr.apply_profile(arch)
|
||||
updated = rocm_mgr.load_config()
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
result = [gr.update(value=_build_style(unavailable))]
|
||||
gemm_val = updated.get("MIOPEN_GEMM_ENFORCE_BACKEND", "1")
|
||||
result = [gr.update(value=_build_style(unavailable, gemm_val == "1"))]
|
||||
for pname in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||
val = updated.get(pname, meta["default"])
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit d54ade8e5f79a62b5228de6406400c8eda71b67f
|
||||
Subproject commit cbbbfc73af2366650cdf8cc71fabbf3a508b607b
|
||||
Loading…
Reference in New Issue