Merge branch 'dev' into RUF013

pull/4706/head
awsr 2026-03-24 07:12:51 -07:00 committed by GitHub
commit ff247d8fd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 3643 additions and 3019 deletions

2
.gitignore vendored
View File

@ -17,6 +17,7 @@ __pycache__
/data/cache.json
/data/themes.json
/data/installer.json
/data/rocm.json
node_modules
pnpm-lock.yaml
package-lock.json
@ -57,6 +58,7 @@ tunableop_results*.csv
!webui.sh
!package.json
!requirements.txt
!constraints.txt
!/data
!/models/VAE-approx
!/models/VAE-approx/model.pt

View File

@ -1,38 +1,48 @@
# Change Log for SD.Next
## Update for 2026-03-20
## Update for 2026-03-24
### Highlights for 2026-03-20
### Highlights for 2026-03-4
This release brings massive code refactoring to modernize codebase and removal of some obsolete features. Leaner & Faster!
And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and *Anima-Preview-2*
And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and new *Anima-Preview*
If you're on Windows platform, we have a brand new [All-in-one Installer & Launcher](https://github.com/vladmandic/sdnext-launcher): simply download [exe or zip](https://github.com/vladmandic/sdnext-launcher/releases) and done!
*What else*? Really a lot!
New color grading module, updated localization with new languages and improved translations, new civitai integration module, several new upscalers, improvements to LLM/VLM in captioning and prompt enhance, a lot of new control preprocessors, new realtime server info panel, some new UI themes
New color grading module, updated localization with new languages and improved translations, new civitai integration module, new finetunes loader, several new upscalers, improvements to LLM/VLM in captioning and prompt enhance, a lot of new control preprocessors, new realtime server info panel, some new UI themes
And major work on API hardening: security, rate limits, secrets handling, new endpoints, etc.
But also many smaller quality-of-life improvements - for full details, see [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md)
*Note*: Purely due to size of changes, clean install is recommended!
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
### Details for 2026-03-20
### Details for 2026-03-24
- **Models**
- [Google Flash 3.1 Image](https://ai.google.dev/gemini-api/docs/models/gemini-3-flash-preview) a.k.a. *Nano Banana 2*
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0) *1.0 and 1.1*
*Note*: FireRed is a fine-tune of Qwen-Image-Edit regardless of its claim as a new base-model
- [Skyworks UniPic-3](https://huggingface.co/Skywork/Unipic3), *Consistency and DMD* variants to reference/community section
*Note*: UniPic-3 is a fine-tune of Qwen-Image-Edit with new distillation regardless of its claim of major changes
- [Anima Preview-v2](https://huggingface.co/circlestone-labs/Anima)
- **Image manipulation**
- new **color grading** module
- update **latent corrections** *(former HDR Corrections)* and expand allowed models
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
**upscaling** engine with suport for new upscaling model families
- add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2*
- add two new **interpolation** methods: *HQX* and *ICB*
- use high-quality [sharpfin](https://github.com/drhead/Sharpfin) accelerated library
when available (*cuda-only*)
- **upscalers**: extend chainner support for additional models
- new **Color grading** module
apply basic corrections to your images: brightness,contrast,saturation,shadows,highlights
move to professional photo corrections: hue,gamma,sharpness,temperature
correct tone: shadows,midtones,highlights
add effects: vignette,grain
apply professional lut-table using .cube file
*hint* color grading is available as step during generate or as processing item for already existing images
- **Upscaling**
add support for [spandrel](https://github.com/chaiNNer-org/spandrel) engine with suport for new upscaling model families
add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2*
add two new **interpolation** methods: *HQX* and *ICB*
use high-quality [sharpfin](https://github.com/drhead/Sharpfin) accelerated library
extend `chainner` support for additional models
- update **Latent corrections** *(former HDR Corrections)*
expand allowed models
- **Captioning / Prompt Enhance**
- new models: **Qwen-3.5**, **Mistral-3** in multiple variations
- new models: multiple *heretic* and *abliterated* finetunes for **Qwen, Gemma, Mistral**
@ -43,32 +53,38 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
- new **pre-processors**:
*anyline, depth_anything v2, dsine, lotus, marigold normals, oneformer, rtmlib pose, sam2, stablenormal, teed, vitpose*
- **Features**
- **secrets** handling: new `secrets.json` and special handling for tokens/keys/passwords
- **Secrets** handling: new `secrets.json` and special handling for tokens/keys/passwords
used to be treated like any other `config.json` param which can cause security issues
- pipelines: add **ZImageInpaint**
- rewritten **civitai** module
- rewritten **CivitAI** module
browse/discover mode with sort, period, type/base dropdowns; URL paste; subfolder sorting; auto-browse; dynamic dropdowns
- **hires**: allow using different lora in refiner prompt
- **nunchaku** models are now listed in networks tab as reference models
- **HiRes**: allow using different lora in refiner prompt
- **Nunchaku** models are now listed in networks tab as reference models
instead of being used implicitly via quantization
- improve image **metadata** parser for foreign metadata (e.g. XMP)
- improve image **Metadata** parser for foreign metadata (e.g. XMP)
- **Compute**
- **ROCm** advanced configuration and tuning, thanks @resonantsky
see *main interface -> scripts -> rocm advanced config*
- **ROCm** support for additional AMD GPUs: `gfx103X`, thanks @crashingalexsan
- **Cuda** `torch==2.10` removed support for `rtx1000` series, use following before first startup:
- **Cuda** `torch==2.10` removed support for `rtx1000` series and older GPUs
use following before first startup to force installation of `torch==2.9.1` with `cuda==12.6`:
> `set TORCH_COMMAND='torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu126'`
- **UI**
- new panel: **server info** with detailed runtime informaton
- **localization** improved translation quality and new translations locales:
- legacy panels **T2I** and **I2I** are disabled by default
you can re-enable them in *settings -> ui -> hide legacy tabs*
- new panel: **Server Info** with detailed runtime informaton
- **Networks** add **UNet/DiT**
- **Localization** improved translation quality and new translations locales:
*en, en1, en2, en3, en4, hr, es, it, fr, de, pt, ru, zh, ja, ko, hi, ar, bn, ur, id, vi, tr, sr, po, he, xx, yy, qq, tlh*
yes, this now includes stuff like *latin, esperanto, arabic, hebrew, klingon* and a lot more!
and also introduce some pseudo-locales such as: *techno-babbel*, *for-n00bs*
*hint*: click on locale icon in bottom-left corner to cycle through available locales, or set default in *settings -> ui*
- **server settings** new section in *settings*
- **kanvas** add paste image from clipboard
- **themes** add *CTD-NT64Light*, *CTD-NT64Medium* and *CTD-NT64Dark*, thanks @resonantsky
- **themes** add *Vlad-Neomorph*
- **gallery** add option to auto-refresh gallery, thanks @awsr
- **token counters** add per-section display for supported models, thanks @awsr
- **Server settings** new section in *settings*
- **Kanvas** add paste image from clipboard
- **Themes** add *CTD-NT64Light*, *CTD-NT64Medium* and *CTD-NT64Dark*, thanks @resonantsky
- **Themes** add *Vlad-Neomorph*
- **Gallery** add option to auto-refresh gallery, thanks @awsr
- **Token counters** add per-section display for supported models, thanks @awsr
- **API**
- **rate limiting**: global for all endpoints, guards against abuse and denial-of-service type of attacks
configurable in *settings -> server settings*
@ -76,7 +92,11 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
- new `/sdapi/v1/torch` endpoint for torch info (backend, version, etc.)
- new `/sdapi/v1/gpu` endpoint for GPU info
- new `/sdapi/v1/rembg` endpoint for background removal
- new `/sdadpi/v1/unet` endpoint to list available unets/dits
- use rate limiting for api logging
- **Obsoleted**
- removed support for additional quantization engines: *BitsAndBytes, TorchAO, Optimum-Quanto, NNCF*
*note*: SDNQ is quantization engine of choice for SD.Next
- **Internal**
- `python==3.13` full support
- `python==3.14` initial support
@ -147,6 +167,7 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
- improve video generation progress tracking
- handle startup with bad `scripts` more gracefully
- thread-safety for `error-limiter`, thanks @awsr
- add `lora` support for flux2-klein
## Update for 2026-02-04

113
README.md
View File

@ -1,8 +1,14 @@
<div align="center">
<img src="https://github.com/vladmandic/sdnext/raw/master/html/logo-transparent.png" width=200 alt="SD.Next">
<img src="https://github.com/vladmandic/sdnext/raw/master/html/logo-transparent.png" width=200 alt="SD.Next: AI art generator logo">
# SD.Next: All-in-one WebUI for AI generative image and video creation and captioning
# SD.Next: All-in-one WebUI
SD.Next is a powerful, open-source WebUI app for AI image and video generation, built on Stable Diffusion and supporting dozens of advanced models. Create, caption, and process images and videos with a modern, cross-platform interface—perfect for artists, researchers, and AI enthusiasts.
![Stars](https://img.shields.io/github/stars/vladmandic/sdnext?style=social)
![Forks](https://img.shields.io/github/forks/vladmandic/sdnext?style=social)
![Contributors](https://img.shields.io/github/contributors/vladmandic/sdnext)
![Last update](https://img.shields.io/github/last-commit/vladmandic/sdnext?svg=true)
![License](https://img.shields.io/github/license/vladmandic/sdnext?svg=true)
[![Discord](https://img.shields.io/discord/1101998836328697867?logo=Discord&svg=true)](https://discord.gg/VjvR2tabEX)
@ -17,61 +23,63 @@
## Table of contents
- [Documentation](https://vladmandic.github.io/sdnext-docs/)
- [SD.Next Features](#sdnext-features)
- [Model support](#model-support)
- [Platform support](#platform-support)
- [SD.Features](#features--capabilities)
- [Supported AI Models](#supported-ai-models)
- [Supported Platforms & Hardware](#supported-platforms--hardware)
- [Getting started](#getting-started)
## SD.Next Features
### Screenshot: Desktop interface
All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
- Fully localized:
**English | Chinese | Russian | Spanish | German | French | Italian | Portuguese | Japanese | Korean**
- Desktop and Mobile support!
- Multiple [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
- Multi-platform!
▹ **Windows | Linux | MacOS | nVidia CUDA | AMD ROCm | Intel Arc / IPEX XPU | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
<div align="center">
<img src="https://github.com/vladmandic/sdnext/raw/dev/html/screenshot-robot.jpg" alt="SD.Next: AI art generator desktop interface screenshot" width="90%">
</div>
### Screenshot: Mobile interface
<div align="center">
<img src="https://github.com/user-attachments/assets/ced9fe0c-d2c2-46d1-94a7-8f9f2307ce38" alt="SD.Next: AI art generator mobile interface screenshot" width="35%">
</div>
</div>
<br>
## Features & Capabilities
SD.Next is feature-rich with a focus on performance, flexibility, and user experience. Key features include:
- [Multi-platform](#platform-support!
- Many [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
- Fully localized to ~15 languages and with support for many [UI themes](https://vladmandic.github.io/sdnext-docs/Themes/)!
- [Desktop](#screenshot-desktop-interface) and [Mobile](#screenshot-mobile-interface) support!
- Platform specific auto-detection and tuning performed on install
- Optimized processing with latest `torch` developments with built-in support for model compile and quantize
Compile backends: *Triton | StableFast | DeepCache | OneDiff | TeaCache | etc.*
Quantization methods: *SDNQ | BitsAndBytes | Optimum-Quanto | TorchAO / LayerWise*
- **Captioning** with 150+ **OpenCLiP** models, **Tagger** with **WaifuDiffusion** and **DeepDanbooru** models, and 20+ built-in **VLMs**
- Built in installer with automatic updates and dependency management
<br>
### Unique features
**Desktop** interface
<div align="center">
<img src="https://github.com/vladmandic/sdnext/raw/dev/html/screenshot-robot.jpg" alt="screenshot-modernui-desktop" width="90%">
</div>
**Mobile** interface
<div align="center">
<img src="https://github.com/user-attachments/assets/ced9fe0c-d2c2-46d1-94a7-8f9f2307ce38" alt="screenshot-modernui-mobile" width="35%">
</div>
For screenshots and information on other available themes, see [Themes](https://vladmandic.github.io/sdnext-docs/Themes/)
SD.Next includes many features not found in other WebUIs, such as:
- **SDNQ**: State-of-the-Art quantization engine
Use pre-quantized or run with quantizaion on-the-fly for up to 4x VRAM reduction with no or minimal quality and performance impact
- **Balanced Offload**: Dynamically balance CPU and GPU memory to run larger models on limited hardware
- **Captioning** with 150+ **OpenCLiP** models, **Tagger** with **WaifuDiffusion** and **DeepDanbooru** models, and 25+ built-in **VLMs**
- **Image Processing** with full image correction color-grading suite of tools
<br>
## Model support
## Supported AI Models
SD.Next supports broad range of models: [supported models](https://vladmandic.github.io/sdnext-docs/Model-Support/) and [model specs](https://vladmandic.github.io/sdnext-docs/Models/)
## Platform support
## Supported Platforms & Hardware
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
- *AMD* GPUs using **ROCm** libraries on *Linux*
Support will be extended to *Windows* once AMD releases ROCm for Windows
- *AMD* GPUs using **ROCm** libraries on both *Linux and Windows*
- *AMD* GPUs on Windows using **ZLUDA** libraries
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
- Any *CPU/GPU* or device compatible with **OpenVINO** libraries on both *Windows and Linux*
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
This includes support for AMD GPUs that are not supported by native ROCm libraries
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
- *ONNX/Olive*
- *AMD* GPUs on Windows using **ZLUDA** libraries
Plus Docker container recipes for: [CUDA, ROCm, Intel IPEX and OpenVINO](https://vladmandic.github.io/sdnext-docs/Docker/)
Plus **Docker** container recipes for: [CUDA, ROCm, Intel IPEX and OpenVINO](https://vladmandic.github.io/sdnext-docs/Docker/)
## Getting started
@ -84,21 +92,37 @@ Plus Docker container recipes for: [CUDA, ROCm, Intel IPEX and OpenVINO](https:/
> And for platform specific information, check out
> [WSL](https://vladmandic.github.io/sdnext-docs/WSL/) | [Intel Arc](https://vladmandic.github.io/sdnext-docs/Intel-ARC/) | [DirectML](https://vladmandic.github.io/sdnext-docs/DirectML/) | [OpenVINO](https://vladmandic.github.io/sdnext-docs/OpenVINO/) | [ONNX & Olive](https://vladmandic.github.io/sdnext-docs/ONNX-Runtime/) | [ZLUDA](https://vladmandic.github.io/sdnext-docs/ZLUDA/) | [AMD ROCm](https://vladmandic.github.io/sdnext-docs/AMD-ROCm/) | [MacOS](https://vladmandic.github.io/sdnext-docs/MacOS-Python/) | [nVidia](https://vladmandic.github.io/sdnext-docs/nVidia/) | [Docker](https://vladmandic.github.io/sdnext-docs/Docker/)
### Quick Start
```shell
git clone https://github.com/vladmandic/sdnext
cd sdnext
./webui.sh # Linux/Mac
webui.bat # Windows
webui.ps1 # PowerShell
```
> [!WARNING]
> If you run into issues, check out [troubleshooting](https://vladmandic.github.io/sdnext-docs/Troubleshooting/) and [debugging](https://vladmandic.github.io/sdnext-docs/Debug/) guides
## Community & Support
If you're unsure how to use a feature, best place to start is [Docs](https://vladmandic.github.io/sdnext-docs/) and if its not there,
check [ChangeLog](https://vladmandic.github.io/sdnext-docs/CHANGELOG/) for when feature was first introduced as it will always have a short note on how to use it
And for any question, reach out on [Discord](https://discord.gg/VjvR2tabEX) or open an [issue](https://github.com/vladmandic/sdnext/issues) or [discussion](https://github.com/vladmandic/sdnext/discussions)
### Contributing
Please see [Contributing](CONTRIBUTING) for details on how to contribute to this project
And for any question, reach out on [Discord](https://discord.gg/VjvR2tabEX) or open an [issue](https://github.com/vladmandic/sdnext/issues) or [discussion](https://github.com/vladmandic/sdnext/discussions)
### Credits
## License & Credits
- SD.Next is licensed under the [Apache License 2.0](LICENSE.txt)
- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for the original codebase
- Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
- Licenses for modules are listed in [Licenses](html/licenses.html)
### Evolution
## Evolution
<a href="https://star-history.com/#vladmandic/sdnext&Date">
<picture width=640>
@ -109,9 +133,4 @@ And for any question, reach out on [Discord](https://discord.gg/VjvR2tabEX) or o
- [OSS Stats](https://ossinsight.io/analyze/vladmandic/sdnext#overview)
### Docs
If you're unsure how to use a feature, best place to start is [Docs](https://vladmandic.github.io/sdnext-docs/) and if its not there,
check [ChangeLog](https://vladmandic.github.io/sdnext-docs/CHANGELOG/) for when feature was first introduced as it will always have a short note on how to use it
<br>

17
TODO.md
View File

@ -2,20 +2,18 @@
## Release
- Update **README**
- Bumb packages
- Implement `unload_auxiliary_models`
- Release **Launcher**
- Release **Enso**
- Update **ROCm**
- Tips **Color Grading**
- Implement: `unload_auxiliary_models`
- Add notes: **Enso**
- Tips: **Color Grading**
- Regen: **Localization**
## Internal
- Feature: Color grading in processing
- Integrate: [Depth3D](https://github.com/vladmandic/sd-extension-depth3d)
- Feature: RIFE update
- Feature: RIFE in processing
- Feature: SeedVR2 in processing
- Feature: Add video models to `Reference`
- Deploy: Lite vs Expert mode
- Engine: [mmgp](https://github.com/deepbeepmeep/mmgp)
- Engine: `TensorRT` acceleration
@ -64,6 +62,7 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
### Image-Edit
- [Bria FIBO-Edit](https://huggingface.co/briaai/Fibo-Edit-RMBG): Fully JSON-based instruction-following image editing framework
- [Meituan LongCat-Image-Edit-Turbo](https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo):6B instruction-following image editing with high visual consistency
- [VIBE Image-Edit](https://huggingface.co/iitolstykh/VIBE-Image-Edit): (Sana+Qwen-VL)Fast visual instruction-based image editing framework
- [LucyEdit](https://github.com/huggingface/diffusers/pull/12340):Instruction-guided video editing while preserving motion and identity
@ -145,3 +144,5 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
- modules/modular_guiders.py:65:58: W0511: TODO: guiders
- processing: remove duplicate mask params
- resize image: enable full VAE mode for resize-latent
modules/sd_samplers_diffusers.py:353:31: W0511: TODO enso-required (fixme)

0
constraints.txt Normal file
View File

View File

@ -148,7 +148,7 @@
"date": "2026 March",
"skip": true
},
"FireRed Image Edit": {
"FireRed Image Edit 1.0": {
"path": "FireRedTeam/FireRed-Image-Edit-1.0",
"preview": "FireRedTeam--FireRed-Image-Edit-1.0.jpg",
"desc": "FireRed-Image-Edit is a general-purpose image editing model that delivers high-fidelity and consistent editing across a wide range of scenarios. FireRed is a fine-tune of Qwen-Image-Edit.",
@ -156,6 +156,14 @@
"date": "2026 February",
"skip": true
},
"FireRed Image Edit 1.1": {
"path": "FireRedTeam/FireRed-Image-Edit-1.1",
"preview": "FireRedTeam--FireRed-Image-Edit-1.0.jpg",
"desc": "FireRed-Image-Edit is a general-purpose image editing model that delivers high-fidelity and consistent editing across a wide range of scenarios. FireRed is a fine-tune of Qwen-Image-Edit.",
"tags": "community",
"date": "2026 February",
"skip": true
},
"Skywork UniPic3": {
"path": "Skywork/Unipic3",
"preview": "Skywork--Unipic3.jpg",

@ -1 +1 @@
Subproject commit 9153b52f3980fe857c4ab9c3dd4f131b6175d20e
Subproject commit 9d584a1bdc0c2aca614aa0e1e34e4374c3aa779d

File diff suppressed because it is too large Load Diff

View File

@ -475,7 +475,7 @@ def check_diffusers():
t_start = time.time()
if args.skip_all:
return
target_commit = "e5aa719241f9b74d6700be3320a777799bfab70a" # diffusers commit hash
target_commit = "c02c17c6ee7ac508c56925dde4d4a3c587650dc3" # diffusers commit hash
# if args.use_rocm or args.use_zluda or args.use_directml:
# sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now
pkg = package_spec('diffusers')
@ -687,7 +687,7 @@ def install_ipex():
if args.use_nightly:
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+xpu torchvision==0.25.0+xpu --index-url https://download.pytorch.org/whl/xpu')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.11.0+xpu torchvision==0.26.0+xpu --index-url https://download.pytorch.org/whl/xpu')
ts('ipex', t_start)
return torch_command
@ -700,13 +700,12 @@ def install_openvino():
#check_python(supported_minors=[10, 11, 12, 13], reason='OpenVINO backend requires a Python version between 3.10 and 3.13')
if sys.platform == 'darwin':
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0 torchvision==0.25.0')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.11.0 torchvision==0.26.0')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+cpu torchvision==0.25.0 --index-url https://download.pytorch.org/whl/cpu')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.11.0+cpu torchvision==0.26.0 --index-url https://download.pytorch.org/whl/cpu')
if not (args.skip_all or args.skip_requirements):
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino')
install(os.environ.get('NNCF_COMMAND', 'nncf==2.19.0'), 'nncf')
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2026.0.0'), 'openvino')
ts('openvino', t_start)
return torch_command
@ -730,12 +729,8 @@ def install_torch_addons():
install('DeepCache')
if opts.get('cuda_compile_backend', '') == 'olive-ai':
install('olive-ai')
if len(opts.get('optimum_quanto_weights', [])):
install('optimum-quanto==0.2.7', 'optimum-quanto')
if len(opts.get('torchao_quantization', [])):
install('torchao==0.10.0', 'torchao')
if opts.get('samples_format', 'jpg') == 'jxl' or opts.get('grid_format', 'jpg') == 'jxl':
install('pillow-jxl-plugin==1.3.5', 'pillow-jxl-plugin')
install('pillow-jxl-plugin==1.3.7', 'pillow-jxl-plugin')
if not args.experimental:
uninstall('wandb', quiet=True)
uninstall('pynvml', quiet=True)
@ -894,7 +889,7 @@ def check_torch():
elif torch.version.hip and allow_rocm:
torch_info.set(type='rocm', hip=torch.version.hip)
else:
log.warning('Unknown Torch backend')
log.warning('Torch backend: cannot detect type')
log.info(f"Torch backend: {torch_info}")
for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]:
gpu = {
@ -1184,14 +1179,11 @@ def install_optional():
install('hf_transfer', ignore=True, quiet=True)
install('hf_xet', ignore=True, quiet=True)
install('nvidia-ml-py', ignore=True, quiet=True)
install('pillow-jxl-plugin==1.3.5', ignore=True, quiet=True)
install('pillow-jxl-plugin==1.3.7', ignore=True, quiet=True)
install('ultralytics==8.3.40', ignore=True, quiet=True)
install('open-clip-torch', no_deps=True, quiet=True)
install('git+https://github.com/tencent-ailab/IP-Adapter.git', 'ip_adapter', ignore=True, quiet=True)
# install('git+https://github.com/openai/CLIP.git', 'clip', quiet=True, no_build_isolation=True)
# install('torchao==0.10.0', ignore=True, quiet=True)
# install('bitsandbytes==0.47.0', ignore=True, quiet=True)
# install('optimum-quanto==0.2.7', ignore=True, quiet=True)
ts('optional', t_start)
@ -1235,6 +1227,7 @@ def install_requirements():
# set environment variables controling the behavior of various libraries
def set_environment():
log.debug('Setting environment tuning')
os.environ.setdefault('PIP_CONSTRAINT', os.path.abspath('constraints.txt'))
os.environ.setdefault('ACCELERATE', 'True')
os.environ.setdefault('ATTN_PRECISION', 'fp16')
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')

View File

@ -524,6 +524,14 @@ function selectVAE(name) {
markSelectedCards([desiredVAEName], 'vae');
}
let desiredUNetName = null;
function selectUNet(name) {
desiredUNetName = name;
gradioApp().getElementById('change_unet').click();
log(`selectUNet: ${desiredUNetName}`);
markSelectedCards([desiredUNetName], 'unet');
}
function selectReference(name) {
log(`selectReference: ${name}`);
desiredCheckpointName = name;

View File

@ -308,6 +308,7 @@ def main():
log.warning('Restart is recommended due to packages updates...')
t_server = time.time()
t_monitor = time.time()
while True:
try:
alive = uv.thread.is_alive()
@ -326,8 +327,10 @@ def main():
if float(monitor_rate) > 0 and t_current - t_monitor > float(monitor_rate):
log.trace(f'Monitor: {get_memory_stats(detailed=True)}')
t_monitor = t_current
from modules.api.validate import get_api_stats
get_api_stats()
# from modules.api.validate import get_api_stats
# get_api_stats()
# from modules import memstats
# memstats.get_objects()
if not alive:
if uv is not None and uv.wants_restart:
clean_server()

View File

@ -88,6 +88,7 @@ class Api:
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=list[models.ItemVae])
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=list[models.ItemExtension])
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=list[models.ItemExtraNetwork])
self.add_api_route("/sdapi/v1/unets", endpoints.get_unets, methods=["GET"], response_model=list[models.ItemUNet])
# functional api
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo, tags=["Functional"])
@ -98,6 +99,7 @@ class Api:
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"], tags=["Functional"])
self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"], tags=["Functional"])
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"], tags=["Functional"])
self.add_api_route("/sdapi/v1/refresh-unets", endpoints.post_refresh_unets, methods=["POST"], tags=["Functional"])
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=list[str], tags=["Functional"])
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int, tags=["Functional"])
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"], tags=["Functional"])

View File

@ -65,7 +65,6 @@ get_restorers = get_detailers # legacy alias for /sdapi/v1/face-restorers
def get_ip_adapters():
"""
List available IP-Adapter models.
Returns adapter names that can be used for image-prompt conditioning during generation.
"""
from modules import ipadapter
@ -75,6 +74,11 @@ def get_prompt_styles():
"""List all saved prompt styles with their prompt, negative prompt, and preview."""
return [{ 'name': v.name, 'prompt': v.prompt, 'negative_prompt': v.negative_prompt, 'extra': v.extra, 'filename': v.filename, 'preview': v.preview} for v in shared.prompt_styles.styles.values()]
def get_unets():
"""List available UNet models with their names and filenames."""
from modules.sd_unet import unet_dict
return [{"name": k, "filename": v} for k, v in unet_dict.items()]
def get_embeddings():
"""List loaded and skipped textual-inversion embeddings for the current model."""
db = getattr(shared.sd_model, 'embedding_db', None) if shared.sd_loaded else None
@ -221,6 +225,11 @@ def post_lock_checkpoint(lock:bool=False):
modeldata.model_data.locked = lock
return {}
def post_refresh_unets():
"""Rescan UNet directories and update the available UNet list."""
import modules.sd_unet
return modules.sd_unet.refresh_unet_list()
def get_checkpoint():
"""Return information about the currently loaded checkpoint including type, class, title, and hash."""
if not shared.sd_loaded or shared.sd_model is None:

View File

@ -146,6 +146,10 @@ class ItemStyle(BaseModel):
filename: str | None = Field(title="Filename", description="Path to the styles file")
preview: str | None = Field(title="Preview", description="URL to the style preview image")
class ItemUNet(BaseModel):
name: str = Field(title="Name", description="UNet/DiT name")
filename: str | None = Field(title="Filename", description="Path to the UNet/DiT file")
class ItemExtraNetwork(BaseModel):
name: str = Field(title="Name", description="Network short name")
type: str = Field(title="Type", description="Network type (lora, checkpoint, embedding, etc.)")

View File

@ -219,8 +219,8 @@ def civit_search_metadata(title: str | None = None, raw: bool = False):
import concurrent
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_items = {}
for fn in candidates:
future_items[executor.submit(atomic_civit_search_metadata, fn, results)] = fn
for candidate in candidates:
future_items[executor.submit(atomic_civit_search_metadata, candidate, results)] = candidate
for future in concurrent.futures.as_completed(future_items):
future.result()
yield results if raw else create_search_metadata_table(results)

View File

@ -74,8 +74,8 @@ def add_diag_args(p):
p.add_argument('--safe', default=env_flag("SD_SAFE", False), action='store_true', help="Run in safe mode with no user extensions")
p.add_argument('--test', default=env_flag("SD_TEST", False), action='store_true', help="Run test only and exit")
p.add_argument('--version', default=False, action='store_true', help="Print version information")
p.add_argument("--monitor", default=os.environ.get("SD_MONITOR", -1), help="Run memory monitor, default: %(default)s")
p.add_argument("--status", default=os.environ.get("SD_STATUS", -1), help="Run server is-alive status, default: %(default)s")
p.add_argument("--monitor", type=float, default=float(os.environ.get("SD_MONITOR", -1)), help="Run memory monitor, default: %(default)s")
p.add_argument("--status", type=float, default=float(os.environ.get("SD_STATUS", -1)), help="Run server is-alive status, default: %(default)s")
def add_log_args(p):

View File

@ -382,22 +382,6 @@ class ControlNet():
self.model = sdnq_quantize_model(self.model)
except Exception as e:
log.error(f'Control {what} model SDNQ Compression failed: id="{model_id}" {e}')
elif "Control" in opts.optimum_quanto_weights:
try:
log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"')
model_quant.load_quanto('Load model: type=Control')
from modules.model_quant import optimum_quanto_model
self.model = optimum_quanto_model(self.model)
except Exception as e:
log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}')
elif "Control" in opts.torchao_quantization:
try:
log.debug(f'Control {what} model Torch AO: id="{model_id}"')
model_quant.load_torchao('Load model: type=Control')
from modules.model_quant import torchao_quantization
self.model = torchao_quantization(self.model)
except Exception as e:
log.error(f'Control {what} model Torch AO: id="{model_id}" {e}')
if self.device is not None:
sd_models.move_model(self.model, self.device)
if "Control" in opts.cuda_compile:

View File

@ -1,5 +1,5 @@
import os
from installer import log, git
from installer import log, git, run_extension_installer
from modules.paths import extensions_dir
@ -14,6 +14,7 @@ def install():
return
log.info(f'Enso: folder="{ENSO_DIR}" installing')
git(f'clone "{ENSO_REPO}" "{ENSO_DIR}"')
run_extension_installer(ENSO_DIR)
def update():
@ -22,3 +23,4 @@ def update():
return
log.info(f'Enso: folder="{ENSO_DIR}" updating')
git('pull', folder=ENSO_DIR)
run_extension_installer(ENSO_DIR)

View File

@ -12,7 +12,7 @@ _lock = Lock()
def _make_unique(name: str):
global _instance_id
global _instance_id # pylint: disable=global-statement
with _lock: # Guard against race conditions
new_name = f"{name}__{_instance_id}"
_instance_id += 1

View File

@ -16,6 +16,28 @@ torch_version[0], torch_version[1] = int(torch_version[0]), int(torch_version[1]
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def return_true(*args, **kwargs):
return True
def return_false(*args, **kwargs):
return False
def return_none(*args, **kwargs):
return None
def return_zero(*args, **kwargs):
return 0
def return_cuda_version(*args, **kwargs):
return (12,1)
def return_xpu_string(*args, **kwargs):
return "xpu"
def return_arch_list(*args, **kwargs):
return ["pvc", "dg2", "ats-m150"]
def ipex_init(): # pylint: disable=too-many-statements
try:
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
@ -26,9 +48,9 @@ def ipex_init(): # pylint: disable=too-many-statements
# import inductor utils to get around lazy import
from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401,RUF100
torch._inductor.utils.GPU_TYPES = ["xpu"]
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
torch._inductor.utils.get_gpu_type = return_xpu_string
from triton import backends as triton_backends # pylint: disable=import-error
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
triton_backends.backends["nvidia"].driver.is_active = return_false
except Exception:
pass
# Replace cuda with xpu:
@ -51,15 +73,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.random = torch.xpu.random
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda._device = torch.xpu._device
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.is_current_stream_capturing = return_false
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__builtins__ = torch.xpu.__builtins__
@ -141,12 +160,23 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
if torch_version[0] < 2 or (torch_version[0] == 2 and torch_version[1] < 11):
torch.cuda.Union = torch.xpu.Union
torch.cuda._device = torch.xpu._device
torch.cuda._device_t = torch.xpu._device_t
# Memory:
if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.xpu.empty_cache = return_none
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory = torch.xpu.memory
if torch_version[0] >= 2 and torch_version[1] >= 8:
old_cpa = torch.cuda.memory.CUDAPluggableAllocator
torch.cuda.memory = torch.xpu.memory
torch.xpu.memory.CUDAPluggableAllocator = old_cpa
else:
torch.cuda.memory = torch.xpu.memory
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
@ -172,21 +202,24 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# Fix functions with ipex:
# torch.xpu.mem_get_info always returns the total memory as free memory
torch.has_cuda = True
torch.version.cuda = "12.1"
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch._utils._get_available_device_type = lambda: "xpu"
torch.backends.cuda.is_built = return_true
torch._utils._get_available_device_type = return_xpu_string
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
# torch.xpu.mem_get_info always returns the total memory as free memory
def mem_get_info(device=None):
return [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch.xpu.mem_get_info = mem_get_info
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True)
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"])
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", return_true)
torch.cuda.is_fp16_supported = getattr(torch.xpu, "is_fp16_supported", return_true)
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", return_arch_list)
torch.cuda.get_device_capability = return_cuda_version
torch.cuda.ipc_collect = return_none
torch.cuda.utilization = return_zero
device_supports_fp64 = ipex_hijacks()
try:

View File

@ -15,8 +15,10 @@ torch_version[0], torch_version[1] = int(torch_version[0]), int(torch_version[1]
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(devices.device).has_fp64
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
# pylint: disable=protected-access, missing-function-docstring, line-too-long, no-else-return
def return_false(*args, **kwargs):
return False
@property
def is_cuda(self):
@ -24,7 +26,7 @@ def is_cuda(self):
def check_device_type(device, device_type: str) -> bool:
if device is None or type(device) not in {str, int, torch.device}:
if device is None or not isinstance(device, (str, int, torch.device)):
return False
else:
return bool(torch.device(device).type == device_type)
@ -137,24 +139,9 @@ def as_tensor(data, dtype=None, device=None):
return original_as_tensor(data, dtype=dtype, device=device)
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
global device_supports_fp64
if check_cuda(device):
device = return_xpu(device)
if not device_supports_fp64 and check_device_type(device, "xpu"):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
torch.Tensor.original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
global device_supports_fp64
if check_cuda(device):
device = return_xpu(device)
if not device_supports_fp64:
@ -210,6 +197,24 @@ if torch_version[0] > 2 or (torch_version[0] == 2 and torch_version[1] >= 4):
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_cuda(device):
if not device_supports_fp64 and (dtype == torch.float64 or (dtype is None and getattr(data, "dtype", None) in {torch.float64, float})):
return original_torch_tensor(data, *args, dtype=torch.float32, device=return_xpu(device), **kwargs)
else:
return original_torch_tensor(data, *args, dtype=dtype, device=return_xpu(device), **kwargs)
else:
if (
not device_supports_fp64 and check_device_type(device, "xpu")
and (dtype == torch.float64 or (dtype is None and getattr(data, "dtype", None) in {torch.float64, float}))
):
return original_torch_tensor(data, *args, dtype=torch.float32, device=device, **kwargs)
else:
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
@ -221,11 +226,11 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, dtype=None, **kwargs):
def torch_randn(*args, device=None, **kwargs):
if check_cuda(device):
return original_torch_randn(*args, device=return_xpu(device), dtype=dtype, **kwargs)
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, dtype=dtype, **kwargs)
return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones
@ -255,34 +260,6 @@ def torch_full(*args, device=None, **kwargs):
return original_torch_full(*args, device=device, **kwargs)
original_torch_arange = torch.arange
@wraps(torch.arange)
def torch_arange(*args, device=None, dtype=None, **kwargs):
global device_supports_fp64
if check_cuda(device):
if not device_supports_fp64 and dtype == torch.float64:
dtype = torch.float32
return original_torch_arange(*args, device=return_xpu(device), dtype=dtype, **kwargs)
else:
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
dtype = torch.float32
return original_torch_arange(*args, device=device, dtype=dtype, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, dtype=None, **kwargs):
global device_supports_fp64
if check_cuda(device):
if not device_supports_fp64 and dtype == torch.float64:
dtype = torch.float32
return original_torch_linspace(*args, device=return_xpu(device), dtype=dtype, **kwargs)
else:
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
dtype = torch.float32
return original_torch_linspace(*args, device=device, dtype=dtype, **kwargs)
original_torch_eye = torch.eye
@wraps(torch.eye)
def torch_eye(*args, device=None, **kwargs):
@ -292,6 +269,36 @@ def torch_eye(*args, device=None, **kwargs):
return original_torch_eye(*args, device=device, **kwargs)
original_torch_arange = torch.arange
@wraps(torch.arange)
def torch_arange(*args, dtype=None, device=None, **kwargs):
if check_cuda(device):
if not device_supports_fp64 and dtype == torch.float64:
return original_torch_arange(*args, dtype=torch.float32, device=return_xpu(device), **kwargs)
else:
return original_torch_arange(*args, dtype=dtype, device=return_xpu(device), **kwargs)
else:
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
return original_torch_arange(*args, dtype=torch.float32, device=device, **kwargs)
else:
return original_torch_arange(*args, dtype=dtype, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, dtype=None, device=None, **kwargs):
if check_cuda(device):
if not device_supports_fp64 and dtype == torch.float64:
return original_torch_linspace(*args, dtype=torch.float32, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, dtype=dtype, device=return_xpu(device), **kwargs)
else:
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
return original_torch_linspace(*args, dtype=torch.float32, device=device, **kwargs)
else:
return original_torch_linspace(*args, dtype=dtype, device=device, **kwargs)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
@ -360,24 +367,29 @@ class torch_Generator(original_torch_Generator):
# Hijack Functions:
def ipex_hijacks():
global device_supports_fp64
torch.UntypedStorage.__init__ = UntypedStorage_init
if torch_version[0] > 2 or (torch_version[0] == 2 and torch_version[1] >= 4):
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.UntypedStorage.to = UntypedStorage_to
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
# transformers completely breaks when anything is done to torch.tensor
# even straight passthroughs breaks transformers for some reason
#torch.tensor = torch_tensor
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.full = torch_full
torch.eye = torch_eye
torch.arange = torch_arange
torch.linspace = torch_linspace
torch.eye = torch_eye
torch.load = torch_load
torch.cuda.synchronize = torch_cuda_synchronize
torch.cuda.device = torch_cuda_device
torch.cuda.set_device = torch_cuda_set_device
@ -437,6 +449,6 @@ def ipex_hijacks():
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
torch.cuda.amp.common.amp_definitely_not_available = return_false
return device_supports_fp64

View File

@ -1,13 +1,11 @@
import os
import sys
import torch
import nncf
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino.frontend import FrontEndManager
from openvino import Core, Type, PartialShape, serialize
from openvino.properties import hint as ov_hints
from openvino.frontend import FrontEndManager # pylint: disable=no-name-in-module
from openvino import Core, Type, PartialShape, serialize # pylint: disable=no-name-in-module
from openvino.properties import hint as ov_hints # pylint: disable=no-name-in-module
from torch._dynamo.backends.common import fake_tensor_unsupported
from torch._dynamo.backends.registry import register_backend
@ -23,25 +21,6 @@ from modules import shared, devices, sd_models_utils
from modules.logger import log
# importing openvino.runtime forces DeprecationWarning to "always"
# And Intel's own libs (NNCF) imports the deprecated module
# Don't allow openvino to override warning filters:
try:
import warnings
filterwarnings = warnings.filterwarnings
warnings.filterwarnings = lambda *args, **kwargs: None
import openvino.runtime # pylint: disable=unused-import
installer.torch_info.set(openvino=openvino.runtime.get_version())
warnings.filterwarnings = filterwarnings
except Exception:
pass
try:
# silence the pytorch version warning
nncf.common.logging.logger.warn_bkc_version_mismatch = lambda *args, **kwargs: None
except Exception:
pass
# Set default params
torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) # pylint: disable=protected-access
torch._dynamo.eval_frame.check_if_dynamo_supported = lambda: True # pylint: disable=protected-access
@ -213,11 +192,7 @@ def execute_cached(compiled_model, *args):
def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str | None = None, file_name=""):
core = Core()
device = get_device()
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
om = core.read_model(file_name + ".xml")
@ -259,26 +234,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str | Non
om.inputs[idx-idx_minus].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types()
if shared.opts.nncf_quantize and not dont_use_quant:
new_inputs = []
for idx, _ in enumerate(example_inputs):
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
new_inputs = [new_inputs]
if shared.opts.nncf_quantize_mode == "INT8":
om = nncf.quantize(om, nncf.Dataset(new_inputs))
else:
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
if shared.opts.nncf_compress_weights and not dont_use_nncf:
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
om = nncf.compress_weights(om)
else:
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
hints = {}
if shared.opts.openvino_accuracy == "performance":
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
@ -287,9 +242,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str | Non
if model_hash_str is not None:
hints['CACHE_DIR'] = shared.opts.openvino_cache_path + '/blob'
core.set_property(hints)
dont_use_nncf = False
dont_use_quant = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, device)
return compiled_model
@ -299,44 +251,17 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
core = Core()
om = core.read_model(cached_model_path + ".xml")
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant
for idx, input_data in enumerate(example_inputs):
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types()
if shared.opts.nncf_quantize and not dont_use_quant:
new_inputs = []
for idx, _ in enumerate(example_inputs):
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
new_inputs = [new_inputs]
if shared.opts.nncf_quantize_mode == "INT8":
om = nncf.quantize(om, nncf.Dataset(new_inputs))
else:
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
if shared.opts.nncf_compress_weights and not dont_use_nncf:
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
om = nncf.compress_weights(om)
else:
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
hints = {'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'}
if shared.opts.openvino_accuracy == "performance":
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
elif shared.opts.openvino_accuracy == "accuracy":
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY
core.set_property(hints)
dont_use_nncf = False
dont_use_quant = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, get_device())
return compiled_model
@ -462,14 +387,8 @@ def get_subgraph_type(tensor):
@fake_tensor_unsupported
def openvino_fx(subgraph, example_inputs, options=None):
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant
global subgraph_type
dont_use_4bit_nncf = False
dont_use_nncf = False
dont_use_quant = False
dont_use_faketensors = False
executor_parameters = None
inputs_reversed = False
@ -478,25 +397,25 @@ def openvino_fx(subgraph, example_inputs, options=None):
subgraph_type = []
subgraph.apply(get_subgraph_type)
"""
# SD 1.5 / SDXL VAE
if (subgraph_type[0] is torch.nn.modules.conv.Conv2d and
if (
subgraph_type[0] is torch.nn.modules.conv.Conv2d and
subgraph_type[1] is torch.nn.modules.conv.Conv2d and
subgraph_type[2] is torch.nn.modules.normalization.GroupNorm and
subgraph_type[3] is torch.nn.modules.activation.SiLU):
dont_use_4bit_nncf = True
dont_use_nncf = bool("VAE" not in shared.opts.nncf_compress_weights)
dont_use_quant = bool("VAE" not in shared.opts.nncf_quantize)
subgraph_type[3] is torch.nn.modules.activation.SiLU
):
pass
"""
# SD 1.5 / SDXL Text Encoder
elif (subgraph_type[0] is torch.nn.modules.sparse.Embedding and
if (
subgraph_type[0] is torch.nn.modules.sparse.Embedding and
subgraph_type[1] is torch.nn.modules.sparse.Embedding and
subgraph_type[2] is torch.nn.modules.normalization.LayerNorm and
subgraph_type[3] is torch.nn.modules.linear.Linear):
subgraph_type[3] is torch.nn.modules.linear.Linear
):
dont_use_faketensors = True
dont_use_nncf = bool("TE" not in shared.opts.nncf_compress_weights)
dont_use_quant = bool("TE" not in shared.opts.nncf_quantize)
# Create a hash to be used for caching
shared.compiled_model_state.model_hash_str = ""

View File

@ -3,6 +3,7 @@ from functools import partial
import os
import re
import sys
import types
import logging
import warnings
import urllib3
@ -133,6 +134,14 @@ timer.startup.record("accelerate")
import pydantic # pylint: disable=W0611,C0411
timer.startup.record("pydantic")
try:
# transformers==5.x has different dependency stack so switching between v4 and v5 becomes very painful
# this temporarily disables dependency version checks so we can use either v4 or v5 until we drop support for v4
fake_version_check = types.ModuleType("transformers.dependency_versions_check")
sys.modules["transformers.dependency_versions_check"] = fake_version_check # disable transformers version checks
fake_version_check.dep_version_check = lambda pkg, hint=None: None
except Exception:
pass
import transformers # pylint: disable=W0611,C0411
from transformers import logging as transformers_logging # pylint: disable=W0611,C0411
transformers_logging.set_verbosity_error()
@ -175,9 +184,10 @@ except Exception as e:
sys.exit(1)
try:
pass # pylint: disable=W0611,C0411
import pillow_jxl # pylint: disable=W0611,C0411
except Exception:
pass
from PIL import Image # pylint: disable=W0611,C0411
timer.startup.record("pillow")

View File

@ -4,13 +4,13 @@ import re
import time
from typing import TYPE_CHECKING
import torch
import diffusers.models.lora
from modules.lora import lora_common as l
from modules import shared, devices, errors, model_quant
from modules.logger import log
if TYPE_CHECKING:
from collections.abc import Callable
import diffusers.models.lora
bnb = None

View File

@ -257,7 +257,11 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
shared.compiled_model_state.lora_model.append(f"{name}:{lora_scale}")
lora_method = lora_overrides.get_method(shorthash)
if lora_method == 'diffusers':
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
if shared.sd_model_type == 'f2':
from pipelines.flux import flux2_lora
net = flux2_lora.try_load_lokr(name, network_on_disk, lora_scale)
if net is None:
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
elif lora_method == 'nunchaku':
pass # handled directly from extra_networks_lora.load_nunchaku
else:
@ -272,7 +276,8 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
continue
if net is None:
failed_to_load_networks.append(name)
log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} not found')
lora_ver = network_on_disk.sd_version if network_on_disk is not None else None
log.error(f'Network load: type=LoRA name="{name}" detected={lora_ver} not loaded')
continue
if hasattr(sd_model, 'embedding_db'):
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
@ -309,6 +314,12 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
errors.display(e, 'LoRA')
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, force=True, silent=True) # some layers may end up on cpu without hook
# Activate native modules loaded via diffusers path (e.g., LoKR on Flux2)
native_nets = [net for net in l.loaded_networks if len(net.modules) > 0]
if native_nets:
from modules.lora import networks
networks.network_activate()
if len(l.loaded_networks) > 0 and l.debug:
log.debug(f'Network load: type=LoRA loaded={[n.name for n in l.loaded_networks]} cache={list(lora_cache)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')

View File

@ -58,6 +58,8 @@ class NetworkOnDisk:
return 'sc'
if base.startswith("sd3"):
return 'sd3'
if base.startswith("flux2") or "klein" in base:
return 'f2'
if base.startswith("flux"):
return 'f1'
if base.startswith("hunyuan_video"):
@ -75,6 +77,8 @@ class NetworkOnDisk:
return 'xl'
if arch.startswith("stable-cascade"):
return 'sc'
if arch.startswith("flux2") or "klein" in arch:
return 'f2'
if arch.startswith("flux"):
return 'f1'
if arch.startswith("hunyuan-video"):
@ -86,6 +90,8 @@ class NetworkOnDisk:
return 'sd1'
if str(self.metadata.get('ss_v2', "")) == "True":
return 'sd2'
if 'klein' in self.name.lower() or 'klein' in self.fullname.lower():
return 'f2'
if 'flux' in self.name.lower():
return 'f1'
if 'xl' in self.name.lower():

View File

@ -55,3 +55,40 @@ class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-metho
output_shape = target.shape
updown = make_kron(output_shape, w1, w2)
return self.finalize_updown(updown, target, output_shape)
class NetworkModuleLokrChunk(NetworkModuleLokr):
"""LoKR module that returns one chunk of the Kronecker product.
Used when a LoKR adapter targets a fused weight (e.g., QKV) but the model
has separate modules (Q, K, V). Computes kron(w1, w2) on-the-fly and
returns only the designated chunk, keeping memory usage minimal.
"""
def __init__(self, net, weights, chunk_index, num_chunks):
super().__init__(net, weights)
self.chunk_index = chunk_index
self.num_chunks = num_chunks
def calc_updown(self, target):
if self.w1 is not None:
w1 = self.w1.to(target.device, dtype=target.dtype)
else:
w1a = self.w1a.to(target.device, dtype=target.dtype)
w1b = self.w1b.to(target.device, dtype=target.dtype)
w1 = w1a @ w1b
if self.w2 is not None:
w2 = self.w2.to(target.device, dtype=target.dtype)
elif self.t2 is None:
w2a = self.w2a.to(target.device, dtype=target.dtype)
w2b = self.w2b.to(target.device, dtype=target.dtype)
w2 = w2a @ w2b
else:
t2 = self.t2.to(target.device, dtype=target.dtype)
w2a = self.w2a.to(target.device, dtype=target.dtype)
w2b = self.w2b.to(target.device, dtype=target.dtype)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
full_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
updown = make_kron(full_shape, w1, w2)
updown = torch.chunk(updown, self.num_chunks, dim=0)[self.chunk_index]
output_shape = list(updown.shape)
return self.finalize_updown(updown, target, output_shape)

View File

@ -1,9 +1,11 @@
import re
import sys
import os
import types
from collections import deque
import psutil
import torch
from modules import shared, errors
from modules import shared, errors, devices
from modules.logger import log
@ -130,28 +132,53 @@ def reset_stats():
class Object:
pattern = r"'(.*?)'"
def get_size(self, obj, seen=None):
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0 # Avoid double counting
seen.add(obj_id)
if isinstance(obj, dict):
size += sum(self.get_size(k, seen) + self.get_size(v, seen) for k, v in obj.items())
elif isinstance(obj, (list, tuple, set, frozenset, deque)):
size += sum(self.get_size(i, seen) for i in obj)
return size
def __init__(self, name, obj):
self.id = id(obj)
self.name = name
self.fn = sys._getframe(2).f_code.co_name
self.size = sys.getsizeof(obj)
self.refcount = sys.getrefcount(obj)
if torch.is_tensor(obj):
self.type = obj.dtype
self.size = obj.element_size() * obj.nelement()
else:
self.type = re.findall(self.pattern, str(type(obj)))[0]
self.size = sys.getsizeof(obj)
self.size = self.get_size(obj)
def __str__(self):
return f'{self.fn}.{self.name} type={self.type} size={self.size} ref={self.refcount}'
def get_objects(gcl=None, threshold:int=0):
def get_objects(gcl=None, threshold:int=1024*1024):
devices.torch_gc(force=True)
if gcl is None:
# gcl = globals()
gcl = {}
log.trace(f'Memory: modules={len(sys.modules)}')
for _module_name, module in sys.modules.items():
try:
if not isinstance(module, types.ModuleType):
continue
namespace = vars(module)
gcl.update(namespace)
except Exception:
pass # Some modules may not allow introspection
objects = []
seen = []
log.trace(f'Memory: items={len(gcl)} threshold={threshold}')
for name, obj in gcl.items():
if id(obj) in seen:
continue
@ -169,6 +196,6 @@ def get_objects(gcl=None, threshold:int=0):
objects = sorted(objects, key=lambda x: x.size, reverse=True)
for obj in objects:
log.trace(obj)
log.trace(f'Memory: {obj}')
return objects

View File

@ -1,12 +1,10 @@
import os
import re
import sys
import copy
import json
import time
import diffusers
import transformers
from installer import installed, install, setup_logging
from installer import install
from modules.logger import log
@ -51,70 +49,6 @@ def dont_quant():
return False
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared, devices
if allow and (module == 'any' or module in shared.opts.bnb_quantization):
load_bnb()
if bnb is None:
return kwargs
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
bnb_4bit_compute_dtype=devices.dtype,
llm_int8_skip_modules=modules_to_not_convert,
)
log.debug(f'Quantization: module={module} type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
if kwargs is None:
return bnb_config
else:
kwargs['quantization_config'] = bnb_config
return kwargs
return kwargs
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared
if allow and (shared.opts.torchao_quantization_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.torchao_quantization):
torchao = load_torchao()
if torchao is None:
return kwargs
if module in {'TE', 'LLM'}:
ao_config = transformers.TorchAoConfig(quant_type=shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
else:
ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
log.debug(f'Quantization: module={module} type=torchao dtype={shared.opts.torchao_quantization_type}')
if kwargs is None:
return ao_config
else:
kwargs['quantization_config'] = ao_config
return kwargs
return kwargs
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared
if allow and (module == 'any' or module in shared.opts.quanto_quantization):
load_quanto(silent=True)
if optimum_quanto is None:
return kwargs
if module in {'TE', 'LLM'}:
quanto_config = transformers.QuantoConfig(weights=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
quanto_config.weights_dtype = quanto_config.weights
else:
quanto_config = diffusers.QuantoConfig(weights_dtype=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
quanto_config.activations = None # patch so it works with transformers
quanto_config.weights = quanto_config.weights_dtype
log.debug(f'Quantization: module={module} type=quanto dtype={shared.opts.quanto_quantization_type}')
if kwargs is None:
return quanto_config
else:
kwargs['quantization_config'] = quanto_config
return kwargs
return kwargs
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
from modules import shared
if allow and (module == 'any' or module in shared.opts.trt_quantization):
@ -249,7 +183,7 @@ def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model',
def check_quant(module: str = ''):
from modules import shared
if module in shared.opts.sdnq_quantize_weights or module in shared.opts.bnb_quantization or module in shared.opts.torchao_quantization or module in shared.opts.quanto_quantization:
if module in shared.opts.sdnq_quantize_weights:
return True
return False
@ -286,21 +220,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
if debug:
log.trace(f'Quantization: type=sdnq config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_bnb_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
log.trace(f'Quantization: type=bnb config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_quanto_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
log.trace(f'Quantization: type=quanto config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_ao_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
log.trace(f'Quantization: type=torchao config={kwargs.get("quantization_config", None)}')
return kwargs
kwargs = create_trt_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
if kwargs is not None and 'quantization_config' in kwargs:
if debug:
@ -309,88 +228,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
return kwargs
def load_torchao(msg='', silent=False):
global ao # pylint: disable=global-statement
if ao is not None:
return ao
if not installed('torchao'):
install('torchao==0.10.0', quiet=True)
log.warning('Quantization: torchao installed please restart')
try:
import torchao
ao = torchao
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
from diffusers.utils import import_utils
import_utils.is_torchao_available = lambda: True
import_utils._torchao_available = True # pylint: disable=protected-access
return ao
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import torchao: {e}")
ao = None
if not silent:
raise
return None
def load_bnb(msg='', silent=False):
from modules import devices
global bnb # pylint: disable=global-statement
if bnb is not None:
return bnb
if not installed('bitsandbytes'):
if devices.backend == 'cuda':
# forcing a version will uninstall the multi-backend-refactor branch of bnb
install('bitsandbytes==0.47.0', quiet=True)
log.warning('Quantization: bitsandbytes installed please restart')
try:
import bitsandbytes
bnb = bitsandbytes
from diffusers.utils import import_utils
import_utils._bitsandbytes_available = True # pylint: disable=protected-access
import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=bitsandbytes version={bnb.__version__} fn={fn}') # pylint: disable=protected-access
return bnb
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import bitsandbytes: {e}")
bnb = None
if not silent:
raise
return None
def load_quanto(msg='', silent=False):
global optimum_quanto # pylint: disable=global-statement
if optimum_quanto is not None:
return optimum_quanto
if not installed('optimum-quanto'):
install('optimum-quanto==0.2.7', quiet=True)
log.warning('Quantization: optimum-quanto installed please restart')
try:
from optimum import quanto # pylint: disable=no-name-in-module
# disable device specific tensors because the model can't be moved between cpu and gpu with them
quanto.tensor.weights.qbits.WeightQBitsTensor.create = lambda *args, **kwargs: quanto.tensor.weights.qbits.WeightQBitsTensor(*args, **kwargs)
optimum_quanto = quanto
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
from diffusers.utils import import_utils
import_utils.is_optimum_quanto_available = lambda: True
import_utils._optimum_quanto_available = True # pylint: disable=protected-access
import_utils._optimum_quanto_version = quanto.__version__ # pylint: disable=protected-access
import_utils._replace_with_quanto_layers = diffusers.quantizers.quanto.utils._replace_with_quanto_layers # pylint: disable=protected-access
return optimum_quanto
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import optimum.quanto: {e}")
optimum_quanto = None
if not silent:
raise
return None
def load_trt(msg='', silent=False):
global trt # pylint: disable=global-statement
if trt is not None:
@ -642,138 +479,6 @@ def sdnq_quantize_weights(sd_model):
return sd_model
def optimum_quanto_model(model, op=None, sd_model=None, weights=None, activations=None):
from modules import devices, shared
quanto = load_quanto('Quantize model: type=Optimum Quanto')
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
if model.__class__.__name__ in {"FluxTransformer2DModel", "ChromaTransformer2DModel"}: # LayerNorm is not supported
exclude_list = ["transformer_blocks.*.norm1.norm", "transformer_blocks.*.norm2", "transformer_blocks.*.norm1_context.norm", "transformer_blocks.*.norm2_context", "single_transformer_blocks.*.norm.norm", "norm_out.norm"]
if model.__class__.__name__ == "ChromaTransformer2DModel":
# we ignore the distilled guidance layer because it degrades quality too much
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
exclude_list.append("distilled_guidance_layer.*")
elif model.__class__.__name__ == "QwenImageTransformer2DModel":
exclude_list = ["transformer_blocks.0.img_mod.1.weight", "time_text_embed", "img_in", "txt_in", "proj_out", "norm_out", "pos_embed"]
else:
exclude_list = None
weights = getattr(quanto, weights) if weights is not None else getattr(quanto, shared.opts.optimum_quanto_weights_type)
if activations is not None:
activations = getattr(quanto, activations) if activations != 'none' else None
elif shared.opts.optimum_quanto_activations_type != 'none':
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
else:
activations = None
model.eval()
backup_embeddings = None
if hasattr(model, "get_input_embeddings"):
backup_embeddings = copy.deepcopy(model.get_input_embeddings())
quanto.quantize(model, weights=weights, activations=activations, exclude=exclude_list)
quanto.freeze(model)
if hasattr(model, "set_input_embeddings") and backup_embeddings is not None:
model.set_input_embeddings(backup_embeddings)
if op is not None and shared.opts.optimum_quanto_shuffle_weights:
if quant_last_model_name is not None:
if "." in quant_last_model_name:
last_model_names = quant_last_model_name.split(".")
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
else:
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
devices.torch_gc(force=True, reason='quanto')
if shared.cmd_opts.medvram or shared.cmd_opts.lowvram or shared.opts.diffusers_offload_mode != "none":
quant_last_model_name = op
quant_last_model_device = model.device
else:
quant_last_model_name = None
quant_last_model_device = None
model.to(devices.device)
devices.torch_gc(force=True, reason='quanto')
return model
def optimum_quanto_weights(sd_model):
try:
t0 = time.time()
from modules import shared, devices, sd_models
if shared.opts.diffusers_offload_mode in {"balanced", "sequential"}:
log.warning(f"Quantization: type=Optimum.quanto offload={shared.opts.diffusers_offload_mode} not compatible")
return sd_model
log.info(f"Quantization: type=Optimum.quanto: modules={shared.opts.optimum_quanto_weights}")
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
quanto = load_quanto()
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_model, shared.opts.optimum_quanto_weights, op="optimum-quanto")
if quant_last_model_name is not None:
if "." in quant_last_model_name:
last_model_names = quant_last_model_name.split(".")
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
else:
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
devices.torch_gc(force=True, reason='quanto')
quant_last_model_name = None
quant_last_model_device = None
if shared.opts.optimum_quanto_activations_type != 'none':
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
else:
activations = None
if activations is not None:
def optimum_quanto_freeze(model, op=None, sd_model=None): # pylint: disable=unused-argument
quanto.freeze(model)
return model
if shared.opts.diffusers_offload_mode == "model":
sd_model.enable_model_cpu_offload(device=devices.device)
if hasattr(sd_model, "encode_prompt"):
original_encode_prompt = sd_model.encode_prompt
def encode_prompt(*args, **kwargs):
embeds = original_encode_prompt(*args, **kwargs)
sd_model.maybe_free_model_hooks() # Diffusers keeps the TE on VRAM
return embeds
sd_model.encode_prompt = encode_prompt
else:
sd_models.move_model(sd_model, devices.device)
with quanto.Calibration(momentum=0.9):
sd_model(prompt="dummy prompt", num_inference_steps=10)
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_freeze, shared.opts.optimum_quanto_weights, op="optimum-quanto-freeze")
if shared.opts.diffusers_offload_mode == "model":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, devices.cpu)
if hasattr(sd_model, "encode_prompt"):
sd_model.encode_prompt = original_encode_prompt
devices.torch_gc(force=True, reason='quanto')
t1 = time.time()
log.info(f"Quantization: type=Optimum.quanto time={t1-t0:.2f}")
except Exception as e:
log.warning(f"Quantization: type=Optimum.quanto {e}")
return sd_model
def torchao_quantization(sd_model):
from modules import shared, devices, sd_models
torchao = load_torchao()
q = torchao.quantization
fn = getattr(q, shared.opts.torchao_quantization_type, None)
if fn is None:
log.error(f"Quantization: type=TorchAO type={shared.opts.torchao_quantization_type} not supported")
return sd_model
def torchao_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
q.quantize_(model, fn(), device=devices.device)
return model
log.info(f"Quantization: type=TorchAO pipe={sd_model.__class__.__name__} quant={shared.opts.torchao_quantization_type} fn={fn} targets={shared.opts.torchao_quantization}")
try:
t0 = time.time()
sd_models.apply_function_to_model(sd_model, torchao_model, shared.opts.torchao_quantization, op="torchao")
t1 = time.time()
log.info(f"Quantization: type=TorchAO time={t1-t0:.2f}")
except Exception as e:
log.error(f"Quantization: type=TorchAO {e}")
setup_logging() # torchao uses dynamo which messes with logging so reset is needed
return sd_model
def get_dit_args(load_config: dict | None = None, module: str | None = None, device_map: bool = False, allow_quant: bool = True, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
from modules import shared, devices
config = {} if load_config is None else load_config.copy()
@ -810,12 +515,6 @@ def do_post_load_quant(sd_model, allow=True):
if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')):
log.debug('Load model: post_quant=sdnq')
sd_model = sdnq_quantize_weights(sd_model)
if len(shared.opts.optimum_quanto_weights) > 0:
log.debug('Load model: post_quant=quanto')
sd_model = optimum_quanto_weights(sd_model)
if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')):
log.debug('Load model: post_quant=torchao')
sd_model = torchao_quantization(sd_model)
if shared.opts.layerwise_quantization:
log.debug('Load model: post_quant=layerwise')
apply_layerwise(sd_model)

View File

@ -62,16 +62,6 @@ def load_t5(name=None, cache_dir=None):
elif 'fp16' in name.lower():
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'fp4' in name.lower():
model_quant.load_bnb('Load model: type=T5')
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'fp8' in name.lower():
model_quant.load_bnb('Load model: type=T5')
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'int8' in name.lower():
from modules.model_quant import create_sdnq_config
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
@ -84,18 +74,6 @@ def load_t5(name=None, cache_dir=None):
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'qint4' in name.lower():
model_quant.load_quanto('Load model: type=T5')
quantization_config = transformers.QuantoConfig(weights='int4')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'qint8' in name.lower():
model_quant.load_quanto('Load model: type=T5')
quantization_config = transformers.QuantoConfig(weights='int8')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif '/' in name:
log.debug(f'Load model: type=T5 repo={name}')
quant_config = model_quant.create_config(module='TE')

View File

@ -93,7 +93,10 @@ class Options:
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if key in self.secrets:
oldval = self.secrets.get(key, None)
else:
oldval = self.data.get(key, None)
if oldval is None:
if key in self.data_labels:
oldval = self.data_labels[key].default

View File

@ -176,6 +176,8 @@ class UpscalerESRGAN(Upscaler):
def upscale_without_tiling(model, img):
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255

View File

@ -360,7 +360,7 @@ def process_samples(p: StableDiffusionProcessing, samples):
split_tone_balance=getattr(p, 'grading_split_tone_balance', 0.5),
vignette=getattr(p, 'grading_vignette', 0.0),
grain=getattr(p, 'grading_grain', 0.0),
lut_file=getattr(p, 'grading_lut_file', ''),
lut_cube_file=getattr(p, 'grading_lut_file', ''),
lut_strength=getattr(p, 'grading_lut_strength', 1.0),
)
if processing_grading.is_active(grading_params):

View File

@ -66,7 +66,7 @@ class GradingParams:
vignette: float = 0.0
grain: float = 0.0
# lut
lut_file: str = ""
lut_cube_file: str = ""
lut_strength: float = 1.0
def __post_init__(self):
@ -179,17 +179,17 @@ def _apply_color_temp(img: torch.Tensor, kelvin: float) -> torch.Tensor:
return (img * scales).clamp(0, 1)
def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Image:
def _apply_lut(image: Image.Image, lut_cube_file: str, strength: float) -> Image.Image:
"""Apply .cube LUT file via pillow-lut-tools."""
if not lut_file or not os.path.isfile(lut_file):
if not lut_cube_file or not os.path.isfile(lut_cube_file):
return image
pillow_lut = _ensure_pillow_lut()
try:
cube = pillow_lut.load_cube_file(lut_file)
cube = pillow_lut.load_cube_file(lut_cube_file)
if strength != 1.0:
cube = pillow_lut.amplify_lut(cube, strength)
result = image.filter(cube)
debug(f'Grading LUT: file={os.path.basename(lut_file)} strength={strength}')
debug(f'Grading LUT: file={os.path.basename(lut_cube_file)} strength={strength}')
return result
except Exception as e:
log.error(f'Grading LUT: {e}')
@ -198,8 +198,8 @@ def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Imag
def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
"""Full grading pipeline: PIL -> GPU tensor -> kornia ops -> PIL."""
log.debug(f"Grading: params={params}")
kornia = _ensure_kornia()
debug(f'Grading: params={params}')
arr = np.array(image).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
tensor = tensor.to(device=devices.device, dtype=devices.dtype)
@ -246,7 +246,7 @@ def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
result = Image.fromarray(arr)
# LUT applied last (CPU, via pillow-lut-tools)
if params.lut_file:
result = _apply_lut(result, params.lut_file, params.lut_strength)
if params.lut_cube_file:
result = _apply_lut(result, params.lut_cube_file, params.lut_strength)
return result

View File

@ -76,9 +76,14 @@ class ScriptPostprocessingRunner:
script.controls = wrap_call(script.ui, script.filename, "ui")
if script.controls is None:
script.controls = {}
for control in script.controls.values():
control.custom_script_source = os.path.basename(script.filename)
inputs += list(script.controls.values())
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
for control in script.controls:
control.custom_script_source = os.path.basename(script.filename)
inputs += script.controls
else:
for control in script.controls.values():
control.custom_script_source = os.path.basename(script.filename)
inputs += list(script.controls.values())
script.args_to = len(inputs)
def scripts_in_preferred_order(self):
@ -109,11 +114,16 @@ class ScriptPostprocessingRunner:
for script in self.scripts_in_preferred_order():
jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to]
process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_args[name] = value
log.debug(f'Process: script="{script.name}" args={process_args}')
script.process(pp, **process_args)
process_args = []
process_kwargs = {}
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
for _control, value in zip(script.controls, script_args, strict=False):
process_args.append(value)
else:
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_kwargs[name] = value
log.debug(f'Process: script="{script.name}" args={process_args} kwargs={process_kwargs}')
script.process(pp, *process_args, **process_kwargs)
shared.state.end(jobid)
def create_args_for_run(self, scripts_args):
@ -139,9 +149,14 @@ class ScriptPostprocessingRunner:
continue
jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to]
process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_args[name] = value
log.debug(f'Postprocess: script={script.name} args={process_args}')
script.postprocess(filenames, **process_args)
process_args = []
process_kwargs = {}
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
for _control, value in zip(script.controls, script_args, strict=False):
process_args.append(value)
else:
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_kwargs[name] = value
log.debug(f'Postprocess: script={script.name} args={process_args} kwargs={process_kwargs}')
script.postprocess(filenames, *process_args, **process_kwargs)
shared.state.end(jobid)

View File

@ -24,8 +24,8 @@ warn_once = False
class CheckpointInfo:
def __init__(self, filename, sha=None, subfolder=None):
self.name = None
def __init__(self, filename, name=None, sha=None, subfolder=None, model_type: str = 'checkpoint'):
self.name = name
self.hash = sha
self.filename = filename
self.type = ''
@ -62,9 +62,9 @@ class CheckpointInfo:
self.sha256 = None
self.type = 'unknown'
elif os.path.isfile(filename): # ckpt or safetensor
self.name = relname
self.name = self.name or relname
self.filename = filename
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
self.sha256 = hashes.sha256_from_cache(self.filename, f"{model_type}/{relname}") or hashes.sha256_from_cache(self.filename, f"{model_type}/{name}")
self.type = ext
if 'nf4' in filename:
self.type = 'transformer'
@ -74,12 +74,12 @@ class CheckpointInfo:
else:
repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']]
if len(repo) == 0:
self.name = filename
self.name = self.name or filename
self.filename = filename
self.sha256 = None
self.type = 'unknown'
else:
self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
self.name = self.name or os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]["name"])
self.filename = repo[0]['path']
self.sha256 = repo[0]['hash']
self.type = 'diffusers'
@ -109,7 +109,7 @@ class CheckpointInfo:
return self.shorthash
def __str__(self):
return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" hash={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")'
return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" sha256={self.sha256} sha={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")'
def setup_model():
@ -160,7 +160,7 @@ def list_models():
checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
def update_model_hashes():
def update_model_hashes(model_list: dict = None, model_type: str = 'checkpoint'):
def update_model_hashes_table(rows):
html = """
<table class="simple-table">
@ -186,14 +186,16 @@ def update_model_hashes():
log.error(f'Model list: row={row} {e}')
return html.format(tbody=tbody)
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
if model_list is None:
model_list = checkpoints_list
lst = [ckpt for ckpt in model_list.values() if ckpt.hash is None]
for ckpt in lst:
ckpt.hash = model_hash(ckpt.filename)
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
lst = [ckpt for ckpt in model_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
log.info(f'Models list: hash missing={len(lst)} total={len(model_list)}')
updated = []
for ckpt in lst:
ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
ckpt.sha256 = hashes.sha256(ckpt.filename, f"{model_type}/{ckpt.name}")
ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
updated.append(ckpt)
yield update_model_hashes_table(updated)

View File

@ -798,8 +798,8 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di
"requires_safety_checker": False, # sd15 specific but we cant know ahead of time
# "use_safetensors": True,
}
if shared.opts.huggingface_token and len(shared.opts.huggingface_token) > 0:
diffusers_load_config['token'] = shared.opts.huggingface_token
# if shared.opts.huggingface_token and len(shared.opts.huggingface_token) > 0:
# diffusers_load_config['token'] = shared.opts.huggingface_token
if revision is not None:
diffusers_load_config['revision'] = revision
if shared.opts.diffusers_model_load_variant != 'default':

View File

@ -102,3 +102,4 @@ def refresh_unet_list():
name = os.path.splitext(basename)[0] if ".safetensors" in basename else basename
unet_dict[name] = file
log.info(f'Available UNets: path="{shared.opts.unet_dir}" items={len(unet_dict)}')
return unet_dict

View File

@ -151,6 +151,7 @@ def list_samplers():
modules.sd_samplers.set_samplers()
return modules.sd_samplers.all_samplers
log.debug('Initializing: default modes')
startup_offload_mode, startup_offload_min_gpu, startup_offload_max_gpu, startup_cross_attention, startup_sdp_options, startup_sdp_choices, startup_sdp_override_options, startup_sdp_override_choices, startup_offload_always, startup_offload_never = get_default_modes(cmd_opts=cmd_opts, mem_stat=mem_stat)

View File

@ -439,6 +439,7 @@ def update_token_counter(text: str):
from modules.extra_networks import parse_prompt
count_formatted = '0'
max_length = 0
visible = False
prompt, _ = parse_prompt(text)
@ -475,7 +476,7 @@ def update_token_counter(text: str):
token_counts = [len(group) - int(has_bos_token) - int(has_eos_token) for group in ids]
if len(token_counts) > 1:
visible = True
count_formatted = f"{token_counts} {sum(token_counts)}" if shared.opts.prompt_detailed_tokens else str(sum(token_counts))
count_formatted = f"{token_counts}/{sum(token_counts)}"
elif len(token_counts) == 1 and token_counts[0] > 0:
visible = True
count_formatted = str(token_counts[0])

View File

@ -166,26 +166,6 @@ def create_settings(cmd_opts):
"nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox),
"nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox),
"bnb_quantization_sep": OptionInfo("<h2>BitsAndBytes</h2>", "", gr.HTML),
"bnb_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "VAE"]}),
"bnb_quantization_type": OptionInfo("nf4", "Quantization type", gr.Dropdown, {"choices": ["nf4", "fp8", "fp4"]}),
"bnb_quantization_storage": OptionInfo("uint8", "Backend storage", gr.Dropdown, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]}),
"quanto_quantization_sep": OptionInfo("<h2>Optimum Quanto</h2>", "", gr.HTML),
"quanto_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM"]}),
"quanto_quantization_type": OptionInfo("int8", "Quantization weights type", gr.Dropdown, {"choices": ["float8", "int8", "int4", "int2"]}),
"optimum_quanto_sep": OptionInfo("<h2>Optimum Quanto: post-load</h2>", "", gr.HTML),
"optimum_quanto_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "Control", "VAE"]}),
"optimum_quanto_weights_type": OptionInfo("qint8", "Quantization weights type", gr.Dropdown, {"choices": ["qint8", "qfloat8_e4m3fn", "qfloat8_e5m2", "qint4", "qint2"]}),
"optimum_quanto_activations_type": OptionInfo("none", "Quantization activations type ", gr.Dropdown, {"choices": ["none", "qint8", "qfloat8_e4m3fn", "qfloat8_e5m2"]}),
"optimum_quanto_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox),
"torchao_sep": OptionInfo("<h2>TorchAO</h2>", "", gr.HTML),
"torchao_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "Control", "VAE"]}),
"torchao_quantization_mode": OptionInfo("auto", "Quantization mode", gr.Dropdown, {"choices": ["auto", "pre", "post"]}),
"torchao_quantization_type": OptionInfo("int8_weight_only", "Quantization type", gr.Dropdown, {"choices": ["int4_weight_only", "int8_dynamic_activation_int4_weight", "int8_weight_only", "int8_dynamic_activation_int8_weight", "float8_weight_only", "float8_dynamic_activation_float8_weight", "float8_static_activation_float8_weight"]}),
"layerwise_quantization_sep": OptionInfo("<h2>Layerwise Casting</h2>", "", gr.HTML),
"layerwise_quantization": OptionInfo([], "Layerwise casting enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}),
"layerwise_quantization_storage": OptionInfo("float8_e4m3fn", "Layerwise casting storage", gr.Dropdown, {"choices": ["float8_e4m3fn", "float8_e5m2"]}),
@ -194,14 +174,6 @@ def create_settings(cmd_opts):
"trt_quantization_sep": OptionInfo("<h2>TensorRT</h2>", "", gr.HTML),
"trt_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model"]}),
"trt_quantization_type": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": ["int8", "int4", "fp8", "nf4", "nvfp4"]}),
"nncf_compress_sep": OptionInfo("<h2>NNCF: Neural Network Compression Framework</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
"nncf_compress_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_mode": OptionInfo("INT8_SYM", "Quantization type", gr.Dropdown, {"choices": ["INT8", "INT8_SYM", "FP8", "MXFP8", "INT4_ASYM", "INT4_SYM", "FP4", "MXFP4", "NF4"], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_raito": OptionInfo(0, "Compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": cmd_opts.use_openvino}),
"nncf_quantize": OptionInfo([], "Static Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
"nncf_quantize_mode": OptionInfo("INT8", "OpenVINO activations mode", gr.Dropdown, {"choices": ["INT8", "FP8_E4M3", "FP8_E5M2"], "visible": cmd_opts.use_openvino}),
}))
# --- VAE & Text Encoder ---
options_templates.update(options_section(('vae_encoder', "Variational Auto Encoder"), {

View File

@ -584,6 +584,8 @@ def register_pages():
register_page(ExtraNetworksPageLora())
from modules.ui_extra_networks_wildcards import ExtraNetworksPageWildcards
register_page(ExtraNetworksPageWildcards())
from modules.ui_extra_networks_unet import ExtraNetworksPageUNets
register_page(ExtraNetworksPageUNets())
if shared.opts.latent_history > 0:
from modules.ui_extra_networks_history import ExtraNetworksPageHistory
register_page(ExtraNetworksPageHistory())
@ -596,7 +598,7 @@ def get_pages(title=None):
visible = shared.opts.extra_networks
pages: list[ExtraNetworksPage] = []
if 'All' in visible or visible == []: # default en sort order
visible = ['Model', 'Lora', 'Style', 'Wildcards', 'Embedding', 'VAE', 'History', 'Hypernetwork']
visible = ['Model', 'Lora', 'UNet/DiT', 'Style', 'Wildcards', 'Embedding', 'VAE', 'History', 'Hypernetwork']
titles = [page.title for page in shared.extra_networks]
if title is None:
@ -743,7 +745,7 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
with ui.tabs:
def ui_tab_change(page):
scan_visible = page in ['Model', 'Lora', 'VAE', 'Hypernetwork', 'Embedding']
scan_visible = page in ['Model', 'Lora', 'VAE', 'UNet/DiT', 'Hypernetwork', 'Embedding']
save_visible = page in ['Style']
model_visible = page in ['Model']
return [gr.update(visible=scan_visible), gr.update(visible=save_visible), gr.update(visible=model_visible)]

View File

@ -0,0 +1,43 @@
import html
import json
import os
from modules import shared, ui_extra_networks, sd_unet, hashes, modelstats
from modules.logger import log
class ExtraNetworksPageUNets(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('UNet/DiT')
def refresh(self):
return sd_unet.refresh_unet_list()
def list_items(self):
for name, filename in sd_unet.unet_dict.items():
try:
size, mtime = modelstats.stat(filename)
info = self.find_info(filename)
version = self.find_version(None, info)
record = {
"type": 'UNet/DiT',
"name": name,
"alias": os.path.splitext(os.path.basename(filename))[0],
"title": name,
"filename": filename,
"hash": hashes.sha256_from_cache(filename, f"unet/{name}"),
"preview": self.find_preview(filename),
"local_preview": f"{os.path.splitext(filename)[0]}.{shared.opts.samples_format}",
"metadata": {},
"onclick": '"' + html.escape(f"""return selectUNet({json.dumps(name)})""") + '"',
"mtime": mtime,
"size": size,
"info": info,
"description": self.find_description(filename, info),
"version": version.get("baseModel", "N/A") if info else "N/A",
}
yield record
except Exception as e:
log.debug(f'Networks error: type=vae file="{filename}" {e}')
def allowed_directories_for_previews(self):
return [v for v in [shared.opts.unet_dir] if v is not None]

View File

@ -12,6 +12,15 @@ from modules.shared import opts, log
extra_ui = []
def update_model_hashes():
from modules import sd_unet, sd_checkpoint
unets = {}
for k, v in sd_unet.unet_dict.items():
unets[k] = sd_checkpoint.CheckpointInfo(name=k, filename=v, model_type='unet')
yield from sd_models.update_model_hashes(unets, model_type='unet')
yield from sd_models.update_model_hashes(model_type='checkpoint')
def create_ui():
log.debug('UI initialize: tab=models')
dummy_component = gr.Label(visible=False)
@ -143,7 +152,7 @@ def create_ui():
with gr.Row():
model_table = gr.HTML(value='', elem_id="model_list_table")
model_checkhash_btn.click(fn=sd_models.update_model_hashes, inputs=[], outputs=[model_table])
model_checkhash_btn.click(fn=update_model_hashes, inputs=[], outputs=[model_table])
model_list_btn.click(fn=lambda: create_models_table(list(sd_models.checkpoints_list.values())), inputs=[], outputs=[model_table])
with gr.Tab(label="Metadata", elem_id="models_metadata_tab"):

View File

@ -172,15 +172,15 @@ def create_latent_inputs(tab):
hdr_sharpen = gr.Slider(minimum=-4.0, maximum=4.0, step=0.05, value=0, label="Latent sharpen", elem_id=f"{tab}_hdr_sharpen")
hdr_color = gr.Slider(minimum=0.0, maximum=16.0, step=0.1, value=0.0, label="Latent color", elem_id=f"{tab}_hdr_color")
with gr.Row(elem_id=f"{tab}_hdr_clamp_row"):
hdr_clamp = gr.Checkbox(label="Clamp", value=False, elem_id=f"{tab}_hdr_clamp")
hdr_boundary = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=4.0, label="Range", elem_id=f"{tab}_hdr_boundary")
hdr_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label="Threshold", elem_id=f"{tab}_hdr_threshold")
hdr_clamp = gr.Checkbox(label="Latent clamp", value=False, elem_id=f"{tab}_hdr_clamp")
hdr_boundary = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=4.0, label="Latent range", elem_id=f"{tab}_hdr_boundary")
hdr_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label="Latent threshold", elem_id=f"{tab}_hdr_threshold")
with gr.Row(elem_id=f"{tab}_hdr_max_row"):
hdr_maximize = gr.Checkbox(label="Maximize", value=False, elem_id=f"{tab}_hdr_maximize")
hdr_max_center = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=0.6, label="Center", elem_id=f"{tab}_hdr_max_center")
hdr_max_boundary = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Max range", elem_id=f"{tab}_hdr_max_boundary")
hdr_maximize = gr.Checkbox(label="Latent maximize", value=False, elem_id=f"{tab}_hdr_maximize")
hdr_max_center = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=0.6, label="Latent center", elem_id=f"{tab}_hdr_max_center")
hdr_max_boundary = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Latent max range", elem_id=f"{tab}_hdr_max_boundary")
with gr.Row(elem_id=f"{tab}_hdr_color_row"):
hdr_color_picker = gr.ColorPicker(label="Tint color", show_label=True, container=False, value=None, elem_id=f"{tab}_hdr_color_picker")
hdr_color_picker = gr.ColorPicker(label="Latent tint", show_label=True, container=False, value=None, elem_id=f"{tab}_hdr_color_picker")
hdr_tint_ratio = gr.Slider(label="Tint strength", minimum=-4.0, maximum=4.0, step=0.05, value=0.0, elem_id=f"{tab}_hdr_tint_ratio")
return hdr_mode, hdr_brightness, hdr_color, hdr_sharpen, hdr_clamp, hdr_boundary, hdr_threshold, hdr_maximize, hdr_max_center, hdr_max_boundary, hdr_color_picker, hdr_tint_ratio, hdr_apply_hires
@ -197,7 +197,7 @@ def create_color_inputs(tab):
grading_hue = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0, label='Hue', elem_id=f"{tab}_grading_hue")
grading_gamma = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label='Gamma', elem_id=f"{tab}_grading_gamma")
grading_sharpness = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0, label='Sharpness', elem_id=f"{tab}_grading_sharpness")
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp (K)', elem_id=f"{tab}_grading_color_temp")
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp', elem_id=f"{tab}_grading_color_temp")
with gr.Group():
gr.HTML('<h3>Tone</h3>')
with gr.Row(elem_id=f"{tab}_grading_tone_row"):
@ -211,7 +211,7 @@ def create_color_inputs(tab):
with gr.Row(elem_id=f"{tab}_grading_split_row"):
grading_shadows_tint = gr.ColorPicker(label="Shadows tint", value="#000000", elem_id=f"{tab}_grading_shadows_tint")
grading_highlights_tint = gr.ColorPicker(label="Highlights tint", value="#ffffff", elem_id=f"{tab}_grading_highlights_tint")
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Balance', elem_id=f"{tab}_grading_split_tone_balance")
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Split tone balance', elem_id=f"{tab}_grading_split_tone_balance")
with gr.Group():
gr.HTML('<h3>Effects</h3>')
with gr.Row(elem_id=f"{tab}_grading_effects_row"):
@ -220,9 +220,9 @@ def create_color_inputs(tab):
with gr.Group():
gr.HTML('<h3>LUT</h3>')
with gr.Row(elem_id=f"{tab}_grading_lut_row"):
grading_lut_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
grading_lut_cube_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
grading_lut_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label='LUT strength', elem_id=f"{tab}_grading_lut_strength")
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_file, grading_lut_strength
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_cube_file, grading_lut_strength
def create_sampler_and_steps_selection(choices, tabname, default_steps:int=20):

View File

@ -396,6 +396,13 @@ def create_quicksettings(interfaces):
inputs=[shared.settings_components['sd_vae'], dummy_component],
outputs=[shared.settings_components['sd_vae'], text_settings],
)
button_set_unet = gr.Button("Change UNet", elem_id="change_unet", visible=False)
button_set_unet.click(
fn=lambda value, _: run_settings_single(value, key="sd_unet"),
_js="function(v){ var res = desiredUNetName; desiredUNetName = ''; return [res || v, null]; }",
inputs=[shared.settings_components["sd_unet"], dummy_component],
outputs=[shared.settings_components["sd_unet"], text_settings],
)
def reference_submit(model):
if '@' not in model: # diffusers

View File

@ -18,7 +18,7 @@
"venv": ". venv/bin/activate",
"start": ". venv/bin/activate; python launch.py --debug",
"localize": "node cli/localize.js",
"packages": ". venv/bin/activate && pip install --upgrade transformers accelerate huggingface_hub safetensors tokenizers peft pytorch_lightning pylint ruff",
"packages": ". venv/bin/activate && pip install --upgrade accelerate huggingface_hub safetensors tokenizers peft pytorch_lightning pylint ruff",
"format": ". venv/bin/activate && pre-commit run --all-files",
"format-win": "venv\\scripts\\activate && pre-commit run --all-files",
"eslint": "eslint . javascript/",

View File

@ -0,0 +1,216 @@
"""Flux2/Klein-specific LoRA loading.
Handles:
- Bare BFL-format keys in state dicts (adds diffusion_model. prefix for converter)
- LoKR adapters via native module loading (bypasses diffusers PEFT system)
Installed via apply_patch() during pipeline loading.
"""
import os
import time
from modules import shared, sd_models
from modules.logger import log
from modules.lora import network, network_lokr, lora_convert
from modules.lora import lora_common as l
BARE_FLUX_PREFIXES = ("single_blocks.", "double_blocks.", "img_in.", "txt_in.",
"final_layer.", "time_in.", "single_stream_modulation.",
"double_stream_modulation_")
# BFL -> diffusers module path mapping for Flux2/Klein
F2_SINGLE_MAP = {
'linear1': 'attn.to_qkv_mlp_proj',
'linear2': 'attn.to_out',
}
F2_DOUBLE_MAP = {
'img_attn.proj': 'attn.to_out.0',
'txt_attn.proj': 'attn.to_add_out',
'img_mlp.0': 'ff.linear_in',
'img_mlp.2': 'ff.linear_out',
'txt_mlp.0': 'ff_context.linear_in',
'txt_mlp.2': 'ff_context.linear_out',
}
F2_QKV_MAP = {
'img_attn.qkv': ('attn', ['to_q', 'to_k', 'to_v']),
'txt_attn.qkv': ('attn', ['add_q_proj', 'add_k_proj', 'add_v_proj']),
}
def apply_lora_alphas(state_dict):
"""Bake kohya-format .alpha scaling into lora_down weights and remove alpha keys.
Diffusers' Flux2 converter only handles lora_A/lora_B (or lora_down/lora_up) keys.
Kohya-format LoRAs store per-layer alpha values as separate .alpha keys that the
converter doesn't consume, causing a ValueError on leftover keys. This matches the
approach used by _convert_kohya_flux_lora_to_diffusers for Flux 1.
"""
alpha_keys = [k for k in state_dict if k.endswith('.alpha')]
if not alpha_keys:
return state_dict
for alpha_key in alpha_keys:
base = alpha_key[:-len('.alpha')]
down_key = f'{base}.lora_down.weight'
if down_key not in state_dict:
continue
down_weight = state_dict[down_key]
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
state_dict[down_key] = down_weight * scale_down
up_key = f'{base}.lora_up.weight'
if up_key in state_dict:
state_dict[up_key] = state_dict[up_key] * scale_up
remaining = [k for k in state_dict if k.endswith('.alpha')]
if remaining:
log.debug(f'Network load: type=LoRA stripped {len(remaining)} orphaned alpha keys')
for k in remaining:
del state_dict[k]
return state_dict
def preprocess_f2_keys(state_dict):
"""Add 'diffusion_model.' prefix to bare BFL-format keys so
Flux2LoraLoaderMixin's format detection routes them to the converter."""
if any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in state_dict):
return state_dict
if any(k.startswith(p) for k in state_dict for p in BARE_FLUX_PREFIXES):
log.debug('Network load: type=LoRA adding diffusion_model prefix for bare BFL-format keys')
state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
return state_dict
def try_load_lokr(name, network_on_disk, lora_scale):
"""Try loading a Flux2/Klein LoRA as LoKR native modules.
Returns a Network with native modules if the state dict contains LoKR keys,
or None to fall through to the generic diffusers path.
"""
t0 = time.time()
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
if not any('.lokr_w1' in k for k in state_dict):
return None
net = load_lokr_native(name, network_on_disk, state_dict)
if len(net.modules) == 0:
log.error(f'Network load: type=LoKR name="{name}" no modules matched')
return None
log.debug(f'Network load: type=LoKR name="{name}" native modules={len(net.modules)} scale={lora_scale}')
l.timer.activate += time.time() - t0
return net
def load_lokr_native(name, network_on_disk, state_dict):
"""Load Flux2 LoKR as native modules applied at inference time.
Stores only the compact LoKR factors (w1, w2) and computes kron(w1, w2)
on-the-fly during weight application. For fused QKV modules in double
blocks, NetworkModuleLokrChunk computes the full Kronecker product and
returns only its designated Q/K/V chunk, then frees the temporary.
"""
prefix = "diffusion_model."
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
lora_convert.assign_network_names_to_compvis_modules(sd_model)
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
for key in list(state_dict.keys()):
if not key.endswith('.lokr_w1'):
continue
if not key.startswith(prefix):
continue
base = key[len(prefix):].rsplit('.lokr_w1', 1)[0]
lokr_weights = {}
for suffix in ['lokr_w1', 'lokr_w2', 'lokr_w1_a', 'lokr_w1_b', 'lokr_w2_a', 'lokr_w2_b', 'lokr_t2', 'alpha']:
full_key = f'{prefix}{base}.{suffix}'
if full_key in state_dict:
lokr_weights[suffix] = state_dict[full_key]
parts = base.split('.')
block_type, block_idx, module_suffix = parts[0], parts[1], '.'.join(parts[2:])
targets = [] # (module_path, chunk_index, num_chunks)
if block_type == 'single_blocks' and module_suffix in F2_SINGLE_MAP:
path = f'single_transformer_blocks.{block_idx}.{F2_SINGLE_MAP[module_suffix]}'
targets.append((path, None, None))
elif block_type == 'double_blocks':
if module_suffix in F2_DOUBLE_MAP:
path = f'transformer_blocks.{block_idx}.{F2_DOUBLE_MAP[module_suffix]}'
targets.append((path, None, None))
elif module_suffix in F2_QKV_MAP:
attn_prefix, proj_keys = F2_QKV_MAP[module_suffix]
for i, proj_key in enumerate(proj_keys):
path = f'transformer_blocks.{block_idx}.{attn_prefix}.{proj_key}'
targets.append((path, i, len(proj_keys)))
for module_path, chunk_index, num_chunks in targets:
network_key = "lora_transformer_" + module_path.replace(".", "_")
sd_module = sd_model.network_layer_mapping.get(network_key)
if sd_module is None:
log.warning(f'Network load: type=LoKR module not found in mapping: {network_key}')
continue
weights = network.NetworkWeights(
network_key=network_key,
sd_key=network_key,
w=dict(lokr_weights),
sd_module=sd_module,
)
if chunk_index is not None:
net.modules[network_key] = network_lokr.NetworkModuleLokrChunk(net, weights, chunk_index, num_chunks)
else:
net.modules[network_key] = network_lokr.NetworkModuleLokr(net, weights)
return net
patched = False
def apply_patch():
"""Patch Flux2LoraLoaderMixin.lora_state_dict to handle bare BFL-format keys.
When a LoRA file has bare BFL keys (no diffusion_model. prefix), the original
lora_state_dict won't detect them as AI toolkit format. This patch checks for
bare keys after the original returns and adds the prefix + re-runs conversion.
"""
global patched # pylint: disable=global-statement
if patched:
return
patched = True
from diffusers.loaders.lora_pipeline import Flux2LoraLoaderMixin
original_lora_state_dict = Flux2LoraLoaderMixin.lora_state_dict.__func__
@classmethod # pylint: disable=no-self-argument
def patched_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs):
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = preprocess_f2_keys(pretrained_model_name_or_path_or_dict)
pretrained_model_name_or_path_or_dict = apply_lora_alphas(pretrained_model_name_or_path_or_dict)
elif isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
path = str(pretrained_model_name_or_path_or_dict)
if path.endswith('.safetensors'):
try:
from safetensors import safe_open
with safe_open(path, framework="pt") as f:
keys = list(f.keys())
needs_load = (
any(k.endswith('.alpha') for k in keys)
or (not any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in keys)
and any(k.startswith(p) for k in keys for p in BARE_FLUX_PREFIXES))
)
if needs_load:
from safetensors.torch import load_file
sd = load_file(path)
sd = preprocess_f2_keys(sd)
pretrained_model_name_or_path_or_dict = apply_lora_alphas(sd)
except Exception:
pass
return original_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs)
Flux2LoraLoaderMixin.lora_state_dict = patched_lora_state_dict

View File

@ -1,25 +0,0 @@
import diffusers
import transformers
from modules import devices, model_quant
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer = None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
return transformer

View File

@ -1,361 +0,0 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
from modules.logger import log
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer, text_encoder_2 = None, None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=FLUX')
quant = model_quant.get_quant(repo_path)
try:
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
transformer, text_encoder_2 = None, None
if debug:
errors.display(e, 'FLUX:')
return transformer, text_encoder_2
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
try:
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": cache_dir,
}
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
if 'flux.1-kontext' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
elif 'flux.1-dev' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
elif 'flux.1-schnell' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-fill' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'flux.1-depth' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
elif 'shuttle' in repo_id.lower():
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
else:
log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:
log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype, cache_dir=cache_dir)
kwargs['transformer'].quantization_method = 'SVDQuant'
if shared.opts.nunchaku_attention:
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir)
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
except Exception as e:
log.error(f'Quantization: {e}')
errors.display(e, 'Quantization:')
return kwargs
def load_transformer(file_path): # triggered by opts.sd_unet change
if file_path is None or not os.path.exists(file_path):
return None
transformer = None
quant = model_quant.get_quant(file_path)
diffusers_load_config = {
"torch_dtype": devices.dtype,
"cache_dir": shared.opts.hfcache_dir,
}
if quant is not None and quant != 'none':
log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
from modules import ggml
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == "fp8":
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif quant in {'qint8', 'qint4'}:
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
if _transformer is not None:
transformer = _transformer
elif quant in {'fp8', 'fp4', 'nf4'}:
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif 'nf4' in quant:
from pipelines.flux.flux_nf4 import load_flux_nf4
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
if _transformer is not None:
transformer = _transformer
else:
quant_args = model_quant.create_bnb_config({})
if quant_args:
log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
from pipelines.flux.flux_nf4 import load_flux_nf4
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
if transformer is not None:
return transformer
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
if transformer is None:
log.error('Failed to load UNet model')
shared.opts.sd_unet = 'Default'
return transformer
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
allow_post_quant = False
prequantized = model_quant.get_quant(checkpoint_info.path)
log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
debug(f'Load model: type=FLUX config={diffusers_load_config}')
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
# unload current model
sd_models.unload_model_weights()
shared.sd_model = None
devices.torch_gc(force=True, reason='load')
if shared.opts.teacache_enabled:
from modules import teacache
log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
# load overrides if any
if shared.opts.sd_unet != 'Default':
try:
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
if transformer is None:
shared.opts.sd_unet = 'Default'
sd_unet.failed_unet.append(shared.opts.sd_unet)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load UNet: {e}")
shared.opts.sd_unet = 'Default'
if debug:
errors.display(e, 'FLUX UNet:')
if shared.opts.sd_text_encoder != 'Default':
try:
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
from modules.model_te import load_t5, load_vit_l
if 'vit-l' in shared.opts.sd_text_encoder.lower():
text_encoder_1 = load_vit_l()
else:
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load T5: {e}")
shared.opts.sd_text_encoder = 'Default'
if debug:
errors.display(e, 'FLUX T5:')
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
try:
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
from modules import sd_vae
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
if os.path.exists(vae_file):
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
except Exception as e:
log.error(f"Load model: type=FLUX failed to load VAE: {e}")
shared.opts.sd_vae = 'Default'
if debug:
errors.display(e, 'FLUX VAE:')
# load quantized components if any
if prequantized == 'nf4':
try:
from pipelines.flux.flux_nf4 import load_flux_nf4
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
if debug:
errors.display(e, 'FLUX NF4:')
if prequantized == 'qint8' or prequantized == 'qint4':
try:
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder_2 = _text_encoder
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
# initialize pipeline with pre-loaded components
kwargs = {}
if transformer is not None:
kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
if text_encoder_1 is not None:
kwargs['text_encoder'] = text_encoder_1
model_te.loaded_te = shared.opts.sd_text_encoder
if text_encoder_2 is not None:
kwargs['text_encoder_2'] = text_encoder_2
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
kwargs['vae'] = vae
if repo_id == 'sayakpaul/flux.1-dev-nf4':
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
if 'Fill' in repo_id:
cls = diffusers.FluxFillPipeline
elif 'Canny' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Depth' in repo_id:
cls = diffusers.FluxControlPipeline
elif 'Kontext' in repo_id:
cls = diffusers.FluxKontextPipeline
from diffusers import pipelines
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
else:
cls = diffusers.FluxPipeline
log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
try:
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
except Exception:
pass
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
fn = checkpoint_info.path
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
if fn.endswith('.safetensors') and os.path.isfile(fn):
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
allow_post_quant = True
else:
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
# release memory
transformer = None
text_encoder_1 = None
text_encoder_2 = None
vae = None
for k in kwargs.keys():
kwargs[k] = None
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True, reason='load')
return pipe, allow_post_quant

View File

@ -1,201 +0,0 @@
"""
Copied from: https://github.com/huggingface/diffusers/issues/9165
"""
import os
import torch
import torch.nn as nn
from transformers.quantizers.quantizers_utils import get_module_from_name
from huggingface_hub import hf_hub_download
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
import safetensors.torch
from modules import shared, devices, model_quant
from modules.logger import log
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
def _replace_with_bnb_linear(
model,
method="nf4",
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
bnb = model_quant.load_bnb('Load model: type=FLUX')
for name, module in model.named_children():
if isinstance(module, nn.Linear):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if method == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt( # pylint: disable=protected-access
in_features,
out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
has_been_replaced = True
else:
model._modules[name] = bnb.nn.Linear4bit( # pylint: disable=protected-access
in_features,
out_features,
module.bias is not None,
compute_dtype=devices.dtype,
compress_statistics=False,
quant_type="nf4",
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module) # pylint: disable=protected-access
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) # pylint: disable=protected-access
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
return model, has_been_replaced
def check_quantized_param(
model,
param_name: str,
) -> bool:
bnb = model_quant.load_bnb('Load model: type=FLUX')
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # pylint: disable=protected-access
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
def create_quantized_param(
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict=None,
unexpected_keys=None,
pre_quantized=False
):
bnb = model_quant.load_bnb('Load model: type=FLUX')
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters: # pylint: disable=protected-access
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # pylint: disable=protected-access
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)
else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
def load_flux_nf4(checkpoint_info, prequantized: bool = True):
transformer = None
text_encoder_2 = None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
if os.path.exists(repo_path) and os.path.isfile(repo_path):
ckpt_path = repo_path
elif os.path.exists(repo_path) and os.path.isdir(repo_path) and os.path.exists(os.path.join(repo_path, "diffusion_pytorch_model.safetensors")):
ckpt_path = os.path.join(repo_path, "diffusion_pytorch_model.safetensors")
else:
ckpt_path = hf_hub_download(repo_path, filename="diffusion_pytorch_model.safetensors", cache_dir=shared.opts.diffusers_dir)
original_state_dict = safetensors.torch.load_file(ckpt_path)
if 'sayakpaul' in repo_path:
converted_state_dict = original_state_dict # already converted
else:
try:
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
except Exception as e:
log.error(f"Load model: type=FLUX Failed to convert UNET: {e}")
if debug:
from modules import errors
errors.display(e, 'FLUX convert:')
converted_state_dict = original_state_dict
with init_empty_weights():
from diffusers import FluxTransformer2DModel
config = FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
transformer = FluxTransformer2DModel.from_config(config).to(devices.dtype)
expected_state_dict_keys = list(transformer.state_dict().keys())
_replace_with_bnb_linear(transformer, "nf4")
try:
for param_name, param in converted_state_dict.items():
if param_name not in expected_state_dict_keys:
continue
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(devices.dtype)
if not check_quantized_param(transformer, param_name):
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
else:
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
except Exception as e:
transformer, text_encoder_2 = None, None
log.error(f"Load model: type=FLUX failed to load UNET: {e}")
if debug:
from modules import errors
errors.display(e, 'FLUX:')
del original_state_dict
devices.torch_gc(force=True, reason='load')
return transformer, text_encoder_2

View File

@ -1,74 +0,0 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from modules import shared, errors, devices, sd_models, model_quant
from modules.logger import log
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_flux_quanto(checkpoint_info):
transformer, text_encoder_2 = None, None
quanto = model_quant.load_quanto('Load model: type=FLUX')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
text_encoder_2_dtype = text_encoder_2.dtype
if text_encoder_2_dtype != devices.dtype:
try:
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
except Exception:
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
except Exception as e:
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'FLUX Quanto:')
return transformer, text_encoder_2

View File

@ -41,18 +41,6 @@ def load_flux(checkpoint_info, diffusers_load_config=None):
transformer = None
text_encoder_2 = None
# handle prequantized models
prequantized = model_quant.get_quant(checkpoint_info.path)
if prequantized == 'nf4':
from pipelines.flux.flux_nf4 import load_flux_nf4
transformer, text_encoder_2 = load_flux_nf4(checkpoint_info)
elif prequantized == 'qint8' or prequantized == 'qint4':
from pipelines.flux.flux_quanto import load_flux_quanto
transformer, text_encoder_2 = load_flux_quanto(checkpoint_info)
elif prequantized == 'fp4' or prequantized == 'fp8':
from pipelines.flux.flux_bnb import load_flux_bnb
transformer = load_flux_bnb(checkpoint_info, diffusers_load_config)
# handle transformer svdquant if available, t5 is handled inside load_text_encoder
if transformer is None and model_quant.check_nunchaku('Model'):
from pipelines.flux.flux_nunchaku import load_flux_nunchaku

View File

@ -31,6 +31,9 @@ def load_flux2(checkpoint_info, diffusers_load_config=None):
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux2"] = diffusers.Flux2Pipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux2"] = diffusers.Flux2Pipeline
from pipelines.flux import flux2_lora
flux2_lora.apply_patch()
del text_encoder
del transformer
sd_hijack_te.init_hijack(pipe)

View File

@ -34,6 +34,9 @@ def load_flux2_klein(checkpoint_info, diffusers_load_config=None):
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux2klein"] = diffusers.Flux2KleinPipeline
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux2klein"] = diffusers.Flux2KleinPipeline
from pipelines.flux import flux2_lora
flux2_lora.apply_patch()
del text_encoder
del transformer
sd_hijack_te.init_hijack(pipe)

View File

@ -27,13 +27,13 @@ peft==0.18.1
httpx==0.28.1
requests==2.32.3
tqdm==4.67.3
accelerate==1.12.0
accelerate==1.13.0
einops==0.8.1
huggingface_hub==1.5.0
huggingface_hub==1.7.2
numpy==2.1.2
pandas==2.3.1
protobuf==6.33.5
pytorch_lightning==2.6.0
pytorch_lightning==2.6.1
urllib3==1.26.19
Pillow==10.4.0
timm==1.0.24

15
scripts/color_grading.py Normal file
View File

@ -0,0 +1,15 @@
from modules import scripts_postprocessing, ui_sections, processing_grading
class ScriptPostprocessingColorGrading(scripts_postprocessing.ScriptPostprocessing):
name = "Color Grading"
def ui(self):
ui_controls = ui_sections.create_color_inputs('process')
ui_controls_dict = {control.label.replace(" ", "_").replace(".", "").lower(): control for control in ui_controls}
return ui_controls_dict
def process(self, pp: scripts_postprocessing.PostprocessedImage, *args, **kwargs): # pylint: disable=arguments-differ
grading_params = processing_grading.GradingParams(*args, **kwargs)
if processing_grading.is_active(grading_params):
pp.image = processing_grading.grade_image(pp.image, grading_params)

View File

@ -124,7 +124,7 @@ def process(
# defines script for dual-mode usage
class Script(scripts.Script):
class ScriptNudeNet(scripts.Script):
# see below for all available options and callbacks
# <https://github.com/vladmandic/automatic/blob/master/modules/scripts.py#L26>
@ -148,7 +148,7 @@ class Script(scripts.Script):
# defines postprocessing script for dual-mode usage
class ScriptPostprocessing(scripts_postprocessing.ScriptPostprocessing):
class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing):
name = 'NudeNet'
order = 10000

View File

@ -2,8 +2,8 @@ import gradio as gr
from modules import video, scripts_postprocessing
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
name = "Video"
class ScriptPostprocessingVideo(scripts_postprocessing.ScriptPostprocessing):
name = "Create Video"
def ui(self):
with gr.Accordion('Create video', open = False, elem_id="postprocess_video_accordion"):
@ -18,7 +18,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
]
with gr.Row():
gr.HTML("<span>&nbsp Video</span><br>")
gr.HTML("<span>&nbsp Create video from generated images</span><br>")
with gr.Row():
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type")
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration")

446
scripts/rocm/rocm_mgr.py Normal file
View File

@ -0,0 +1,446 @@
import os
import sys
from pathlib import Path
from typing import Dict, Optional
import installer
from modules.logger import log
from modules.json_helpers import readfile, writefile
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
def _check_rocm() -> bool:
from modules import shared
if getattr(shared.cmd_opts, 'use_rocm', False):
return True
if installer.torch_info.get('type') == 'rocm':
return True
import torch # pylint: disable=import-outside-toplevel
return hasattr(torch.version, 'hip') and torch.version.hip is not None
is_rocm = _check_rocm()
CONFIG = Path(os.path.abspath(os.path.join('data', 'rocm.json')))
_cache: Optional[Dict[str, str]] = None # loaded once, invalidated on save
# Metadata key written into rocm.json to record which architecture profile is active.
# Not an environment variable — always skipped during env application but preserved in the
# saved config so that arch-safety enforcement is consistent across restarts.
_ARCH_KEY = "_rocm_arch"
# Vars that must never appear in the process environment.
#
# _DTYPE_UNSAFE: alter FP16 inference dtype — must be cleared regardless of config
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — DEBUG alias: routes all FP16 convs through BF16 exponent math
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — API-level alias: same BF16-exponent effect
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM — unstable experimental FP16 path
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 — changes FP16 WrW atomic accumulation
#
# SOLVER_DISABLED_BY_DEFAULT: every solver known to be incompatible with this runtime
# (FP32-only, training-only WrW/BWD, fixed-geometry mismatches, XDLOPS/CDNA-only, arch-specific).
# Actively unsetting these ensures no inherited shell value can re-enable them.
_DTYPE_UNSAFE = {
"MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL",
"MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16",
}
# _UNSET_VARS: hard-blocked vars that are DELETED from the process env and never written,
# regardless of saved config. Limited to dtype-corrupting vars only.
# IMPORTANT: SOLVER_DISABLED_BY_DEFAULT is intentionally NOT included here.
# When a solver var is absent (unset) MIOpen still calls IsApplicable() on every
# conv-find — wasted probing overhead. When a var is explicitly "0" MIOpen skips
# IsApplicable() immediately. Solver defaults flow through the config loop as "0"
# (their ROCM_ENV_VARS default is "0") so they are explicitly set to "0" in the env.
_UNSET_VARS = _DTYPE_UNSAFE
# Additional environment vars that must be removed from the process before MIOpen loads.
# These are not MIOpen solver toggles but can corrupt MIOpen's runtime behaviour:
# HIP_PATH / HIP_PATH_71 — point to the system AMD ROCm install; override the venv-bundled
# _rocm_sdk_devel DLLs with a potentially mismatched system version
# QML_*/QT_* — QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
# PyTorch but can conflict with Gradio's embedded Qt helpers
# PYENV_VIRTUALENV_DISABLE_PROMPT — pyenv noise that confuses venv detection
_EXTRA_CLEAR_VARS = {
"HIP_PATH",
"HIP_PATH_71",
"PYENV_VIRTUALENV_DISABLE_PROMPT",
"QML_DISABLE_DISK_CACHE",
"QML_FORCE_DISK_CACHE",
"QT_DISABLE_SHADER_DISK_CACHE",
# PERF_VALS vars are NOT boolean toggles — MIOpen reads them as perf-config strings.
# If inherited from a parent shell with value "1", MIOpen's GetPerfConfFromEnv parses
# "1" as a degenerate config and can return dtype=float32 output from FP16 tensors.
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS",
}
# Solvers whose MIOpen IsApplicable() explicitly rejects non-FP32 tensors.
# They are safe to leave enabled in FP32 mode. When the active dtype is FP16 or BF16
# we force them OFF so MIOpen skips the IsApplicable probe entirely — avoids overhead on
# every conv shape find. These are NOT in _UNSET_VARS because they are valid in FP32.
_FP32_ONLY_SOLVERS = {
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution — FP32 only (MIOpen source: IsFp32 check)
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 — FP32 only
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd — FP32 only
}
def _resolve_dtype() -> str:
"""Return the resolved active compute dtype: 'FP16', 'BF16', 'FP32', or '' (not yet known).
Prefers the resolved devices.dtype (post test_fp16/bf16) over the raw opts string."""
try:
import torch # pylint: disable=import-outside-toplevel
from modules import devices as _dev # pylint: disable=import-outside-toplevel
if _dev.dtype is not None:
if _dev.dtype == torch.float16:
return 'FP16'
if _dev.dtype == torch.bfloat16:
return 'BF16'
if _dev.dtype == torch.float32:
return 'FP32'
except Exception:
pass
try:
from modules import shared as _sh # pylint: disable=import-outside-toplevel
v = getattr(getattr(_sh, 'opts', None), 'cuda_dtype', None)
if v in ('FP16', 'BF16', 'FP32'):
return v
except Exception:
pass
return ''
# --- venv helpers ---
def _get_venv() -> str:
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
def _expand_venv(value: str) -> str:
return value.replace("{VIRTUAL_ENV}", _get_venv())
def _collapse_venv(value: str) -> str:
venv = _get_venv()
if venv and value.startswith(venv):
return "{VIRTUAL_ENV}" + value[len(venv):]
return value
# --- dropdown helpers ---
def _dropdown_display(stored_val: str, options) -> str:
if options and isinstance(options[0], tuple):
return next((label for label, val in options if val == str(stored_val)), str(stored_val))
return str(stored_val)
def _dropdown_stored(display_val: str, options) -> str:
if options and isinstance(options[0], tuple):
return next((val for label, val in options if label == str(display_val)), str(display_val))
return str(display_val)
def _dropdown_choices(options):
if options and isinstance(options[0], tuple):
return [label for label, _ in options]
return options
# --- config I/O ---
def load_config() -> Dict[str, str]:
global _cache # pylint: disable=global-statement
if _cache is None:
file_existed = CONFIG.exists()
if file_existed:
data = readfile(str(CONFIG), lock=True, as_type="dict")
_cache = data if data else {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
# Purge unsafe vars from a stale saved config and re-persist only if the file existed.
# When running without a saved config (first run / after Delete), load_config() must
# never create the file — that only happens via save_config() on Apply or Apply Profile.
dirty = {k for k in _cache if k in _UNSET_VARS or (k != _ARCH_KEY and k not in ROCM_ENV_VARS)}
if dirty:
_cache = {k: v for k, v in _cache.items() if k not in dirty}
writefile(_cache, str(CONFIG))
log.debug(f'ROCm load_config: purged {len(dirty)} stale/unsafe var(s) from saved config')
else:
_cache = {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
log.debug(f'ROCm load_config: path={CONFIG} existed={file_existed} items={len(_cache)}')
return _cache
def save_config(config: Dict[str, str]) -> None:
global _cache # pylint: disable=global-statement
sanitized = {k: v for k, v in config.items() if k not in _UNSET_VARS}
# Enforce arch-incompatible solvers to "0" before writing.
# Prevents malformed edits (UI or JSON hand-edit) from persisting incompatible "1" values.
arch = sanitized.get(_ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
for var in unavailable:
if var in sanitized and sanitized[var] != "0":
sanitized[var] = "0"
log.debug(f'ROCm save_config: clamped arch-incompatible var={var} arch={arch}')
writefile(sanitized, str(CONFIG))
_cache = sanitized
def apply_env(config: Optional[Dict[str, str]] = None) -> None:
if config is None:
config = load_config()
for var in _UNSET_VARS | _EXTRA_CLEAR_VARS:
if var in os.environ:
del os.environ[var]
for var, value in config.items():
if var == _ARCH_KEY:
continue
if var in _UNSET_VARS:
continue
if var not in ROCM_ENV_VARS:
continue
meta = ROCM_ENV_VARS.get(var, {})
if meta.get("options"):
value = _dropdown_stored(str(value), meta["options"])
expanded = _expand_venv(str(value))
if expanded == "":
continue
os.environ[var] = expanded
# Arch safety net: hard-force all hardware-incompatible vars to "0" in the env.
# This runs *after* the config loop so it overrides any stale "1" that survived in the JSON.
# Source of truth: rocm_profiles.UNAVAILABLE[arch] — vars with no supporting hardware.
arch = config.get(_ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
if unavailable:
for var in unavailable:
os.environ[var] = "0"
dtype_str = _resolve_dtype()
if dtype_str in ('FP16', 'BF16'):
for var in _FP32_ONLY_SOLVERS:
os.environ[var] = "0"
def apply_all(names: list, values: list) -> None:
config = load_config().copy()
arch = config.get(_ARCH_KEY, "")
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
for name, value in zip(names, values):
if name not in ROCM_ENV_VARS:
log.warning(f'ROCm apply_all: unknown variable={name}')
continue
# Arch safety net: silently clamp incompatible solvers back to "0".
# The UI may send the current checkbox state even for greyed-out vars.
if name in unavailable:
config[name] = "0"
continue
meta = ROCM_ENV_VARS[name]
if meta["widget"] == "checkbox":
if value is None:
pass # Gradio passed None (component not interacted with) — leave config unchanged
else:
config[name] = "1" if value else "0"
elif meta["widget"] == "radio":
stored = _dropdown_stored(str(value), meta["options"])
valid = {v for _, v in meta["options"]} if meta["options"] and isinstance(meta["options"][0], tuple) else set(meta["options"] or [])
if stored in valid:
config[name] = stored
# else: value was None/invalid — leave the existing saved value untouched
else:
if meta.get("options"):
value = _dropdown_stored(str(value), meta["options"])
config[name] = _collapse_venv(str(value))
save_config(config)
apply_env(config)
def reset_defaults() -> None:
defaults = {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
# Preserve the active arch key so safety nets survive a defaults reset.
arch = load_config().get(_ARCH_KEY, "")
if arch:
defaults[_ARCH_KEY] = arch
save_config(defaults)
apply_env(defaults)
log.info(f'ROCm reset_defaults: config reset to defaults arch={arch or "(none)"}')
def clear_env() -> None:
"""Remove all managed ROCm vars and known noise vars from os.environ without writing to disk."""
cleared = 0
for var in ROCM_ENV_VARS:
if var in os.environ:
del os.environ[var]
cleared += 1
for var in _UNSET_VARS | _EXTRA_CLEAR_VARS:
if var in os.environ:
del os.environ[var]
cleared += 1
log.info(f'ROCm clear_env: cleared={cleared}')
def delete_config() -> None:
"""Delete the saved config file, clear all vars, and wipe the MIOpen user DB cache."""
import shutil # pylint: disable=import-outside-toplevel
global _cache # pylint: disable=global-statement
clear_env()
if CONFIG.exists():
CONFIG.unlink()
log.info(f'ROCm delete_config: deleted {CONFIG}')
_cache = None
# Delete the MIOpen user DB (~/.miopen/db) — stale entries can cause solver mismatches
miopen_db = Path(os.path.expanduser('~')) / '.miopen' / 'db'
if miopen_db.exists():
shutil.rmtree(miopen_db, ignore_errors=True)
log.info(f'ROCm delete_config: wiped MIOpen user DB at {miopen_db}')
else:
log.debug(f'ROCm delete_config: MIOpen user DB not found at {miopen_db} — nothing to wipe')
def apply_profile(name: str) -> None:
"""Merge an architecture profile on top of the current config, then save and apply."""
profile = rocm_profiles.PROFILES.get(name)
if profile is None:
log.warning(f'ROCm apply_profile: unknown profile={name}')
return
config = load_config().copy()
config.update(profile)
config[_ARCH_KEY] = name # stamp the active arch so safety nets survive restarts
save_config(config)
apply_env(config)
log.info(f'ROCm apply_profile: profile={name} overrides={len(profile)}')
def _hip_version_from_file(db_path: Path) -> str:
"""Parse HIP_VERSION_* keys from .hipVersion in the SDK bin folder."""
hip_ver_file = db_path / ".hipVersion"
if not hip_ver_file.exists():
return ""
kv = {}
for line in hip_ver_file.read_text(errors="ignore").splitlines():
if "=" in line and not line.startswith("#"):
k, _, v = line.partition("=")
kv[k.strip()] = v.strip()
major = kv.get("HIP_VERSION_MAJOR", "")
minor = kv.get("HIP_VERSION_MINOR", "")
patch = kv.get("HIP_VERSION_PATCH", "")
git = kv.get("HIP_VERSION_GITHASH", "")
if major:
return f"{major}.{minor}.{patch} ({git})"
return ""
def _pkg_version(name: str) -> str:
try:
import importlib.metadata as _m # pylint: disable=import-outside-toplevel
return _m.version(name)
except Exception:
return "n/a"
def _db_file_summary(path: Path, patterns: list) -> dict:
"""Return {filename: 'N KB'} for files matching any of the given glob patterns."""
out = {}
for pat in patterns:
for f in sorted(path.glob(pat)):
kb = f.stat().st_size // 1024
out[f.name] = f"{kb} KB"
return out
def _user_db_summary(path: Path) -> dict:
"""Return {filename: 'N KB, M entries'} for user MIOpen DB txt files."""
out = {}
for pat in ("*.udb.txt", "*.ufdb.txt"):
for f in sorted(path.glob(pat)):
kb = f.stat().st_size // 1024
try:
lines = sum(1 for _ in f.open(errors="ignore"))
except Exception:
lines = 0
out[f.name] = f"{kb} KB, {lines} entries"
return out
def info() -> dict:
config = load_config()
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
# --- ROCm / HIP package versions ---
rocm_pkgs = {}
for pkg in ("rocm", "rocm-sdk-core", "rocm-sdk-devel"):
v = _pkg_version(pkg)
if v != "n/a":
rocm_pkgs[pkg] = v
libs_pkg = _pkg_version("rocm-sdk-libraries-gfx103x-dgpu")
if libs_pkg != "n/a":
rocm_pkgs["rocm-sdk-libraries (gfx103x)"] = libs_pkg
hip_ver = _hip_version_from_file(db_path)
if not hip_ver:
try:
import torch # pylint: disable=import-outside-toplevel
hip_ver = getattr(torch.version, "hip", "") or ""
except Exception:
pass
rocm_section = {}
if hip_ver:
rocm_section["hip_version"] = hip_ver
rocm_section.update(rocm_pkgs)
# --- Torch ---
torch_section = {}
try:
import torch # pylint: disable=import-outside-toplevel
torch_section["version"] = torch.__version__
torch_section["hip"] = getattr(torch.version, "hip", None) or "n/a"
except Exception:
pass
# --- GPU ---
gpu_section = [dict(g) for g in installer.gpu_info]
# --- System DB ---
sdb = {"path": str(db_path)}
if db_path.exists():
solver_db = _db_file_summary(db_path, ["*.db.txt"])
find_db = _db_file_summary(db_path, ["*.HIP.fdb.txt", "*.fdb.txt"])
kernel_db = _db_file_summary(db_path, ["*.kdb"])
if solver_db:
sdb["solver_db"] = solver_db
if find_db:
sdb["find_db"] = find_db
if kernel_db:
sdb["kernel_db"] = kernel_db
else:
sdb["exists"] = False
# --- User DB (~/.miopen/db) ---
user_db_path = Path.home() / ".miopen" / "db"
udb = {"path": str(user_db_path), "exists": user_db_path.exists()}
if user_db_path.exists():
ufiles = _user_db_summary(user_db_path)
if ufiles:
udb["files"] = ufiles
return {
"rocm": rocm_section,
"torch": torch_section,
"gpu": gpu_section,
"system_db": sdb,
"user_db": udb,
}
# Apply saved config to os.environ at import time (only when ROCm is present)
if is_rocm:
try:
apply_env()
except Exception as _e:
print(f"[rocm_mgr] Warning: failed to apply env at import: {_e}", file=sys.stderr)
else:
log.debug('ROCm is not installed — skipping rocm_mgr env apply')

View File

@ -0,0 +1,253 @@
"""
Architecture-specific MIOpen solver profiles for AMD GCN/RDNA GPUs.
Sources:
https://rocm.docs.amd.com/projects/MIOpen/en/develop/reference/env_variables.html
Key axis: consumer RDNA GPUs have NO XDLOPS hardware (that's CDNA/Instinct only).
RDNA2 (gfx1030): RX 6000 series
RDNA3 (gfx1100): RX 7000 series adds Fury Winograd, wider MPASS
RDNA4 (gfx1200): RX 9000 series adds Rage Winograd, wider MPASS
Each profile is a dict of {var: value} that will be MERGED on top of the
current config (general vars like DB path / log level are preserved).
"""
from typing import Dict
# ---------------------------------------------------------------------------
# Shared: everything that must be OFF on ALL consumer RDNA (no XDLOPS hw)
# ---------------------------------------------------------------------------
_XDLOPS_OFF: Dict[str, str] = {
# GTC XDLOPS (CDNA-only)
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS_NHWC": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS_NHWC": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS_NHWC": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_DLOPS_NCHWC": "0",
# HIP XDLOPS variants (CDNA-only)
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R5_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4_PADDED_GEMM_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4_PADDED_GEMM_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_XDLOPS_EMULATE": "0",
"MIOPEN_DEBUG_IMPLICIT_GEMM_XDLOPS_INLINE_ASM": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_GROUP_BWD_XDLOPS": "0",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS_AI_HEUR": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_FWD_V4R4_XDLOPS_ADD_VECTOR_LOAD_GEMMN_TUNE_PARAM": "0",
# 3D XDLOPS (CDNA-only; no 3D conv XDLOPS on consumer RDNA)
"MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
"MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS": "0",
"MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS": "0",
# Composable Kernel (requires XDLOPS / CDNA)
"MIOPEN_DEBUG_CONV_CK_IGEMM_FWD_V6R1_DLOPS_NCHW": "0",
"MIOPEN_DEBUG_CONV_CK_IGEMM_FWD_BIAS_ACTIV": "0",
"MIOPEN_DEBUG_CONV_CK_IGEMM_FWD_BIAS_RES_ADD_ACTIV": "0",
# MLIR (CDNA-only in practice)
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS": "0",
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS": "0",
# MP BD Winograd (Multi-pass Block-Decomposed — CDNA / high-end only)
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F2X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F3X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F4X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F5X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F6X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F2X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F3X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F4X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F5X3": "0",
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F6X3": "0",
}
# ---------------------------------------------------------------------------
# RDNA2 — gfx1030 (RX 6000 series)
# No XDLOPS, no Fury/Rage Winograd, MPASS limited to F3x2/F3x3
# ASM IGEMM: V4R1 variants only; HIP IGEMM: non-XDLOPS V4R1/R4 only
# ---------------------------------------------------------------------------
RDNA2: Dict[str, str] = {
**_XDLOPS_OFF,
# General settings (architecture-independent; set here so all profiles cover them)
"MIOPEN_SEARCH_CUTOFF": "0",
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": "0",
# Core algo enables — FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
"MIOPEN_DEBUG_CONV_FFT": "1",
"MIOPEN_DEBUG_CONV_DIRECT": "1",
"MIOPEN_DEBUG_CONV_GEMM": "1",
"MIOPEN_DEBUG_CONV_WINOGRAD": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM": "1",
"MIOPEN_DEBUG_CONV_IMMED_FALLBACK": "1",
"MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK": "1",
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK": "0",
# Kernel backends
"MIOPEN_DEBUG_GCN_ASM_KERNELS": "1",
"MIOPEN_DEBUG_HIP_KERNELS": "1",
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "1",
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "1",
"MIOPEN_DEBUG_ATTN_SOFTMAX": "1",
# Direct ASM — dtype notes
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward — enabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches — disabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) — disabled for inference
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW3X3": "0",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW1X1": "0",
# PERF_VALS intentionally blank: MIOpen reads this as a config string not a boolean;
# setting to "1" causes GetPerfConfFromEnv to use a degenerate config and return float32
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS": "",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "1",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "1",
# NAIVE_CONV_FWD: scalar FP32 reference solver — IsApplicable does NOT reliably filter for FP16;
# can be selected for unusual shapes (e.g. VAE decoder 3-ch output) and returns dtype=float32
"MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD": "0",
# Direct OCL — dtype notes
# FWD / FWD1X1: FP32/FP16 forward — enabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
# FWD11X11: requires 11*11 kernel — no SD match — disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
# FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16;
# can produce dtype=float32 output for FP16 inputs — disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWDGEN": "0",
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient — disabled
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW2": "0",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW53": "0",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW1X1": "0",
# Winograd RxS — dtype per MIOpen docs
# WINOGRAD_3X3: FP32-only — harmless (IsApplicable rejects for fp16); enabled
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "1",
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW — keep enabled (fp16 fwd/bwd path exists)
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "1",
# RXS_FWD_BWD: FP32/FP16 — explicitly the fp16-capable subset
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "1",
# RXS_WRW: FP32 WrW only — training-only, disabled for inference fp16 profile
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_WRW": "0",
# RXS_F3X2: FP32/FP16 Fwd/Bwd
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "1",
# RXS_F2X3: FP32/FP16 Fwd/Bwd (group convolutions)
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "1",
# RXS_F2X3_G1: FP32/FP16 Fwd/Bwd (non-group convolutions)
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "1",
# FUSED_WINOGRAD: FP32-only — harmless (IsApplicable rejects for fp16); enabled
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "1",
# PERF_VALS intentionally blank: same reason as ASM_1X1U — not a boolean, config string
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS": "",
# Fury/Rage Winograd — NOT available on RDNA2
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "0",
# MPASS — only F3x2 and F3x3 are safe on RDNA2
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "1",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "1",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2": "0",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3": "0",
# ASM Implicit GEMM — forward V4R1 only; no GTC/XDLOPS on RDNA2
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_V4R1": "0",
# HIP Implicit GEMM — non-XDLOPS V4R1/R4 forward only
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1": "0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4": "0",
# Group Conv XDLOPS / CK default kernels — RDNA3/4 only, not available on RDNA2
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "0",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "0",
}
# ---------------------------------------------------------------------------
# RDNA3 — gfx1100 (RX 7000 series)
# Fury Winograd added; MPASS F3x4 enabled; Group Conv XDLOPS + CK default kernels enabled
# ---------------------------------------------------------------------------
RDNA3: Dict[str, str] = {
**RDNA2,
# Fury Winograd — introduced for gfx1100 (RDNA3)
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "1",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "1",
# Wider MPASS on RDNA3
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "1",
# Group Conv XDLOPS / CK — available from gfx1100 (RDNA3) onwards
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "1",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "1",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "1",
}
# ---------------------------------------------------------------------------
# RDNA4 — gfx1200 (RX 9000 series)
# Rage Winograd added; MPASS F3x5 enabled
# ---------------------------------------------------------------------------
RDNA4: Dict[str, str] = {
**RDNA3,
# Rage Winograd — introduced for gfx1200 (RDNA4)
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "1",
# Wider MPASS on RDNA4
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "1",
}
PROFILES: Dict[str, Dict[str, str]] = {
"RDNA2": RDNA2,
"RDNA3": RDNA3,
"RDNA4": RDNA4,
}
# Vars that are architecturally unavailable (no supporting hardware) per arch.
# These will be visually marked in the UI with strikethrough.
_UNAVAILABLE_ALL_RDNA = set(_XDLOPS_OFF.keys())
UNAVAILABLE: Dict[str, set] = {
"RDNA2": _UNAVAILABLE_ALL_RDNA | {
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2",
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
},
"RDNA3": _UNAVAILABLE_ALL_RDNA | {
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3",
},
"RDNA4": _UNAVAILABLE_ALL_RDNA | {
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3",
},
}

253
scripts/rocm/rocm_vars.py Normal file
View File

@ -0,0 +1,253 @@
from typing import Dict, Any, List, Tuple
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
"MIOPEN_GEMM_ENFORCE_BACKEND": {
"default": "1",
"desc": "Enforce GEMM backend",
"widget": "dropdown",
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
"restart_required": False,
},
"MIOPEN_FIND_MODE": {
"default": "2",
"desc": "MIOpen Find Mode",
"widget": "dropdown",
"options": [("1 - NORMAL", "1"), ("2 - FAST", "2"), ("3 - HYBRID", "3"), ("5 - DYNAMIC_HYBRID", "5"), ("6 - TRUST_VERIFY", "6"), ("7 - TRUST_VERIFY_FULL", "7")],
"restart_required": True,
},
"MIOPEN_FIND_ENFORCE": {
"default": "1",
"desc": "MIOpen Find Enforce",
"widget": "dropdown",
"options": [("1 - NONE", "1"), ("2 - DB_UPDATE", "2"), ("3 - SEARCH", "3"), ("4 - SEARCH_DB_UPDATE", "4"), ("5 - DB_CLEAN", "5")],
"restart_required": True,
},
"MIOPEN_SEARCH_CUTOFF": {
"default": "0",
"desc": "Enable early termination of suboptimal searches",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": True,
},
"MIOPEN_SYSTEM_DB_PATH": {
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
"desc": "MIOpen system DB path",
"widget": "textbox",
"options": None,
"restart_required": True,
},
"MIOPEN_LOG_LEVEL": {
"default": "0",
"desc": "MIOpen log verbosity level",
"widget": "dropdown",
"options": [("0 - Default", "0"), ("1 - Quiet", "1"), ("3 - Error", "3"), ("4 - Warning", "4"), ("5 - Info", "5"), ("6 - Detail", "6"), ("7 - Trace", "7")],
"restart_required": False,
},
"MIOPEN_DEBUG_ENABLE": {
"default": "0",
"desc": "Enable MIOpen logging",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": False,
},
"ROCBLAS_LAYER": {
"default": "0",
"desc": "rocBLAS logging",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - Trace", "1"), ("2 - Bench", "2"), ("3 - Trace+Bench", "3"), ("4 - Profile", "4"), ("5 - Trace+Profile", "5"), ("6 - Bench+Profile", "6"), ("7 - All", "7")],
"restart_required": False,
},
"HIPBLASLT_LOG_LEVEL": {
"default": "0",
"desc": "hipBLASLt logging",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
"restart_required": False,
},
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
"default": "0",
"desc": "Deterministic convolution (reproducible results, may be slower)",
"widget": "dropdown",
"options": [("0 - Off", "0"), ("1 - On", "1")],
"restart_required": False,
},
}
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
# Removed entirely — not representable in the UI, cannot be set by users:
# WRW (weight-gradient) and BWD (data-gradient) — training passes only, never run during inference
# XDLOPS/CK CDNA-exclusive (MI100/MI200/MI300 matrix engine variants) — not on any RDNA
# Fixed-geometry (5x10, 7x7-ImageNet, 11x11) — shapes never appear in SD/video inference
# FP32-reference (NAIVE_CONV_FWD, FWDGEN) — IsApplicable() unreliable for FP16/BF16
# Wide MPASS (F3x4..F7x3) — kernel sizes that cannot match any SD convolution shape
# Disabled by default (added but off): RDNA3/4-only — Group Conv XDLOPS, CK default kernels
_SOLVER_DESCS: Dict[str, str] = {}
_SOLVER_DESCS.update({
"MIOPEN_DEBUG_CONV_FFT": "Enable FFT solver",
"MIOPEN_DEBUG_CONV_DIRECT": "Enable Direct solver",
"MIOPEN_DEBUG_CONV_GEMM": "Enable GEMM solver",
"MIOPEN_DEBUG_CONV_WINOGRAD": "Enable Winograd solver",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM": "Enable Implicit GEMM solver",
})
_SOLVER_DESCS.update({
"MIOPEN_DEBUG_CONV_IMMED_FALLBACK": "Enable Immediate Fallback",
"MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK": "Enable AI Immediate Mode Fallback",
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK": "Force Immediate Mode Fallback",
})
_SOLVER_DESCS.update({
"MIOPEN_DEBUG_GCN_ASM_KERNELS": "Enable GCN ASM kernels",
"MIOPEN_DEBUG_HIP_KERNELS": "Enable HIP kernels",
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "Enable OpenCL convolutions",
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "Enable OpenCL Wave64 NOWGP",
"MIOPEN_DEBUG_ATTN_SOFTMAX": "Enable Attention Softmax",
})
_SOLVER_DESCS.update({
# Direct ASM — FWD inference only (WRW, fixed-geometry, FP32-reference removed)
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "Enable Direct ASM 3x3U",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "Enable Direct ASM 1x1U",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "Enable Direct ASM 1x1UV2",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "Enable Direct ASM 1x1U Search Optimized",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "Enable Direct ASM 1x1U AI Heuristic",
})
_SOLVER_DESCS.update({
# Direct OCL — FWD inference only (WRW, FWD11X11 fixed-geom, FWDGEN FP32-ref removed)
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "Enable Direct OCL FWD",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "Enable Direct OCL FWD1X1",
})
_SOLVER_DESCS.update({
# Winograd FWD — WRW removed; Fury/Rage kept as RDNA3/4 inference (off by default)
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "Enable AMD Winograd 3x3",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "Enable AMD Winograd RxS",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "Enable AMD Winograd RxS FWD",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "Enable AMD Winograd RxS F3x2",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "Enable AMD Winograd RxS F2x3",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "Enable AMD Winograd RxS F2x3 G1",
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "Enable AMD Fused Winograd",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "Enable AMD Winograd Fury RxS F2x3",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "Enable AMD Winograd Fury RxS F3x2",
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "Enable AMD Winograd Rage RxS F2x3",
})
_SOLVER_DESCS.update({
# Multi-pass Winograd — only F3x2/F3x3 match typical 3x3 SD shapes; wider kernels removed
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "Enable AMD Winograd MPASS F3x2",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "Enable AMD Winograd MPASS F3x3",
})
_SOLVER_DESCS.update({
# Implicit GEMM FWD — BWD/WRW (training), CDNA-exclusive XDLOPS variants removed
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "Enable ASM Implicit GEMM FWD V4R1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "Enable ASM Implicit GEMM FWD V4R1 1x1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "Enable HIP Implicit GEMM FWD V4R1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "Enable HIP Implicit GEMM FWD V4R4",
})
_SOLVER_DESCS.update({
# Group Conv XDLOPS FWD — RDNA3/4 (gfx1100+) only; disabled by default
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "Enable Group Conv Implicit GEMM XDLOPS FWD",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "Enable Group Conv Implicit GEMM XDLOPS FWD AI Heuristic",
# CK (Composable Kernel) default kernels — RDNA3/4 (gfx1100+); disabled by default
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "Enable CK (Composable Kernel) default kernels",
})
# Solvers still in the registry but disabled by default.
# FORCE_IMMED_MODE_FALLBACK — overrides FIND_MODE entirely, defeats tuning DB
# Fury RxS F2x3/F3x2 — RDNA3/4-only; harmless on RDNA2 but won't select
# Rage RxS F2x3 — RDNA4-only
# Group Conv XDLOPS — RDNA3/4-only (gfx1100+)
# CK_DEFAULT_KERNELS — RDNA3/4-only (gfx1100+)
SOLVER_DISABLED_BY_DEFAULT = {
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2",
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
}
SOLVER_DTYPE_TAGS: Dict[str, str] = {
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "FP16/FP32",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "FP16/FP32",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "FP16/FP32",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "FP16/FP32",
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "FP32",
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "FP16/FP32",
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "FP16/FP32",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "FP16/FP32",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "FP16/FP32",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "FP16/FP32",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "FP16/FP32",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "FP16/BF16",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "FP16/BF16",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "FP16/BF16/FP32",
}
# Build full merged var registry
ROCM_ENV_VARS: Dict[str, Dict[str, Any]] = {}
ROCM_ENV_VARS.update(GENERAL_VARS)
for _var, _desc in _SOLVER_DESCS.items():
ROCM_ENV_VARS[_var] = {
"default": "0" if _var in SOLVER_DISABLED_BY_DEFAULT else "1",
"desc": _desc,
"widget": "checkbox",
"options": None,
"dtype": SOLVER_DTYPE_TAGS.get(_var),
"restart_required": False,
}
# UI group ordering for solver sections
SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
("Algorithm/Solver Group Enables", [
"MIOPEN_DEBUG_CONV_FFT", "MIOPEN_DEBUG_CONV_DIRECT", "MIOPEN_DEBUG_CONV_GEMM",
"MIOPEN_DEBUG_CONV_WINOGRAD", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM",
]),
("Immediate Fallback Mode", [
"MIOPEN_DEBUG_CONV_IMMED_FALLBACK", "MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK",
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK",
]),
("Build Method Toggles", [
"MIOPEN_DEBUG_GCN_ASM_KERNELS", "MIOPEN_DEBUG_HIP_KERNELS",
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS", "MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP",
"MIOPEN_DEBUG_ATTN_SOFTMAX",
]),
("Direct ASM Solver Toggles", [
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U", "MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2",
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED", "MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR",
]),
("Direct OpenCL Solver Toggles", [
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD", "MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1",
]),
("Winograd Solver Toggles", [
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", "MIOPEN_DEBUG_AMD_WINOGRAD_RXS",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", "MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1", "MIOPEN_DEBUG_AMD_FUSED_WINOGRAD",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3",
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2", "MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
]),
("Multi-pass Winograd Toggles", [
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2", "MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3",
]),
("Implicit GEMM Toggles", [
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4",
]),
("Group Conv / CK Toggles (RDNA3/4+)", [
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS",
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR",
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
]),
]

203
scripts/rocm_ext.py Normal file
View File

@ -0,0 +1,203 @@
import gradio as gr
import installer
from modules import scripts_manager, shared
# rocm_mgr exposes package-internal helpers (prefixed _) that are intentionally called here
# pylint: disable=protected-access
class Script(scripts_manager.Script):
def title(self):
return "ROCm: Advanced Config"
def show(self, _is_img2img):
if shared.cmd_opts.use_rocm or installer.torch_info.get('type') == 'rocm':
return scripts_manager.AlwaysVisible # script should be visible only if rocm is detected or forced
return False
def ui(self, _is_img2img):
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
return []
from scripts.rocm import rocm_mgr, rocm_vars # pylint: disable=no-name-in-module
config = rocm_mgr.load_config()
var_names = []
components = []
def _make_component(name, meta, cfg):
val = cfg.get(name, meta["default"])
widget = meta["widget"]
if widget == "checkbox":
dtype_tag = meta.get("dtype")
label = f"[{dtype_tag}] {meta['desc']}" if dtype_tag else meta["desc"]
return gr.Checkbox(label=label, value=(val == "1"), elem_id=f"rocm_var_{name.lower()}")
if widget == "dropdown":
choices = rocm_mgr._dropdown_choices(meta["options"])
display = rocm_mgr._dropdown_display(val, meta["options"])
return gr.Dropdown(label=meta["desc"], choices=choices, value=display, elem_id=f"rocm_var_{name.lower()}")
return gr.Textbox(label=meta["desc"], value=rocm_mgr._expand_venv(val), lines=1)
def _info_html():
d = rocm_mgr.info()
rows = []
def section(title):
rows.append(f"<tr><th colspan='2' style='padding-top:6px;text-align:left;color:var(--sd-main-accent-color)'>{title}</th></tr>")
def row(k, v):
rows.append(f"<tr><td style='color:var(--sd-muted-color);width:38%;padding:2px 8px;border-bottom:1px solid var(--sd-panel-border-color)'>{k}</td><td style='color:var(--sd-label-color);padding:2px 8px;border-bottom:1px solid var(--sd-panel-border-color)'>{v}</td></tr>")
section("ROCm / HIP")
for k, v in d.get("rocm", {}).items():
row(k, v)
section("System DB")
sdb = d.get("system_db", {})
row("path", sdb.get("path", ""))
for sub in ("solver_db", "find_db", "kernel_db"):
for fname, sz in sdb.get(sub, {}).items():
row(sub.replace("_", " "), f"{fname} &nbsp; {sz}")
section("User DB (~/.miopen/db)")
udb = d.get("user_db", {})
row("path", udb.get("path", ""))
for fname, finfo in udb.get("files", {}).items():
row(fname, finfo)
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
with gr.Row():
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>Set your database and solver selections based on GPU profile or individually.</p><br><p>Enable cuDNN in Backend Settings to activate MIOpen.</p>")
with gr.Row():
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
btn_reset = gr.Button("Defaults", elem_id="rocm_btn_reset", size="sm")
btn_clear = gr.Button("Clear Run Vars", elem_id="rocm_btn_clear", size="sm")
btn_delete = gr.Button("Delete UserDb", variant="stop", elem_id="rocm_btn_delete", size="sm")
with gr.Row():
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
style_out = gr.HTML("")
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
# General vars (dropdowns, textboxes, checkboxes)
with gr.Group():
gr.HTML("<h3>MIOpen Settings</h3><hr>")
for name, meta in rocm_vars.GENERAL_VARS.items():
comp = _make_component(name, meta, config)
var_names.append(name)
components.append(comp)
# Solver groups (all checkboxes, grouped by section)
for group_name, varlist in rocm_vars.SOLVER_GROUPS:
with gr.Group():
gr.HTML(f"<h3>{group_name}</h3><hr>")
for name in varlist:
meta = rocm_vars.ROCM_ENV_VARS[name]
comp = _make_component(name, meta, config)
var_names.append(name)
components.append(comp)
gr.HTML("<br><center><div style='margin:0 Auto'><a href='https://rocm.docs.amd.com/projects/MIOpen/en/develop/reference/env_variables.html' target='_blank'>&#128196; MIOpen Environment Variables Reference</a></div></center><br>")
def _autosave_field(name, value):
meta = rocm_vars.ROCM_ENV_VARS[name]
stored = rocm_mgr._dropdown_stored(str(value), meta["options"])
cfg = rocm_mgr.load_config()
cfg[name] = stored
rocm_mgr.save_config(cfg)
rocm_mgr.apply_env(cfg)
for name, comp in zip(var_names, components):
meta = rocm_vars.ROCM_ENV_VARS[name]
if meta["widget"] == "dropdown":
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
def apply_fn(*values):
rocm_mgr.apply_all(var_names, list(values))
saved = rocm_mgr.load_config()
result = [gr.update(value="")]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
val = saved.get(name, meta["default"])
if meta["widget"] == "checkbox":
result.append(gr.update(value=val == "1"))
elif meta["widget"] == "dropdown":
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
else:
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
return result
def _build_style(unavailable):
if not unavailable:
return ""
rules = " ".join(
f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}"
for v in unavailable
)
return f"<style>{rules}</style>"
def reset_fn():
rocm_mgr.reset_defaults()
updated = rocm_mgr.load_config()
result = [gr.update(value="")]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
val = updated.get(name, meta["default"])
if meta["widget"] == "checkbox":
result.append(gr.update(value=val == "1"))
elif meta["widget"] == "dropdown":
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
else:
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
return result
def clear_fn():
rocm_mgr.clear_env()
result = [gr.update(value="")]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
if meta["widget"] == "checkbox":
result.append(gr.update(value=False))
elif meta["widget"] == "dropdown":
result.append(gr.update(value=rocm_mgr._dropdown_display(meta["default"], meta["options"])))
else:
result.append(gr.update(value=""))
return result
def delete_fn():
rocm_mgr.delete_config()
result = [gr.update(value="")]
for name in var_names:
meta = rocm_vars.ROCM_ENV_VARS[name]
if meta["widget"] == "checkbox":
result.append(gr.update(value=False))
elif meta["widget"] == "dropdown":
result.append(gr.update(value=rocm_mgr._dropdown_display(meta["default"], meta["options"])))
else:
result.append(gr.update(value=""))
return result
def profile_fn(arch):
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
rocm_mgr.apply_profile(arch)
updated = rocm_mgr.load_config()
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
result = [gr.update(value=_build_style(unavailable))]
for pname in var_names:
meta = rocm_vars.ROCM_ENV_VARS[pname]
val = updated.get(pname, meta["default"])
if meta["widget"] == "checkbox":
result.append(gr.update(value=val == "1"))
elif meta["widget"] == "dropdown":
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
else:
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
return result
btn_info.click(fn=_info_html, inputs=[], outputs=[info_out], show_progress='hidden')
btn_apply.click(fn=apply_fn, inputs=components, outputs=[style_out] + components, show_progress='hidden')
btn_reset.click(fn=reset_fn, inputs=[], outputs=[style_out] + components, show_progress='hidden')
btn_clear.click(fn=clear_fn, inputs=[], outputs=[style_out] + components, show_progress='hidden')
btn_delete.click(fn=delete_fn, inputs=[], outputs=[style_out] + components, show_progress='hidden')
btn_rdna2.click(fn=lambda: profile_fn("RDNA2"), inputs=[], outputs=[style_out] + components, show_progress='hidden')
btn_rdna3.click(fn=lambda: profile_fn("RDNA3"), inputs=[], outputs=[style_out] + components, show_progress='hidden')
btn_rdna4.click(fn=lambda: profile_fn("RDNA4"), inputs=[], outputs=[style_out] + components, show_progress='hidden')
return components

55
test/reformat.js Normal file
View File

@ -0,0 +1,55 @@
const fs = require('fs');
/**
* Custom stringifier that switches to minified format at a specific depth
*/
const mixedStringify = (data, maxDepth, indent = 2, currentDepth = 0) => {
if (currentDepth >= maxDepth) {
return JSON.stringify(data);
}
const spacing = ' '.repeat(indent * currentDepth);
const nextSpacing = ' '.repeat(indent * (currentDepth + 1));
if (Array.isArray(data)) {
if (data.length === 0) return '[]';
const items = data.map((item) => nextSpacing + mixedStringify(item, maxDepth, indent, currentDepth + 1));
return `[\n${items.join(',\n')}\n${spacing}]`;
}
if (typeof data === 'object' && data !== null) {
const keys = Object.keys(data);
if (keys.length === 0) return '{}';
const items = keys.map((key) => {
const value = mixedStringify(data[key], maxDepth, indent, currentDepth + 1);
return `${nextSpacing}"${key}": ${value}`;
});
return `{\n${items.join(',\n')}\n${spacing}}`;
}
return JSON.stringify(data);
};
// Capture CLI arguments
const [,, inputFile, outputFile, maxDepth] = process.argv;
console.log(`Input File: ${inputFile}, Output File: ${outputFile}, Max Depth: ${maxDepth}`);
if (!inputFile || !outputFile || !maxDepth) {
console.log('Usage: node reformat.js <input.json> <output.json> <depth>');
process.exit(1);
}
try {
// Read input file
const rawData = fs.readFileSync(inputFile, 'utf8');
const jsonData = JSON.parse(rawData);
// Reformat with mixed depth
const result = mixedStringify(jsonData, parseInt(maxDepth, 10));
// Write output file
fs.writeFileSync(outputFile, result);
console.log(`Success! File saved to ${outputFile} (levels expanded: ${maxDepth})`);
} catch (err) {
console.error('Error processing JSON:', err.message);
}

View File

@ -95,7 +95,7 @@ def test_grading_params_defaults():
assert p.split_tone_balance == 0.5
assert p.vignette == 0.0
assert p.grain == 0.0
assert p.lut_file == ""
assert p.lut_cube_file == ""
assert p.lut_strength == 1.0
return True

2
wiki

@ -1 +1 @@
Subproject commit 33dbd026a2e2fb7311d545a3b2d2db0363bb887f
Subproject commit 99f4e13d03191b5269b869c71283d7fcf9c98f60