mirror of https://github.com/vladmandic/automatic
Merge branch 'dev' into RUF013
commit
ff247d8fd2
|
|
@ -17,6 +17,7 @@ __pycache__
|
|||
/data/cache.json
|
||||
/data/themes.json
|
||||
/data/installer.json
|
||||
/data/rocm.json
|
||||
node_modules
|
||||
pnpm-lock.yaml
|
||||
package-lock.json
|
||||
|
|
@ -57,6 +58,7 @@ tunableop_results*.csv
|
|||
!webui.sh
|
||||
!package.json
|
||||
!requirements.txt
|
||||
!constraints.txt
|
||||
!/data
|
||||
!/models/VAE-approx
|
||||
!/models/VAE-approx/model.pt
|
||||
|
|
|
|||
79
CHANGELOG.md
79
CHANGELOG.md
|
|
@ -1,38 +1,48 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-03-20
|
||||
## Update for 2026-03-24
|
||||
|
||||
### Highlights for 2026-03-20
|
||||
### Highlights for 2026-03-4
|
||||
|
||||
This release brings massive code refactoring to modernize codebase and removal of some obsolete features. Leaner & Faster!
|
||||
And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and *Anima-Preview-2*
|
||||
And since its a bit quieter period when it comes to new models, notable additions would be : *FireRed-Image-Edit* *SkyWorks-UniPic-3* and new *Anima-Preview*
|
||||
|
||||
If you're on Windows platform, we have a brand new [All-in-one Installer & Launcher](https://github.com/vladmandic/sdnext-launcher): simply download [exe or zip](https://github.com/vladmandic/sdnext-launcher/releases) and done!
|
||||
|
||||
*What else*? Really a lot!
|
||||
New color grading module, updated localization with new languages and improved translations, new civitai integration module, several new upscalers, improvements to LLM/VLM in captioning and prompt enhance, a lot of new control preprocessors, new realtime server info panel, some new UI themes
|
||||
New color grading module, updated localization with new languages and improved translations, new civitai integration module, new finetunes loader, several new upscalers, improvements to LLM/VLM in captioning and prompt enhance, a lot of new control preprocessors, new realtime server info panel, some new UI themes
|
||||
And major work on API hardening: security, rate limits, secrets handling, new endpoints, etc.
|
||||
But also many smaller quality-of-life improvements - for full details, see [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md)
|
||||
|
||||
*Note*: Purely due to size of changes, clean install is recommended!
|
||||
|
||||
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
|
||||
|
||||
### Details for 2026-03-20
|
||||
### Details for 2026-03-24
|
||||
|
||||
- **Models**
|
||||
- [Google Flash 3.1 Image](https://ai.google.dev/gemini-api/docs/models/gemini-3-flash-preview) a.k.a. *Nano Banana 2*
|
||||
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0)
|
||||
- [FireRed Image Edit](https://huggingface.co/FireRedTeam/FireRed-Image-Edit-1.0) *1.0 and 1.1*
|
||||
*Note*: FireRed is a fine-tune of Qwen-Image-Edit regardless of its claim as a new base-model
|
||||
- [Skyworks UniPic-3](https://huggingface.co/Skywork/Unipic3), *Consistency and DMD* variants to reference/community section
|
||||
*Note*: UniPic-3 is a fine-tune of Qwen-Image-Edit with new distillation regardless of its claim of major changes
|
||||
- [Anima Preview-v2](https://huggingface.co/circlestone-labs/Anima)
|
||||
- **Image manipulation**
|
||||
- new **color grading** module
|
||||
- update **latent corrections** *(former HDR Corrections)* and expand allowed models
|
||||
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
|
||||
**upscaling** engine with suport for new upscaling model families
|
||||
- add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2*
|
||||
- add two new **interpolation** methods: *HQX* and *ICB*
|
||||
- use high-quality [sharpfin](https://github.com/drhead/Sharpfin) accelerated library
|
||||
when available (*cuda-only*)
|
||||
- **upscalers**: extend chainner support for additional models
|
||||
- new **Color grading** module
|
||||
apply basic corrections to your images: brightness,contrast,saturation,shadows,highlights
|
||||
move to professional photo corrections: hue,gamma,sharpness,temperature
|
||||
correct tone: shadows,midtones,highlights
|
||||
add effects: vignette,grain
|
||||
apply professional lut-table using .cube file
|
||||
*hint* color grading is available as step during generate or as processing item for already existing images
|
||||
- **Upscaling**
|
||||
add support for [spandrel](https://github.com/chaiNNer-org/spandrel) engine with suport for new upscaling model families
|
||||
add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2*
|
||||
add two new **interpolation** methods: *HQX* and *ICB*
|
||||
use high-quality [sharpfin](https://github.com/drhead/Sharpfin) accelerated library
|
||||
extend `chainner` support for additional models
|
||||
- update **Latent corrections** *(former HDR Corrections)*
|
||||
expand allowed models
|
||||
- **Captioning / Prompt Enhance**
|
||||
- new models: **Qwen-3.5**, **Mistral-3** in multiple variations
|
||||
- new models: multiple *heretic* and *abliterated* finetunes for **Qwen, Gemma, Mistral**
|
||||
|
|
@ -43,32 +53,38 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
|
|||
- new **pre-processors**:
|
||||
*anyline, depth_anything v2, dsine, lotus, marigold normals, oneformer, rtmlib pose, sam2, stablenormal, teed, vitpose*
|
||||
- **Features**
|
||||
- **secrets** handling: new `secrets.json` and special handling for tokens/keys/passwords
|
||||
- **Secrets** handling: new `secrets.json` and special handling for tokens/keys/passwords
|
||||
used to be treated like any other `config.json` param which can cause security issues
|
||||
- pipelines: add **ZImageInpaint**
|
||||
- rewritten **civitai** module
|
||||
- rewritten **CivitAI** module
|
||||
browse/discover mode with sort, period, type/base dropdowns; URL paste; subfolder sorting; auto-browse; dynamic dropdowns
|
||||
- **hires**: allow using different lora in refiner prompt
|
||||
- **nunchaku** models are now listed in networks tab as reference models
|
||||
- **HiRes**: allow using different lora in refiner prompt
|
||||
- **Nunchaku** models are now listed in networks tab as reference models
|
||||
instead of being used implicitly via quantization
|
||||
- improve image **metadata** parser for foreign metadata (e.g. XMP)
|
||||
- improve image **Metadata** parser for foreign metadata (e.g. XMP)
|
||||
- **Compute**
|
||||
- **ROCm** advanced configuration and tuning, thanks @resonantsky
|
||||
see *main interface -> scripts -> rocm advanced config*
|
||||
- **ROCm** support for additional AMD GPUs: `gfx103X`, thanks @crashingalexsan
|
||||
- **Cuda** `torch==2.10` removed support for `rtx1000` series, use following before first startup:
|
||||
- **Cuda** `torch==2.10` removed support for `rtx1000` series and older GPUs
|
||||
use following before first startup to force installation of `torch==2.9.1` with `cuda==12.6`:
|
||||
> `set TORCH_COMMAND='torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu126'`
|
||||
- **UI**
|
||||
- new panel: **server info** with detailed runtime informaton
|
||||
- **localization** improved translation quality and new translations locales:
|
||||
- legacy panels **T2I** and **I2I** are disabled by default
|
||||
you can re-enable them in *settings -> ui -> hide legacy tabs*
|
||||
- new panel: **Server Info** with detailed runtime informaton
|
||||
- **Networks** add **UNet/DiT**
|
||||
- **Localization** improved translation quality and new translations locales:
|
||||
*en, en1, en2, en3, en4, hr, es, it, fr, de, pt, ru, zh, ja, ko, hi, ar, bn, ur, id, vi, tr, sr, po, he, xx, yy, qq, tlh*
|
||||
yes, this now includes stuff like *latin, esperanto, arabic, hebrew, klingon* and a lot more!
|
||||
and also introduce some pseudo-locales such as: *techno-babbel*, *for-n00bs*
|
||||
*hint*: click on locale icon in bottom-left corner to cycle through available locales, or set default in *settings -> ui*
|
||||
- **server settings** new section in *settings*
|
||||
- **kanvas** add paste image from clipboard
|
||||
- **themes** add *CTD-NT64Light*, *CTD-NT64Medium* and *CTD-NT64Dark*, thanks @resonantsky
|
||||
- **themes** add *Vlad-Neomorph*
|
||||
- **gallery** add option to auto-refresh gallery, thanks @awsr
|
||||
- **token counters** add per-section display for supported models, thanks @awsr
|
||||
- **Server settings** new section in *settings*
|
||||
- **Kanvas** add paste image from clipboard
|
||||
- **Themes** add *CTD-NT64Light*, *CTD-NT64Medium* and *CTD-NT64Dark*, thanks @resonantsky
|
||||
- **Themes** add *Vlad-Neomorph*
|
||||
- **Gallery** add option to auto-refresh gallery, thanks @awsr
|
||||
- **Token counters** add per-section display for supported models, thanks @awsr
|
||||
- **API**
|
||||
- **rate limiting**: global for all endpoints, guards against abuse and denial-of-service type of attacks
|
||||
configurable in *settings -> server settings*
|
||||
|
|
@ -76,7 +92,11 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
|
|||
- new `/sdapi/v1/torch` endpoint for torch info (backend, version, etc.)
|
||||
- new `/sdapi/v1/gpu` endpoint for GPU info
|
||||
- new `/sdapi/v1/rembg` endpoint for background removal
|
||||
- new `/sdadpi/v1/unet` endpoint to list available unets/dits
|
||||
- use rate limiting for api logging
|
||||
- **Obsoleted**
|
||||
- removed support for additional quantization engines: *BitsAndBytes, TorchAO, Optimum-Quanto, NNCF*
|
||||
*note*: SDNQ is quantization engine of choice for SD.Next
|
||||
- **Internal**
|
||||
- `python==3.13` full support
|
||||
- `python==3.14` initial support
|
||||
|
|
@ -147,6 +167,7 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
|
|||
- improve video generation progress tracking
|
||||
- handle startup with bad `scripts` more gracefully
|
||||
- thread-safety for `error-limiter`, thanks @awsr
|
||||
- add `lora` support for flux2-klein
|
||||
|
||||
## Update for 2026-02-04
|
||||
|
||||
|
|
|
|||
113
README.md
113
README.md
|
|
@ -1,8 +1,14 @@
|
|||
<div align="center">
|
||||
<img src="https://github.com/vladmandic/sdnext/raw/master/html/logo-transparent.png" width=200 alt="SD.Next">
|
||||
<img src="https://github.com/vladmandic/sdnext/raw/master/html/logo-transparent.png" width=200 alt="SD.Next: AI art generator logo">
|
||||
|
||||
# SD.Next: All-in-one WebUI for AI generative image and video creation and captioning
|
||||
# SD.Next: All-in-one WebUI
|
||||
|
||||
SD.Next is a powerful, open-source WebUI app for AI image and video generation, built on Stable Diffusion and supporting dozens of advanced models. Create, caption, and process images and videos with a modern, cross-platform interface—perfect for artists, researchers, and AI enthusiasts.
|
||||
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
[](https://discord.gg/VjvR2tabEX)
|
||||
|
|
@ -17,61 +23,63 @@
|
|||
## Table of contents
|
||||
|
||||
- [Documentation](https://vladmandic.github.io/sdnext-docs/)
|
||||
- [SD.Next Features](#sdnext-features)
|
||||
- [Model support](#model-support)
|
||||
- [Platform support](#platform-support)
|
||||
- [SD.Features](#features--capabilities)
|
||||
- [Supported AI Models](#supported-ai-models)
|
||||
- [Supported Platforms & Hardware](#supported-platforms--hardware)
|
||||
- [Getting started](#getting-started)
|
||||
|
||||
## SD.Next Features
|
||||
### Screenshot: Desktop interface
|
||||
|
||||
All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
|
||||
- Fully localized:
|
||||
▹ **English | Chinese | Russian | Spanish | German | French | Italian | Portuguese | Japanese | Korean**
|
||||
- Desktop and Mobile support!
|
||||
- Multiple [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
|
||||
- Multi-platform!
|
||||
▹ **Windows | Linux | MacOS | nVidia CUDA | AMD ROCm | Intel Arc / IPEX XPU | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
|
||||
<div align="center">
|
||||
<img src="https://github.com/vladmandic/sdnext/raw/dev/html/screenshot-robot.jpg" alt="SD.Next: AI art generator desktop interface screenshot" width="90%">
|
||||
</div>
|
||||
|
||||
### Screenshot: Mobile interface
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/ced9fe0c-d2c2-46d1-94a7-8f9f2307ce38" alt="SD.Next: AI art generator mobile interface screenshot" width="35%">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
## Features & Capabilities
|
||||
|
||||
SD.Next is feature-rich with a focus on performance, flexibility, and user experience. Key features include:
|
||||
- [Multi-platform](#platform-support!
|
||||
- Many [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
|
||||
- Fully localized to ~15 languages and with support for many [UI themes](https://vladmandic.github.io/sdnext-docs/Themes/)!
|
||||
- [Desktop](#screenshot-desktop-interface) and [Mobile](#screenshot-mobile-interface) support!
|
||||
- Platform specific auto-detection and tuning performed on install
|
||||
- Optimized processing with latest `torch` developments with built-in support for model compile and quantize
|
||||
Compile backends: *Triton | StableFast | DeepCache | OneDiff | TeaCache | etc.*
|
||||
Quantization methods: *SDNQ | BitsAndBytes | Optimum-Quanto | TorchAO / LayerWise*
|
||||
- **Captioning** with 150+ **OpenCLiP** models, **Tagger** with **WaifuDiffusion** and **DeepDanbooru** models, and 20+ built-in **VLMs**
|
||||
- Built in installer with automatic updates and dependency management
|
||||
|
||||
<br>
|
||||
### Unique features
|
||||
|
||||
**Desktop** interface
|
||||
<div align="center">
|
||||
<img src="https://github.com/vladmandic/sdnext/raw/dev/html/screenshot-robot.jpg" alt="screenshot-modernui-desktop" width="90%">
|
||||
</div>
|
||||
|
||||
**Mobile** interface
|
||||
<div align="center">
|
||||
<img src="https://github.com/user-attachments/assets/ced9fe0c-d2c2-46d1-94a7-8f9f2307ce38" alt="screenshot-modernui-mobile" width="35%">
|
||||
</div>
|
||||
|
||||
For screenshots and information on other available themes, see [Themes](https://vladmandic.github.io/sdnext-docs/Themes/)
|
||||
SD.Next includes many features not found in other WebUIs, such as:
|
||||
- **SDNQ**: State-of-the-Art quantization engine
|
||||
Use pre-quantized or run with quantizaion on-the-fly for up to 4x VRAM reduction with no or minimal quality and performance impact
|
||||
- **Balanced Offload**: Dynamically balance CPU and GPU memory to run larger models on limited hardware
|
||||
- **Captioning** with 150+ **OpenCLiP** models, **Tagger** with **WaifuDiffusion** and **DeepDanbooru** models, and 25+ built-in **VLMs**
|
||||
- **Image Processing** with full image correction color-grading suite of tools
|
||||
|
||||
<br>
|
||||
|
||||
## Model support
|
||||
## Supported AI Models
|
||||
|
||||
SD.Next supports broad range of models: [supported models](https://vladmandic.github.io/sdnext-docs/Model-Support/) and [model specs](https://vladmandic.github.io/sdnext-docs/Models/)
|
||||
|
||||
## Platform support
|
||||
## Supported Platforms & Hardware
|
||||
|
||||
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
|
||||
- *AMD* GPUs using **ROCm** libraries on *Linux*
|
||||
Support will be extended to *Windows* once AMD releases ROCm for Windows
|
||||
- *AMD* GPUs using **ROCm** libraries on both *Linux and Windows*
|
||||
- *AMD* GPUs on Windows using **ZLUDA** libraries
|
||||
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
|
||||
- Any *CPU/GPU* or device compatible with **OpenVINO** libraries on both *Windows and Linux*
|
||||
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
|
||||
This includes support for AMD GPUs that are not supported by native ROCm libraries
|
||||
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
|
||||
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
|
||||
- *ONNX/Olive*
|
||||
- *AMD* GPUs on Windows using **ZLUDA** libraries
|
||||
|
||||
Plus Docker container recipes for: [CUDA, ROCm, Intel IPEX and OpenVINO](https://vladmandic.github.io/sdnext-docs/Docker/)
|
||||
Plus **Docker** container recipes for: [CUDA, ROCm, Intel IPEX and OpenVINO](https://vladmandic.github.io/sdnext-docs/Docker/)
|
||||
|
||||
## Getting started
|
||||
|
||||
|
|
@ -84,21 +92,37 @@ Plus Docker container recipes for: [CUDA, ROCm, Intel IPEX and OpenVINO](https:/
|
|||
> And for platform specific information, check out
|
||||
> [WSL](https://vladmandic.github.io/sdnext-docs/WSL/) | [Intel Arc](https://vladmandic.github.io/sdnext-docs/Intel-ARC/) | [DirectML](https://vladmandic.github.io/sdnext-docs/DirectML/) | [OpenVINO](https://vladmandic.github.io/sdnext-docs/OpenVINO/) | [ONNX & Olive](https://vladmandic.github.io/sdnext-docs/ONNX-Runtime/) | [ZLUDA](https://vladmandic.github.io/sdnext-docs/ZLUDA/) | [AMD ROCm](https://vladmandic.github.io/sdnext-docs/AMD-ROCm/) | [MacOS](https://vladmandic.github.io/sdnext-docs/MacOS-Python/) | [nVidia](https://vladmandic.github.io/sdnext-docs/nVidia/) | [Docker](https://vladmandic.github.io/sdnext-docs/Docker/)
|
||||
|
||||
### Quick Start
|
||||
|
||||
```shell
|
||||
git clone https://github.com/vladmandic/sdnext
|
||||
cd sdnext
|
||||
./webui.sh # Linux/Mac
|
||||
webui.bat # Windows
|
||||
webui.ps1 # PowerShell
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> If you run into issues, check out [troubleshooting](https://vladmandic.github.io/sdnext-docs/Troubleshooting/) and [debugging](https://vladmandic.github.io/sdnext-docs/Debug/) guides
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you're unsure how to use a feature, best place to start is [Docs](https://vladmandic.github.io/sdnext-docs/) and if its not there,
|
||||
check [ChangeLog](https://vladmandic.github.io/sdnext-docs/CHANGELOG/) for when feature was first introduced as it will always have a short note on how to use it
|
||||
|
||||
And for any question, reach out on [Discord](https://discord.gg/VjvR2tabEX) or open an [issue](https://github.com/vladmandic/sdnext/issues) or [discussion](https://github.com/vladmandic/sdnext/discussions)
|
||||
|
||||
### Contributing
|
||||
|
||||
Please see [Contributing](CONTRIBUTING) for details on how to contribute to this project
|
||||
And for any question, reach out on [Discord](https://discord.gg/VjvR2tabEX) or open an [issue](https://github.com/vladmandic/sdnext/issues) or [discussion](https://github.com/vladmandic/sdnext/discussions)
|
||||
|
||||
### Credits
|
||||
## License & Credits
|
||||
|
||||
- SD.Next is licensed under the [Apache License 2.0](LICENSE.txt)
|
||||
- Main credit goes to [Automatic1111 WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) for the original codebase
|
||||
- Additional credits are listed in [Credits](https://github.com/AUTOMATIC1111/stable-diffusion-webui/#credits)
|
||||
- Licenses for modules are listed in [Licenses](html/licenses.html)
|
||||
|
||||
### Evolution
|
||||
## Evolution
|
||||
|
||||
<a href="https://star-history.com/#vladmandic/sdnext&Date">
|
||||
<picture width=640>
|
||||
|
|
@ -109,9 +133,4 @@ And for any question, reach out on [Discord](https://discord.gg/VjvR2tabEX) or o
|
|||
|
||||
- [OSS Stats](https://ossinsight.io/analyze/vladmandic/sdnext#overview)
|
||||
|
||||
### Docs
|
||||
|
||||
If you're unsure how to use a feature, best place to start is [Docs](https://vladmandic.github.io/sdnext-docs/) and if its not there,
|
||||
check [ChangeLog](https://vladmandic.github.io/sdnext-docs/CHANGELOG/) for when feature was first introduced as it will always have a short note on how to use it
|
||||
|
||||
<br>
|
||||
|
|
|
|||
17
TODO.md
17
TODO.md
|
|
@ -2,20 +2,18 @@
|
|||
|
||||
## Release
|
||||
|
||||
- Update **README**
|
||||
- Bumb packages
|
||||
- Implement `unload_auxiliary_models`
|
||||
- Release **Launcher**
|
||||
- Release **Enso**
|
||||
- Update **ROCm**
|
||||
- Tips **Color Grading**
|
||||
- Implement: `unload_auxiliary_models`
|
||||
- Add notes: **Enso**
|
||||
- Tips: **Color Grading**
|
||||
- Regen: **Localization**
|
||||
|
||||
## Internal
|
||||
|
||||
- Feature: Color grading in processing
|
||||
- Integrate: [Depth3D](https://github.com/vladmandic/sd-extension-depth3d)
|
||||
- Feature: RIFE update
|
||||
- Feature: RIFE in processing
|
||||
- Feature: SeedVR2 in processing
|
||||
- Feature: Add video models to `Reference`
|
||||
- Deploy: Lite vs Expert mode
|
||||
- Engine: [mmgp](https://github.com/deepbeepmeep/mmgp)
|
||||
- Engine: `TensorRT` acceleration
|
||||
|
|
@ -64,6 +62,7 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
|
|||
|
||||
### Image-Edit
|
||||
|
||||
- [Bria FIBO-Edit](https://huggingface.co/briaai/Fibo-Edit-RMBG): Fully JSON-based instruction-following image editing framework
|
||||
- [Meituan LongCat-Image-Edit-Turbo](https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo):6B instruction-following image editing with high visual consistency
|
||||
- [VIBE Image-Edit](https://huggingface.co/iitolstykh/VIBE-Image-Edit): (Sana+Qwen-VL)Fast visual instruction-based image editing framework
|
||||
- [LucyEdit](https://github.com/huggingface/diffusers/pull/12340):Instruction-guided video editing while preserving motion and identity
|
||||
|
|
@ -145,3 +144,5 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
|
|||
- modules/modular_guiders.py:65:58: W0511: TODO: guiders
|
||||
- processing: remove duplicate mask params
|
||||
- resize image: enable full VAE mode for resize-latent
|
||||
|
||||
modules/sd_samplers_diffusers.py:353:31: W0511: TODO enso-required (fixme)
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@
|
|||
"date": "2026 March",
|
||||
"skip": true
|
||||
},
|
||||
"FireRed Image Edit": {
|
||||
"FireRed Image Edit 1.0": {
|
||||
"path": "FireRedTeam/FireRed-Image-Edit-1.0",
|
||||
"preview": "FireRedTeam--FireRed-Image-Edit-1.0.jpg",
|
||||
"desc": "FireRed-Image-Edit is a general-purpose image editing model that delivers high-fidelity and consistent editing across a wide range of scenarios. FireRed is a fine-tune of Qwen-Image-Edit.",
|
||||
|
|
@ -156,6 +156,14 @@
|
|||
"date": "2026 February",
|
||||
"skip": true
|
||||
},
|
||||
"FireRed Image Edit 1.1": {
|
||||
"path": "FireRedTeam/FireRed-Image-Edit-1.1",
|
||||
"preview": "FireRedTeam--FireRed-Image-Edit-1.0.jpg",
|
||||
"desc": "FireRed-Image-Edit is a general-purpose image editing model that delivers high-fidelity and consistent editing across a wide range of scenarios. FireRed is a fine-tune of Qwen-Image-Edit.",
|
||||
"tags": "community",
|
||||
"date": "2026 February",
|
||||
"skip": true
|
||||
},
|
||||
"Skywork UniPic3": {
|
||||
"path": "Skywork/Unipic3",
|
||||
"preview": "Skywork--Unipic3.jpg",
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 9153b52f3980fe857c4ab9c3dd4f131b6175d20e
|
||||
Subproject commit 9d584a1bdc0c2aca614aa0e1e34e4374c3aa779d
|
||||
3260
html/locale_en.json
3260
html/locale_en.json
File diff suppressed because it is too large
Load Diff
25
installer.py
25
installer.py
|
|
@ -475,7 +475,7 @@ def check_diffusers():
|
|||
t_start = time.time()
|
||||
if args.skip_all:
|
||||
return
|
||||
target_commit = "e5aa719241f9b74d6700be3320a777799bfab70a" # diffusers commit hash
|
||||
target_commit = "c02c17c6ee7ac508c56925dde4d4a3c587650dc3" # diffusers commit hash
|
||||
# if args.use_rocm or args.use_zluda or args.use_directml:
|
||||
# sha = '043ab2520f6a19fce78e6e060a68dbc947edb9f9' # lock diffusers versions for now
|
||||
pkg = package_spec('diffusers')
|
||||
|
|
@ -687,7 +687,7 @@ def install_ipex():
|
|||
if args.use_nightly:
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
|
||||
else:
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+xpu torchvision==0.25.0+xpu --index-url https://download.pytorch.org/whl/xpu')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.11.0+xpu torchvision==0.26.0+xpu --index-url https://download.pytorch.org/whl/xpu')
|
||||
|
||||
ts('ipex', t_start)
|
||||
return torch_command
|
||||
|
|
@ -700,13 +700,12 @@ def install_openvino():
|
|||
|
||||
#check_python(supported_minors=[10, 11, 12, 13], reason='OpenVINO backend requires a Python version between 3.10 and 3.13')
|
||||
if sys.platform == 'darwin':
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0 torchvision==0.25.0')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.11.0 torchvision==0.26.0')
|
||||
else:
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+cpu torchvision==0.25.0 --index-url https://download.pytorch.org/whl/cpu')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.11.0+cpu torchvision==0.26.0 --index-url https://download.pytorch.org/whl/cpu')
|
||||
|
||||
if not (args.skip_all or args.skip_requirements):
|
||||
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino')
|
||||
install(os.environ.get('NNCF_COMMAND', 'nncf==2.19.0'), 'nncf')
|
||||
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2026.0.0'), 'openvino')
|
||||
ts('openvino', t_start)
|
||||
return torch_command
|
||||
|
||||
|
|
@ -730,12 +729,8 @@ def install_torch_addons():
|
|||
install('DeepCache')
|
||||
if opts.get('cuda_compile_backend', '') == 'olive-ai':
|
||||
install('olive-ai')
|
||||
if len(opts.get('optimum_quanto_weights', [])):
|
||||
install('optimum-quanto==0.2.7', 'optimum-quanto')
|
||||
if len(opts.get('torchao_quantization', [])):
|
||||
install('torchao==0.10.0', 'torchao')
|
||||
if opts.get('samples_format', 'jpg') == 'jxl' or opts.get('grid_format', 'jpg') == 'jxl':
|
||||
install('pillow-jxl-plugin==1.3.5', 'pillow-jxl-plugin')
|
||||
install('pillow-jxl-plugin==1.3.7', 'pillow-jxl-plugin')
|
||||
if not args.experimental:
|
||||
uninstall('wandb', quiet=True)
|
||||
uninstall('pynvml', quiet=True)
|
||||
|
|
@ -894,7 +889,7 @@ def check_torch():
|
|||
elif torch.version.hip and allow_rocm:
|
||||
torch_info.set(type='rocm', hip=torch.version.hip)
|
||||
else:
|
||||
log.warning('Unknown Torch backend')
|
||||
log.warning('Torch backend: cannot detect type')
|
||||
log.info(f"Torch backend: {torch_info}")
|
||||
for device in [torch.cuda.device(i) for i in range(torch.cuda.device_count())]:
|
||||
gpu = {
|
||||
|
|
@ -1184,14 +1179,11 @@ def install_optional():
|
|||
install('hf_transfer', ignore=True, quiet=True)
|
||||
install('hf_xet', ignore=True, quiet=True)
|
||||
install('nvidia-ml-py', ignore=True, quiet=True)
|
||||
install('pillow-jxl-plugin==1.3.5', ignore=True, quiet=True)
|
||||
install('pillow-jxl-plugin==1.3.7', ignore=True, quiet=True)
|
||||
install('ultralytics==8.3.40', ignore=True, quiet=True)
|
||||
install('open-clip-torch', no_deps=True, quiet=True)
|
||||
install('git+https://github.com/tencent-ailab/IP-Adapter.git', 'ip_adapter', ignore=True, quiet=True)
|
||||
# install('git+https://github.com/openai/CLIP.git', 'clip', quiet=True, no_build_isolation=True)
|
||||
# install('torchao==0.10.0', ignore=True, quiet=True)
|
||||
# install('bitsandbytes==0.47.0', ignore=True, quiet=True)
|
||||
# install('optimum-quanto==0.2.7', ignore=True, quiet=True)
|
||||
ts('optional', t_start)
|
||||
|
||||
|
||||
|
|
@ -1235,6 +1227,7 @@ def install_requirements():
|
|||
# set environment variables controling the behavior of various libraries
|
||||
def set_environment():
|
||||
log.debug('Setting environment tuning')
|
||||
os.environ.setdefault('PIP_CONSTRAINT', os.path.abspath('constraints.txt'))
|
||||
os.environ.setdefault('ACCELERATE', 'True')
|
||||
os.environ.setdefault('ATTN_PRECISION', 'fp16')
|
||||
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
|
||||
|
|
|
|||
|
|
@ -524,6 +524,14 @@ function selectVAE(name) {
|
|||
markSelectedCards([desiredVAEName], 'vae');
|
||||
}
|
||||
|
||||
let desiredUNetName = null;
|
||||
function selectUNet(name) {
|
||||
desiredUNetName = name;
|
||||
gradioApp().getElementById('change_unet').click();
|
||||
log(`selectUNet: ${desiredUNetName}`);
|
||||
markSelectedCards([desiredUNetName], 'unet');
|
||||
}
|
||||
|
||||
function selectReference(name) {
|
||||
log(`selectReference: ${name}`);
|
||||
desiredCheckpointName = name;
|
||||
|
|
|
|||
|
|
@ -308,6 +308,7 @@ def main():
|
|||
log.warning('Restart is recommended due to packages updates...')
|
||||
t_server = time.time()
|
||||
t_monitor = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
alive = uv.thread.is_alive()
|
||||
|
|
@ -326,8 +327,10 @@ def main():
|
|||
if float(monitor_rate) > 0 and t_current - t_monitor > float(monitor_rate):
|
||||
log.trace(f'Monitor: {get_memory_stats(detailed=True)}')
|
||||
t_monitor = t_current
|
||||
from modules.api.validate import get_api_stats
|
||||
get_api_stats()
|
||||
# from modules.api.validate import get_api_stats
|
||||
# get_api_stats()
|
||||
# from modules import memstats
|
||||
# memstats.get_objects()
|
||||
if not alive:
|
||||
if uv is not None and uv.wants_restart:
|
||||
clean_server()
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=list[models.ItemVae])
|
||||
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=list[models.ItemExtension])
|
||||
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=list[models.ItemExtraNetwork])
|
||||
self.add_api_route("/sdapi/v1/unets", endpoints.get_unets, methods=["GET"], response_model=list[models.ItemUNet])
|
||||
|
||||
# functional api
|
||||
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo, tags=["Functional"])
|
||||
|
|
@ -98,6 +99,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"], tags=["Functional"])
|
||||
self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"], tags=["Functional"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"], tags=["Functional"])
|
||||
self.add_api_route("/sdapi/v1/refresh-unets", endpoints.post_refresh_unets, methods=["POST"], tags=["Functional"])
|
||||
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=list[str], tags=["Functional"])
|
||||
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int, tags=["Functional"])
|
||||
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"], tags=["Functional"])
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ get_restorers = get_detailers # legacy alias for /sdapi/v1/face-restorers
|
|||
def get_ip_adapters():
|
||||
"""
|
||||
List available IP-Adapter models.
|
||||
|
||||
Returns adapter names that can be used for image-prompt conditioning during generation.
|
||||
"""
|
||||
from modules import ipadapter
|
||||
|
|
@ -75,6 +74,11 @@ def get_prompt_styles():
|
|||
"""List all saved prompt styles with their prompt, negative prompt, and preview."""
|
||||
return [{ 'name': v.name, 'prompt': v.prompt, 'negative_prompt': v.negative_prompt, 'extra': v.extra, 'filename': v.filename, 'preview': v.preview} for v in shared.prompt_styles.styles.values()]
|
||||
|
||||
def get_unets():
|
||||
"""List available UNet models with their names and filenames."""
|
||||
from modules.sd_unet import unet_dict
|
||||
return [{"name": k, "filename": v} for k, v in unet_dict.items()]
|
||||
|
||||
def get_embeddings():
|
||||
"""List loaded and skipped textual-inversion embeddings for the current model."""
|
||||
db = getattr(shared.sd_model, 'embedding_db', None) if shared.sd_loaded else None
|
||||
|
|
@ -221,6 +225,11 @@ def post_lock_checkpoint(lock:bool=False):
|
|||
modeldata.model_data.locked = lock
|
||||
return {}
|
||||
|
||||
def post_refresh_unets():
|
||||
"""Rescan UNet directories and update the available UNet list."""
|
||||
import modules.sd_unet
|
||||
return modules.sd_unet.refresh_unet_list()
|
||||
|
||||
def get_checkpoint():
|
||||
"""Return information about the currently loaded checkpoint including type, class, title, and hash."""
|
||||
if not shared.sd_loaded or shared.sd_model is None:
|
||||
|
|
|
|||
|
|
@ -146,6 +146,10 @@ class ItemStyle(BaseModel):
|
|||
filename: str | None = Field(title="Filename", description="Path to the styles file")
|
||||
preview: str | None = Field(title="Preview", description="URL to the style preview image")
|
||||
|
||||
class ItemUNet(BaseModel):
|
||||
name: str = Field(title="Name", description="UNet/DiT name")
|
||||
filename: str | None = Field(title="Filename", description="Path to the UNet/DiT file")
|
||||
|
||||
class ItemExtraNetwork(BaseModel):
|
||||
name: str = Field(title="Name", description="Network short name")
|
||||
type: str = Field(title="Type", description="Network type (lora, checkpoint, embedding, etc.)")
|
||||
|
|
|
|||
|
|
@ -219,8 +219,8 @@ def civit_search_metadata(title: str | None = None, raw: bool = False):
|
|||
import concurrent
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_items = {}
|
||||
for fn in candidates:
|
||||
future_items[executor.submit(atomic_civit_search_metadata, fn, results)] = fn
|
||||
for candidate in candidates:
|
||||
future_items[executor.submit(atomic_civit_search_metadata, candidate, results)] = candidate
|
||||
for future in concurrent.futures.as_completed(future_items):
|
||||
future.result()
|
||||
yield results if raw else create_search_metadata_table(results)
|
||||
|
|
|
|||
|
|
@ -74,8 +74,8 @@ def add_diag_args(p):
|
|||
p.add_argument('--safe', default=env_flag("SD_SAFE", False), action='store_true', help="Run in safe mode with no user extensions")
|
||||
p.add_argument('--test', default=env_flag("SD_TEST", False), action='store_true', help="Run test only and exit")
|
||||
p.add_argument('--version', default=False, action='store_true', help="Print version information")
|
||||
p.add_argument("--monitor", default=os.environ.get("SD_MONITOR", -1), help="Run memory monitor, default: %(default)s")
|
||||
p.add_argument("--status", default=os.environ.get("SD_STATUS", -1), help="Run server is-alive status, default: %(default)s")
|
||||
p.add_argument("--monitor", type=float, default=float(os.environ.get("SD_MONITOR", -1)), help="Run memory monitor, default: %(default)s")
|
||||
p.add_argument("--status", type=float, default=float(os.environ.get("SD_STATUS", -1)), help="Run server is-alive status, default: %(default)s")
|
||||
|
||||
|
||||
def add_log_args(p):
|
||||
|
|
|
|||
|
|
@ -382,22 +382,6 @@ class ControlNet():
|
|||
self.model = sdnq_quantize_model(self.model)
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model SDNQ Compression failed: id="{model_id}" {e}')
|
||||
elif "Control" in opts.optimum_quanto_weights:
|
||||
try:
|
||||
log.debug(f'Control {what} model Optimum Quanto: id="{model_id}"')
|
||||
model_quant.load_quanto('Load model: type=Control')
|
||||
from modules.model_quant import optimum_quanto_model
|
||||
self.model = optimum_quanto_model(self.model)
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model Optimum Quanto: id="{model_id}" {e}')
|
||||
elif "Control" in opts.torchao_quantization:
|
||||
try:
|
||||
log.debug(f'Control {what} model Torch AO: id="{model_id}"')
|
||||
model_quant.load_torchao('Load model: type=Control')
|
||||
from modules.model_quant import torchao_quantization
|
||||
self.model = torchao_quantization(self.model)
|
||||
except Exception as e:
|
||||
log.error(f'Control {what} model Torch AO: id="{model_id}" {e}')
|
||||
if self.device is not None:
|
||||
sd_models.move_model(self.model, self.device)
|
||||
if "Control" in opts.cuda_compile:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from installer import log, git
|
||||
from installer import log, git, run_extension_installer
|
||||
from modules.paths import extensions_dir
|
||||
|
||||
|
||||
|
|
@ -14,6 +14,7 @@ def install():
|
|||
return
|
||||
log.info(f'Enso: folder="{ENSO_DIR}" installing')
|
||||
git(f'clone "{ENSO_REPO}" "{ENSO_DIR}"')
|
||||
run_extension_installer(ENSO_DIR)
|
||||
|
||||
|
||||
def update():
|
||||
|
|
@ -22,3 +23,4 @@ def update():
|
|||
return
|
||||
log.info(f'Enso: folder="{ENSO_DIR}" updating')
|
||||
git('pull', folder=ENSO_DIR)
|
||||
run_extension_installer(ENSO_DIR)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ _lock = Lock()
|
|||
|
||||
|
||||
def _make_unique(name: str):
|
||||
global _instance_id
|
||||
global _instance_id # pylint: disable=global-statement
|
||||
with _lock: # Guard against race conditions
|
||||
new_name = f"{name}__{_instance_id}"
|
||||
_instance_id += 1
|
||||
|
|
|
|||
|
|
@ -16,6 +16,28 @@ torch_version[0], torch_version[1] = int(torch_version[0]), int(torch_version[1]
|
|||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def return_true(*args, **kwargs):
|
||||
return True
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
def return_none(*args, **kwargs):
|
||||
return None
|
||||
|
||||
def return_zero(*args, **kwargs):
|
||||
return 0
|
||||
|
||||
def return_cuda_version(*args, **kwargs):
|
||||
return (12,1)
|
||||
|
||||
def return_xpu_string(*args, **kwargs):
|
||||
return "xpu"
|
||||
|
||||
def return_arch_list(*args, **kwargs):
|
||||
return ["pvc", "dg2", "ats-m150"]
|
||||
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||
|
|
@ -26,9 +48,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||
# import inductor utils to get around lazy import
|
||||
from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401,RUF100
|
||||
torch._inductor.utils.GPU_TYPES = ["xpu"]
|
||||
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
|
||||
torch._inductor.utils.get_gpu_type = return_xpu_string
|
||||
from triton import backends as triton_backends # pylint: disable=import-error
|
||||
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
|
||||
triton_backends.backends["nvidia"].driver.is_active = return_false
|
||||
except Exception:
|
||||
pass
|
||||
# Replace cuda with xpu:
|
||||
|
|
@ -51,15 +73,12 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.is_current_stream_capturing = return_false
|
||||
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
|
|
@ -141,12 +160,23 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
|
||||
if torch_version[0] < 2 or (torch_version[0] == 2 and torch_version[1] < 11):
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
|
||||
# Memory:
|
||||
if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.xpu.empty_cache = return_none
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if torch_version[0] >= 2 and torch_version[1] >= 8:
|
||||
old_cpa = torch.cuda.memory.CUDAPluggableAllocator
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
torch.xpu.memory.CUDAPluggableAllocator = old_cpa
|
||||
else:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
|
|
@ -172,21 +202,24 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# Fix functions with ipex:
|
||||
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||
torch.has_cuda = True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.backends.cuda.is_built = return_true
|
||||
torch._utils._get_available_device_type = return_xpu_string
|
||||
|
||||
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||
def mem_get_info(device=None):
|
||||
return [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch.xpu.mem_get_info = mem_get_info
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True)
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"])
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", return_true)
|
||||
torch.cuda.is_fp16_supported = getattr(torch.xpu, "is_fp16_supported", return_true)
|
||||
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", return_arch_list)
|
||||
torch.cuda.get_device_capability = return_cuda_version
|
||||
torch.cuda.ipc_collect = return_none
|
||||
torch.cuda.utilization = return_zero
|
||||
|
||||
device_supports_fp64 = ipex_hijacks()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@ torch_version[0], torch_version[1] = int(torch_version[0]), int(torch_version[1]
|
|||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(devices.device).has_fp64
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, no-else-return
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
|
|
@ -24,7 +26,7 @@ def is_cuda(self):
|
|||
|
||||
|
||||
def check_device_type(device, device_type: str) -> bool:
|
||||
if device is None or type(device) not in {str, int, torch.device}:
|
||||
if device is None or not isinstance(device, (str, int, torch.device)):
|
||||
return False
|
||||
else:
|
||||
return bool(torch.device(device).type == device_type)
|
||||
|
|
@ -137,24 +139,9 @@ def as_tensor(data, dtype=None, device=None):
|
|||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_cuda(device):
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64 and check_device_type(device, "xpu"):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
|
||||
torch.Tensor.original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_cuda(device):
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
|
|
@ -210,6 +197,24 @@ if torch_version[0] > 2 or (torch_version[0] == 2 and torch_version[1] >= 4):
|
|||
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if not device_supports_fp64 and (dtype == torch.float64 or (dtype is None and getattr(data, "dtype", None) in {torch.float64, float})):
|
||||
return original_torch_tensor(data, *args, dtype=torch.float32, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
if (
|
||||
not device_supports_fp64 and check_device_type(device, "xpu")
|
||||
and (dtype == torch.float64 or (dtype is None and getattr(data, "dtype", None) in {torch.float64, float}))
|
||||
):
|
||||
return original_torch_tensor(data, *args, dtype=torch.float32, device=device, **kwargs)
|
||||
else:
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
|
|
@ -221,11 +226,11 @@ def torch_empty(*args, device=None, **kwargs):
|
|||
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), dtype=dtype, **kwargs)
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_randn(*args, device=device, dtype=dtype, **kwargs)
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
|
||||
|
||||
original_torch_ones = torch.ones
|
||||
|
|
@ -255,34 +260,6 @@ def torch_full(*args, device=None, **kwargs):
|
|||
return original_torch_full(*args, device=device, **kwargs)
|
||||
|
||||
|
||||
original_torch_arange = torch.arange
|
||||
@wraps(torch.arange)
|
||||
def torch_arange(*args, device=None, dtype=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_cuda(device):
|
||||
if not device_supports_fp64 and dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
return original_torch_arange(*args, device=return_xpu(device), dtype=dtype, **kwargs)
|
||||
else:
|
||||
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
return original_torch_arange(*args, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, dtype=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_cuda(device):
|
||||
if not device_supports_fp64 and dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
return original_torch_linspace(*args, device=return_xpu(device), dtype=dtype, **kwargs)
|
||||
else:
|
||||
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
return original_torch_linspace(*args, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
original_torch_eye = torch.eye
|
||||
@wraps(torch.eye)
|
||||
def torch_eye(*args, device=None, **kwargs):
|
||||
|
|
@ -292,6 +269,36 @@ def torch_eye(*args, device=None, **kwargs):
|
|||
return original_torch_eye(*args, device=device, **kwargs)
|
||||
|
||||
|
||||
original_torch_arange = torch.arange
|
||||
@wraps(torch.arange)
|
||||
def torch_arange(*args, dtype=None, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if not device_supports_fp64 and dtype == torch.float64:
|
||||
return original_torch_arange(*args, dtype=torch.float32, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_arange(*args, dtype=dtype, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
|
||||
return original_torch_arange(*args, dtype=torch.float32, device=device, **kwargs)
|
||||
else:
|
||||
return original_torch_arange(*args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, dtype=None, device=None, **kwargs):
|
||||
if check_cuda(device):
|
||||
if not device_supports_fp64 and dtype == torch.float64:
|
||||
return original_torch_linspace(*args, dtype=torch.float32, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, dtype=dtype, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
if not device_supports_fp64 and check_device_type(device, "xpu") and dtype == torch.float64:
|
||||
return original_torch_linspace(*args, dtype=torch.float32, device=device, **kwargs)
|
||||
else:
|
||||
return original_torch_linspace(*args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, *args, **kwargs):
|
||||
|
|
@ -360,24 +367,29 @@ class torch_Generator(original_torch_Generator):
|
|||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
global device_supports_fp64
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
if torch_version[0] > 2 or (torch_version[0] == 2 and torch_version[1] >= 4):
|
||||
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
||||
torch.UntypedStorage.to = UntypedStorage_to
|
||||
torch.tensor = torch_tensor
|
||||
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
torch.Tensor.pin_memory = Tensor_pin_memory
|
||||
torch.UntypedStorage.__init__ = UntypedStorage_init
|
||||
|
||||
# transformers completely breaks when anything is done to torch.tensor
|
||||
# even straight passthroughs breaks transformers for some reason
|
||||
#torch.tensor = torch_tensor
|
||||
|
||||
torch.empty = torch_empty
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.full = torch_full
|
||||
torch.eye = torch_eye
|
||||
torch.arange = torch_arange
|
||||
torch.linspace = torch_linspace
|
||||
torch.eye = torch_eye
|
||||
torch.load = torch_load
|
||||
|
||||
torch.cuda.synchronize = torch_cuda_synchronize
|
||||
torch.cuda.device = torch_cuda_device
|
||||
torch.cuda.set_device = torch_cuda_set_device
|
||||
|
|
@ -437,6 +449,6 @@ def ipex_hijacks():
|
|||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
torch.cuda.amp.common.amp_definitely_not_available = return_false
|
||||
|
||||
return device_supports_fp64
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import nncf
|
||||
|
||||
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
|
||||
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
|
||||
from openvino.frontend import FrontEndManager
|
||||
from openvino import Core, Type, PartialShape, serialize
|
||||
from openvino.properties import hint as ov_hints
|
||||
from openvino.frontend import FrontEndManager # pylint: disable=no-name-in-module
|
||||
from openvino import Core, Type, PartialShape, serialize # pylint: disable=no-name-in-module
|
||||
from openvino.properties import hint as ov_hints # pylint: disable=no-name-in-module
|
||||
|
||||
from torch._dynamo.backends.common import fake_tensor_unsupported
|
||||
from torch._dynamo.backends.registry import register_backend
|
||||
|
|
@ -23,25 +21,6 @@ from modules import shared, devices, sd_models_utils
|
|||
from modules.logger import log
|
||||
|
||||
|
||||
# importing openvino.runtime forces DeprecationWarning to "always"
|
||||
# And Intel's own libs (NNCF) imports the deprecated module
|
||||
# Don't allow openvino to override warning filters:
|
||||
try:
|
||||
import warnings
|
||||
filterwarnings = warnings.filterwarnings
|
||||
warnings.filterwarnings = lambda *args, **kwargs: None
|
||||
import openvino.runtime # pylint: disable=unused-import
|
||||
installer.torch_info.set(openvino=openvino.runtime.get_version())
|
||||
warnings.filterwarnings = filterwarnings
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# silence the pytorch version warning
|
||||
nncf.common.logging.logger.warn_bkc_version_mismatch = lambda *args, **kwargs: None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Set default params
|
||||
torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) # pylint: disable=protected-access
|
||||
torch._dynamo.eval_frame.check_if_dynamo_supported = lambda: True # pylint: disable=protected-access
|
||||
|
|
@ -213,11 +192,7 @@ def execute_cached(compiled_model, *args):
|
|||
|
||||
def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str | None = None, file_name=""):
|
||||
core = Core()
|
||||
|
||||
device = get_device()
|
||||
global dont_use_4bit_nncf
|
||||
global dont_use_nncf
|
||||
global dont_use_quant
|
||||
|
||||
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||
om = core.read_model(file_name + ".xml")
|
||||
|
|
@ -259,26 +234,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str | Non
|
|||
om.inputs[idx-idx_minus].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
if shared.opts.nncf_quantize and not dont_use_quant:
|
||||
new_inputs = []
|
||||
for idx, _ in enumerate(example_inputs):
|
||||
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
|
||||
new_inputs = [new_inputs]
|
||||
if shared.opts.nncf_quantize_mode == "INT8":
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs))
|
||||
else:
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
|
||||
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
|
||||
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
|
||||
|
||||
if shared.opts.nncf_compress_weights and not dont_use_nncf:
|
||||
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
|
||||
om = nncf.compress_weights(om)
|
||||
else:
|
||||
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
|
||||
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
|
||||
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
|
||||
|
||||
hints = {}
|
||||
if shared.opts.openvino_accuracy == "performance":
|
||||
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
|
||||
|
|
@ -287,9 +242,6 @@ def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str | Non
|
|||
if model_hash_str is not None:
|
||||
hints['CACHE_DIR'] = shared.opts.openvino_cache_path + '/blob'
|
||||
core.set_property(hints)
|
||||
dont_use_nncf = False
|
||||
dont_use_quant = False
|
||||
dont_use_4bit_nncf = False
|
||||
|
||||
compiled_model = core.compile_model(om, device)
|
||||
return compiled_model
|
||||
|
|
@ -299,44 +251,17 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
|
|||
core = Core()
|
||||
om = core.read_model(cached_model_path + ".xml")
|
||||
|
||||
global dont_use_4bit_nncf
|
||||
global dont_use_nncf
|
||||
global dont_use_quant
|
||||
|
||||
for idx, input_data in enumerate(example_inputs):
|
||||
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
|
||||
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
if shared.opts.nncf_quantize and not dont_use_quant:
|
||||
new_inputs = []
|
||||
for idx, _ in enumerate(example_inputs):
|
||||
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
|
||||
new_inputs = [new_inputs]
|
||||
if shared.opts.nncf_quantize_mode == "INT8":
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs))
|
||||
else:
|
||||
om = nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quantize_mode),
|
||||
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
|
||||
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
|
||||
|
||||
if shared.opts.nncf_compress_weights and not dont_use_nncf:
|
||||
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
|
||||
om = nncf.compress_weights(om)
|
||||
else:
|
||||
compress_group_size = shared.opts.nncf_compress_weights_group_size if shared.opts.nncf_compress_weights_group_size != 0 else None
|
||||
compress_ratio = shared.opts.nncf_compress_weights_raito if shared.opts.nncf_compress_weights_raito != 0 else None
|
||||
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=compress_group_size, ratio=compress_ratio)
|
||||
|
||||
hints = {'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'}
|
||||
if shared.opts.openvino_accuracy == "performance":
|
||||
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.PERFORMANCE
|
||||
elif shared.opts.openvino_accuracy == "accuracy":
|
||||
hints[ov_hints.execution_mode] = ov_hints.ExecutionMode.ACCURACY
|
||||
core.set_property(hints)
|
||||
dont_use_nncf = False
|
||||
dont_use_quant = False
|
||||
dont_use_4bit_nncf = False
|
||||
|
||||
compiled_model = core.compile_model(om, get_device())
|
||||
return compiled_model
|
||||
|
|
@ -462,14 +387,8 @@ def get_subgraph_type(tensor):
|
|||
|
||||
@fake_tensor_unsupported
|
||||
def openvino_fx(subgraph, example_inputs, options=None):
|
||||
global dont_use_4bit_nncf
|
||||
global dont_use_nncf
|
||||
global dont_use_quant
|
||||
global subgraph_type
|
||||
|
||||
dont_use_4bit_nncf = False
|
||||
dont_use_nncf = False
|
||||
dont_use_quant = False
|
||||
dont_use_faketensors = False
|
||||
executor_parameters = None
|
||||
inputs_reversed = False
|
||||
|
|
@ -478,25 +397,25 @@ def openvino_fx(subgraph, example_inputs, options=None):
|
|||
subgraph_type = []
|
||||
subgraph.apply(get_subgraph_type)
|
||||
|
||||
"""
|
||||
# SD 1.5 / SDXL VAE
|
||||
if (subgraph_type[0] is torch.nn.modules.conv.Conv2d and
|
||||
if (
|
||||
subgraph_type[0] is torch.nn.modules.conv.Conv2d and
|
||||
subgraph_type[1] is torch.nn.modules.conv.Conv2d and
|
||||
subgraph_type[2] is torch.nn.modules.normalization.GroupNorm and
|
||||
subgraph_type[3] is torch.nn.modules.activation.SiLU):
|
||||
|
||||
dont_use_4bit_nncf = True
|
||||
dont_use_nncf = bool("VAE" not in shared.opts.nncf_compress_weights)
|
||||
dont_use_quant = bool("VAE" not in shared.opts.nncf_quantize)
|
||||
subgraph_type[3] is torch.nn.modules.activation.SiLU
|
||||
):
|
||||
pass
|
||||
"""
|
||||
|
||||
# SD 1.5 / SDXL Text Encoder
|
||||
elif (subgraph_type[0] is torch.nn.modules.sparse.Embedding and
|
||||
if (
|
||||
subgraph_type[0] is torch.nn.modules.sparse.Embedding and
|
||||
subgraph_type[1] is torch.nn.modules.sparse.Embedding and
|
||||
subgraph_type[2] is torch.nn.modules.normalization.LayerNorm and
|
||||
subgraph_type[3] is torch.nn.modules.linear.Linear):
|
||||
|
||||
subgraph_type[3] is torch.nn.modules.linear.Linear
|
||||
):
|
||||
dont_use_faketensors = True
|
||||
dont_use_nncf = bool("TE" not in shared.opts.nncf_compress_weights)
|
||||
dont_use_quant = bool("TE" not in shared.opts.nncf_quantize)
|
||||
|
||||
# Create a hash to be used for caching
|
||||
shared.compiled_model_state.model_hash_str = ""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from functools import partial
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import logging
|
||||
import warnings
|
||||
import urllib3
|
||||
|
|
@ -133,6 +134,14 @@ timer.startup.record("accelerate")
|
|||
import pydantic # pylint: disable=W0611,C0411
|
||||
timer.startup.record("pydantic")
|
||||
|
||||
try:
|
||||
# transformers==5.x has different dependency stack so switching between v4 and v5 becomes very painful
|
||||
# this temporarily disables dependency version checks so we can use either v4 or v5 until we drop support for v4
|
||||
fake_version_check = types.ModuleType("transformers.dependency_versions_check")
|
||||
sys.modules["transformers.dependency_versions_check"] = fake_version_check # disable transformers version checks
|
||||
fake_version_check.dep_version_check = lambda pkg, hint=None: None
|
||||
except Exception:
|
||||
pass
|
||||
import transformers # pylint: disable=W0611,C0411
|
||||
from transformers import logging as transformers_logging # pylint: disable=W0611,C0411
|
||||
transformers_logging.set_verbosity_error()
|
||||
|
|
@ -175,9 +184,10 @@ except Exception as e:
|
|||
sys.exit(1)
|
||||
|
||||
try:
|
||||
pass # pylint: disable=W0611,C0411
|
||||
import pillow_jxl # pylint: disable=W0611,C0411
|
||||
except Exception:
|
||||
pass
|
||||
from PIL import Image # pylint: disable=W0611,C0411
|
||||
timer.startup.record("pillow")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import re
|
|||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
import torch
|
||||
import diffusers.models.lora
|
||||
from modules.lora import lora_common as l
|
||||
from modules import shared, devices, errors, model_quant
|
||||
from modules.logger import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
import diffusers.models.lora
|
||||
|
||||
|
||||
bnb = None
|
||||
|
|
|
|||
|
|
@ -257,7 +257,11 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
shared.compiled_model_state.lora_model.append(f"{name}:{lora_scale}")
|
||||
lora_method = lora_overrides.get_method(shorthash)
|
||||
if lora_method == 'diffusers':
|
||||
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
|
||||
if shared.sd_model_type == 'f2':
|
||||
from pipelines.flux import flux2_lora
|
||||
net = flux2_lora.try_load_lokr(name, network_on_disk, lora_scale)
|
||||
if net is None:
|
||||
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
|
||||
elif lora_method == 'nunchaku':
|
||||
pass # handled directly from extra_networks_lora.load_nunchaku
|
||||
else:
|
||||
|
|
@ -272,7 +276,8 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
continue
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} not found')
|
||||
lora_ver = network_on_disk.sd_version if network_on_disk is not None else None
|
||||
log.error(f'Network load: type=LoRA name="{name}" detected={lora_ver} not loaded')
|
||||
continue
|
||||
if hasattr(sd_model, 'embedding_db'):
|
||||
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
|
||||
|
|
@ -309,6 +314,12 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
errors.display(e, 'LoRA')
|
||||
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, force=True, silent=True) # some layers may end up on cpu without hook
|
||||
|
||||
# Activate native modules loaded via diffusers path (e.g., LoKR on Flux2)
|
||||
native_nets = [net for net in l.loaded_networks if len(net.modules) > 0]
|
||||
if native_nets:
|
||||
from modules.lora import networks
|
||||
networks.network_activate()
|
||||
|
||||
if len(l.loaded_networks) > 0 and l.debug:
|
||||
log.debug(f'Network load: type=LoRA loaded={[n.name for n in l.loaded_networks]} cache={list(lora_cache)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ class NetworkOnDisk:
|
|||
return 'sc'
|
||||
if base.startswith("sd3"):
|
||||
return 'sd3'
|
||||
if base.startswith("flux2") or "klein" in base:
|
||||
return 'f2'
|
||||
if base.startswith("flux"):
|
||||
return 'f1'
|
||||
if base.startswith("hunyuan_video"):
|
||||
|
|
@ -75,6 +77,8 @@ class NetworkOnDisk:
|
|||
return 'xl'
|
||||
if arch.startswith("stable-cascade"):
|
||||
return 'sc'
|
||||
if arch.startswith("flux2") or "klein" in arch:
|
||||
return 'f2'
|
||||
if arch.startswith("flux"):
|
||||
return 'f1'
|
||||
if arch.startswith("hunyuan-video"):
|
||||
|
|
@ -86,6 +90,8 @@ class NetworkOnDisk:
|
|||
return 'sd1'
|
||||
if str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return 'sd2'
|
||||
if 'klein' in self.name.lower() or 'klein' in self.fullname.lower():
|
||||
return 'f2'
|
||||
if 'flux' in self.name.lower():
|
||||
return 'f1'
|
||||
if 'xl' in self.name.lower():
|
||||
|
|
|
|||
|
|
@ -55,3 +55,40 @@ class NetworkModuleLokr(network.NetworkModule): # pylint: disable=abstract-metho
|
|||
output_shape = target.shape
|
||||
updown = make_kron(output_shape, w1, w2)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
||||
|
||||
class NetworkModuleLokrChunk(NetworkModuleLokr):
|
||||
"""LoKR module that returns one chunk of the Kronecker product.
|
||||
|
||||
Used when a LoKR adapter targets a fused weight (e.g., QKV) but the model
|
||||
has separate modules (Q, K, V). Computes kron(w1, w2) on-the-fly and
|
||||
returns only the designated chunk, keeping memory usage minimal.
|
||||
"""
|
||||
def __init__(self, net, weights, chunk_index, num_chunks):
|
||||
super().__init__(net, weights)
|
||||
self.chunk_index = chunk_index
|
||||
self.num_chunks = num_chunks
|
||||
|
||||
def calc_updown(self, target):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(target.device, dtype=target.dtype)
|
||||
else:
|
||||
w1a = self.w1a.to(target.device, dtype=target.dtype)
|
||||
w1b = self.w1b.to(target.device, dtype=target.dtype)
|
||||
w1 = w1a @ w1b
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(target.device, dtype=target.dtype)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(target.device, dtype=target.dtype)
|
||||
w2a = self.w2a.to(target.device, dtype=target.dtype)
|
||||
w2b = self.w2b.to(target.device, dtype=target.dtype)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
full_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
updown = make_kron(full_shape, w1, w2)
|
||||
updown = torch.chunk(updown, self.num_chunks, dim=0)[self.chunk_index]
|
||||
output_shape = list(updown.shape)
|
||||
return self.finalize_updown(updown, target, output_shape)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import re
|
||||
import sys
|
||||
import os
|
||||
import types
|
||||
from collections import deque
|
||||
import psutil
|
||||
import torch
|
||||
from modules import shared, errors
|
||||
from modules import shared, errors, devices
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
|
|
@ -130,28 +132,53 @@ def reset_stats():
|
|||
class Object:
|
||||
pattern = r"'(.*?)'"
|
||||
|
||||
def get_size(self, obj, seen=None):
|
||||
size = sys.getsizeof(obj)
|
||||
if seen is None:
|
||||
seen = set()
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return 0 # Avoid double counting
|
||||
seen.add(obj_id)
|
||||
if isinstance(obj, dict):
|
||||
size += sum(self.get_size(k, seen) + self.get_size(v, seen) for k, v in obj.items())
|
||||
elif isinstance(obj, (list, tuple, set, frozenset, deque)):
|
||||
size += sum(self.get_size(i, seen) for i in obj)
|
||||
return size
|
||||
|
||||
def __init__(self, name, obj):
|
||||
self.id = id(obj)
|
||||
self.name = name
|
||||
self.fn = sys._getframe(2).f_code.co_name
|
||||
self.size = sys.getsizeof(obj)
|
||||
self.refcount = sys.getrefcount(obj)
|
||||
if torch.is_tensor(obj):
|
||||
self.type = obj.dtype
|
||||
self.size = obj.element_size() * obj.nelement()
|
||||
else:
|
||||
self.type = re.findall(self.pattern, str(type(obj)))[0]
|
||||
self.size = sys.getsizeof(obj)
|
||||
self.size = self.get_size(obj)
|
||||
def __str__(self):
|
||||
return f'{self.fn}.{self.name} type={self.type} size={self.size} ref={self.refcount}'
|
||||
|
||||
|
||||
def get_objects(gcl=None, threshold:int=0):
|
||||
def get_objects(gcl=None, threshold:int=1024*1024):
|
||||
devices.torch_gc(force=True)
|
||||
if gcl is None:
|
||||
# gcl = globals()
|
||||
gcl = {}
|
||||
log.trace(f'Memory: modules={len(sys.modules)}')
|
||||
for _module_name, module in sys.modules.items():
|
||||
try:
|
||||
if not isinstance(module, types.ModuleType):
|
||||
continue
|
||||
namespace = vars(module)
|
||||
gcl.update(namespace)
|
||||
except Exception:
|
||||
pass # Some modules may not allow introspection
|
||||
objects = []
|
||||
seen = []
|
||||
|
||||
log.trace(f'Memory: items={len(gcl)} threshold={threshold}')
|
||||
for name, obj in gcl.items():
|
||||
if id(obj) in seen:
|
||||
continue
|
||||
|
|
@ -169,6 +196,6 @@ def get_objects(gcl=None, threshold:int=0):
|
|||
|
||||
objects = sorted(objects, key=lambda x: x.size, reverse=True)
|
||||
for obj in objects:
|
||||
log.trace(obj)
|
||||
log.trace(f'Memory: {obj}')
|
||||
|
||||
return objects
|
||||
|
|
|
|||
|
|
@ -1,12 +1,10 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
import diffusers
|
||||
import transformers
|
||||
from installer import installed, install, setup_logging
|
||||
from installer import install
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
|
|
@ -51,70 +49,6 @@ def dont_quant():
|
|||
return False
|
||||
|
||||
|
||||
def create_bnb_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared, devices
|
||||
if allow and (module == 'any' or module in shared.opts.bnb_quantization):
|
||||
load_bnb()
|
||||
if bnb is None:
|
||||
return kwargs
|
||||
bnb_config = diffusers.BitsAndBytesConfig(
|
||||
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
|
||||
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
|
||||
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
|
||||
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
|
||||
bnb_4bit_compute_dtype=devices.dtype,
|
||||
llm_int8_skip_modules=modules_to_not_convert,
|
||||
)
|
||||
log.debug(f'Quantization: module={module} type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
|
||||
if kwargs is None:
|
||||
return bnb_config
|
||||
else:
|
||||
kwargs['quantization_config'] = bnb_config
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_ao_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared
|
||||
if allow and (shared.opts.torchao_quantization_mode in {'pre', 'auto'}) and (module == 'any' or module in shared.opts.torchao_quantization):
|
||||
torchao = load_torchao()
|
||||
if torchao is None:
|
||||
return kwargs
|
||||
if module in {'TE', 'LLM'}:
|
||||
ao_config = transformers.TorchAoConfig(quant_type=shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
else:
|
||||
ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
log.debug(f'Quantization: module={module} type=torchao dtype={shared.opts.torchao_quantization_type}')
|
||||
if kwargs is None:
|
||||
return ao_config
|
||||
else:
|
||||
kwargs['quantization_config'] = ao_config
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_quanto_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared
|
||||
if allow and (module == 'any' or module in shared.opts.quanto_quantization):
|
||||
load_quanto(silent=True)
|
||||
if optimum_quanto is None:
|
||||
return kwargs
|
||||
if module in {'TE', 'LLM'}:
|
||||
quanto_config = transformers.QuantoConfig(weights=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
quanto_config.weights_dtype = quanto_config.weights
|
||||
else:
|
||||
quanto_config = diffusers.QuantoConfig(weights_dtype=shared.opts.quanto_quantization_type, modules_to_not_convert=modules_to_not_convert)
|
||||
quanto_config.activations = None # patch so it works with transformers
|
||||
quanto_config.weights = quanto_config.weights_dtype
|
||||
log.debug(f'Quantization: module={module} type=quanto dtype={shared.opts.quanto_quantization_type}')
|
||||
if kwargs is None:
|
||||
return quanto_config
|
||||
else:
|
||||
kwargs['quantization_config'] = quanto_config
|
||||
return kwargs
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_trt_config(kwargs = None, allow: bool = True, module: str = 'Model', modules_to_not_convert: list | None = None):
|
||||
from modules import shared
|
||||
if allow and (module == 'any' or module in shared.opts.trt_quantization):
|
||||
|
|
@ -249,7 +183,7 @@ def create_sdnq_config(kwargs = None, allow: bool = True, module: str = 'Model',
|
|||
|
||||
def check_quant(module: str = ''):
|
||||
from modules import shared
|
||||
if module in shared.opts.sdnq_quantize_weights or module in shared.opts.bnb_quantization or module in shared.opts.torchao_quantization or module in shared.opts.quanto_quantization:
|
||||
if module in shared.opts.sdnq_quantize_weights:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
|
@ -286,21 +220,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
|
|||
if debug:
|
||||
log.trace(f'Quantization: type=sdnq config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_bnb_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
log.trace(f'Quantization: type=bnb config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_quanto_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
log.trace(f'Quantization: type=quanto config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_ao_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
log.trace(f'Quantization: type=torchao config={kwargs.get("quantization_config", None)}')
|
||||
return kwargs
|
||||
kwargs = create_trt_config(kwargs, allow=allow, module=module, modules_to_not_convert=modules_to_not_convert)
|
||||
if kwargs is not None and 'quantization_config' in kwargs:
|
||||
if debug:
|
||||
|
|
@ -309,88 +228,6 @@ def create_config(kwargs = None, allow: bool = True, module: str = 'Model', modu
|
|||
return kwargs
|
||||
|
||||
|
||||
def load_torchao(msg='', silent=False):
|
||||
global ao # pylint: disable=global-statement
|
||||
if ao is not None:
|
||||
return ao
|
||||
if not installed('torchao'):
|
||||
install('torchao==0.10.0', quiet=True)
|
||||
log.warning('Quantization: torchao installed please restart')
|
||||
try:
|
||||
import torchao
|
||||
ao = torchao
|
||||
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
|
||||
from diffusers.utils import import_utils
|
||||
import_utils.is_torchao_available = lambda: True
|
||||
import_utils._torchao_available = True # pylint: disable=protected-access
|
||||
return ao
|
||||
except Exception as e:
|
||||
if len(msg) > 0:
|
||||
log.error(f"{msg} failed to import torchao: {e}")
|
||||
ao = None
|
||||
if not silent:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def load_bnb(msg='', silent=False):
|
||||
from modules import devices
|
||||
global bnb # pylint: disable=global-statement
|
||||
if bnb is not None:
|
||||
return bnb
|
||||
if not installed('bitsandbytes'):
|
||||
if devices.backend == 'cuda':
|
||||
# forcing a version will uninstall the multi-backend-refactor branch of bnb
|
||||
install('bitsandbytes==0.47.0', quiet=True)
|
||||
log.warning('Quantization: bitsandbytes installed please restart')
|
||||
try:
|
||||
import bitsandbytes
|
||||
bnb = bitsandbytes
|
||||
from diffusers.utils import import_utils
|
||||
import_utils._bitsandbytes_available = True # pylint: disable=protected-access
|
||||
import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
|
||||
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Quantization: type=bitsandbytes version={bnb.__version__} fn={fn}') # pylint: disable=protected-access
|
||||
return bnb
|
||||
except Exception as e:
|
||||
if len(msg) > 0:
|
||||
log.error(f"{msg} failed to import bitsandbytes: {e}")
|
||||
bnb = None
|
||||
if not silent:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def load_quanto(msg='', silent=False):
|
||||
global optimum_quanto # pylint: disable=global-statement
|
||||
if optimum_quanto is not None:
|
||||
return optimum_quanto
|
||||
if not installed('optimum-quanto'):
|
||||
install('optimum-quanto==0.2.7', quiet=True)
|
||||
log.warning('Quantization: optimum-quanto installed please restart')
|
||||
try:
|
||||
from optimum import quanto # pylint: disable=no-name-in-module
|
||||
# disable device specific tensors because the model can't be moved between cpu and gpu with them
|
||||
quanto.tensor.weights.qbits.WeightQBitsTensor.create = lambda *args, **kwargs: quanto.tensor.weights.qbits.WeightQBitsTensor(*args, **kwargs)
|
||||
optimum_quanto = quanto
|
||||
fn = f'{sys._getframe(3).f_code.co_name}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
|
||||
from diffusers.utils import import_utils
|
||||
import_utils.is_optimum_quanto_available = lambda: True
|
||||
import_utils._optimum_quanto_available = True # pylint: disable=protected-access
|
||||
import_utils._optimum_quanto_version = quanto.__version__ # pylint: disable=protected-access
|
||||
import_utils._replace_with_quanto_layers = diffusers.quantizers.quanto.utils._replace_with_quanto_layers # pylint: disable=protected-access
|
||||
return optimum_quanto
|
||||
except Exception as e:
|
||||
if len(msg) > 0:
|
||||
log.error(f"{msg} failed to import optimum.quanto: {e}")
|
||||
optimum_quanto = None
|
||||
if not silent:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def load_trt(msg='', silent=False):
|
||||
global trt # pylint: disable=global-statement
|
||||
if trt is not None:
|
||||
|
|
@ -642,138 +479,6 @@ def sdnq_quantize_weights(sd_model):
|
|||
return sd_model
|
||||
|
||||
|
||||
def optimum_quanto_model(model, op=None, sd_model=None, weights=None, activations=None):
|
||||
from modules import devices, shared
|
||||
quanto = load_quanto('Quantize model: type=Optimum Quanto')
|
||||
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
|
||||
if model.__class__.__name__ in {"FluxTransformer2DModel", "ChromaTransformer2DModel"}: # LayerNorm is not supported
|
||||
exclude_list = ["transformer_blocks.*.norm1.norm", "transformer_blocks.*.norm2", "transformer_blocks.*.norm1_context.norm", "transformer_blocks.*.norm2_context", "single_transformer_blocks.*.norm.norm", "norm_out.norm"]
|
||||
if model.__class__.__name__ == "ChromaTransformer2DModel":
|
||||
# we ignore the distilled guidance layer because it degrades quality too much
|
||||
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
|
||||
exclude_list.append("distilled_guidance_layer.*")
|
||||
elif model.__class__.__name__ == "QwenImageTransformer2DModel":
|
||||
exclude_list = ["transformer_blocks.0.img_mod.1.weight", "time_text_embed", "img_in", "txt_in", "proj_out", "norm_out", "pos_embed"]
|
||||
else:
|
||||
exclude_list = None
|
||||
weights = getattr(quanto, weights) if weights is not None else getattr(quanto, shared.opts.optimum_quanto_weights_type)
|
||||
if activations is not None:
|
||||
activations = getattr(quanto, activations) if activations != 'none' else None
|
||||
elif shared.opts.optimum_quanto_activations_type != 'none':
|
||||
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
|
||||
else:
|
||||
activations = None
|
||||
model.eval()
|
||||
backup_embeddings = None
|
||||
if hasattr(model, "get_input_embeddings"):
|
||||
backup_embeddings = copy.deepcopy(model.get_input_embeddings())
|
||||
quanto.quantize(model, weights=weights, activations=activations, exclude=exclude_list)
|
||||
quanto.freeze(model)
|
||||
if hasattr(model, "set_input_embeddings") and backup_embeddings is not None:
|
||||
model.set_input_embeddings(backup_embeddings)
|
||||
if op is not None and shared.opts.optimum_quanto_shuffle_weights:
|
||||
if quant_last_model_name is not None:
|
||||
if "." in quant_last_model_name:
|
||||
last_model_names = quant_last_model_name.split(".")
|
||||
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
|
||||
else:
|
||||
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
if shared.cmd_opts.medvram or shared.cmd_opts.lowvram or shared.opts.diffusers_offload_mode != "none":
|
||||
quant_last_model_name = op
|
||||
quant_last_model_device = model.device
|
||||
else:
|
||||
quant_last_model_name = None
|
||||
quant_last_model_device = None
|
||||
model.to(devices.device)
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
return model
|
||||
|
||||
|
||||
def optimum_quanto_weights(sd_model):
|
||||
try:
|
||||
t0 = time.time()
|
||||
from modules import shared, devices, sd_models
|
||||
if shared.opts.diffusers_offload_mode in {"balanced", "sequential"}:
|
||||
log.warning(f"Quantization: type=Optimum.quanto offload={shared.opts.diffusers_offload_mode} not compatible")
|
||||
return sd_model
|
||||
log.info(f"Quantization: type=Optimum.quanto: modules={shared.opts.optimum_quanto_weights}")
|
||||
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
|
||||
quanto = load_quanto()
|
||||
|
||||
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_model, shared.opts.optimum_quanto_weights, op="optimum-quanto")
|
||||
if quant_last_model_name is not None:
|
||||
if "." in quant_last_model_name:
|
||||
last_model_names = quant_last_model_name.split(".")
|
||||
getattr(getattr(sd_model, last_model_names[0]), last_model_names[1]).to(quant_last_model_device)
|
||||
else:
|
||||
getattr(sd_model, quant_last_model_name).to(quant_last_model_device)
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
quant_last_model_name = None
|
||||
quant_last_model_device = None
|
||||
|
||||
if shared.opts.optimum_quanto_activations_type != 'none':
|
||||
activations = getattr(quanto, shared.opts.optimum_quanto_activations_type)
|
||||
else:
|
||||
activations = None
|
||||
|
||||
if activations is not None:
|
||||
def optimum_quanto_freeze(model, op=None, sd_model=None): # pylint: disable=unused-argument
|
||||
quanto.freeze(model)
|
||||
return model
|
||||
if shared.opts.diffusers_offload_mode == "model":
|
||||
sd_model.enable_model_cpu_offload(device=devices.device)
|
||||
if hasattr(sd_model, "encode_prompt"):
|
||||
original_encode_prompt = sd_model.encode_prompt
|
||||
def encode_prompt(*args, **kwargs):
|
||||
embeds = original_encode_prompt(*args, **kwargs)
|
||||
sd_model.maybe_free_model_hooks() # Diffusers keeps the TE on VRAM
|
||||
return embeds
|
||||
sd_model.encode_prompt = encode_prompt
|
||||
else:
|
||||
sd_models.move_model(sd_model, devices.device)
|
||||
with quanto.Calibration(momentum=0.9):
|
||||
sd_model(prompt="dummy prompt", num_inference_steps=10)
|
||||
sd_model = sd_models.apply_function_to_model(sd_model, optimum_quanto_freeze, shared.opts.optimum_quanto_weights, op="optimum-quanto-freeze")
|
||||
if shared.opts.diffusers_offload_mode == "model":
|
||||
sd_models.disable_offload(sd_model)
|
||||
sd_models.move_model(sd_model, devices.cpu)
|
||||
if hasattr(sd_model, "encode_prompt"):
|
||||
sd_model.encode_prompt = original_encode_prompt
|
||||
devices.torch_gc(force=True, reason='quanto')
|
||||
|
||||
t1 = time.time()
|
||||
log.info(f"Quantization: type=Optimum.quanto time={t1-t0:.2f}")
|
||||
except Exception as e:
|
||||
log.warning(f"Quantization: type=Optimum.quanto {e}")
|
||||
return sd_model
|
||||
|
||||
|
||||
def torchao_quantization(sd_model):
|
||||
from modules import shared, devices, sd_models
|
||||
torchao = load_torchao()
|
||||
q = torchao.quantization
|
||||
|
||||
fn = getattr(q, shared.opts.torchao_quantization_type, None)
|
||||
if fn is None:
|
||||
log.error(f"Quantization: type=TorchAO type={shared.opts.torchao_quantization_type} not supported")
|
||||
return sd_model
|
||||
def torchao_model(model, op=None, sd_model=None): # pylint: disable=unused-argument
|
||||
q.quantize_(model, fn(), device=devices.device)
|
||||
return model
|
||||
|
||||
log.info(f"Quantization: type=TorchAO pipe={sd_model.__class__.__name__} quant={shared.opts.torchao_quantization_type} fn={fn} targets={shared.opts.torchao_quantization}")
|
||||
try:
|
||||
t0 = time.time()
|
||||
sd_models.apply_function_to_model(sd_model, torchao_model, shared.opts.torchao_quantization, op="torchao")
|
||||
t1 = time.time()
|
||||
log.info(f"Quantization: type=TorchAO time={t1-t0:.2f}")
|
||||
except Exception as e:
|
||||
log.error(f"Quantization: type=TorchAO {e}")
|
||||
setup_logging() # torchao uses dynamo which messes with logging so reset is needed
|
||||
return sd_model
|
||||
|
||||
|
||||
def get_dit_args(load_config: dict | None = None, module: str | None = None, device_map: bool = False, allow_quant: bool = True, modules_to_not_convert: list | None = None, modules_dtype_dict: dict | None = None):
|
||||
from modules import shared, devices
|
||||
config = {} if load_config is None else load_config.copy()
|
||||
|
|
@ -810,12 +515,6 @@ def do_post_load_quant(sd_model, allow=True):
|
|||
if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')):
|
||||
log.debug('Load model: post_quant=sdnq')
|
||||
sd_model = sdnq_quantize_weights(sd_model)
|
||||
if len(shared.opts.optimum_quanto_weights) > 0:
|
||||
log.debug('Load model: post_quant=quanto')
|
||||
sd_model = optimum_quanto_weights(sd_model)
|
||||
if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')):
|
||||
log.debug('Load model: post_quant=torchao')
|
||||
sd_model = torchao_quantization(sd_model)
|
||||
if shared.opts.layerwise_quantization:
|
||||
log.debug('Load model: post_quant=layerwise')
|
||||
apply_layerwise(sd_model)
|
||||
|
|
|
|||
|
|
@ -62,16 +62,6 @@ def load_t5(name=None, cache_dir=None):
|
|||
elif 'fp16' in name.lower():
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'fp4' in name.lower():
|
||||
model_quant.load_bnb('Load model: type=T5')
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'fp8' in name.lower():
|
||||
model_quant.load_bnb('Load model: type=T5')
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'int8' in name.lower():
|
||||
from modules.model_quant import create_sdnq_config
|
||||
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
|
||||
|
|
@ -84,18 +74,6 @@ def load_t5(name=None, cache_dir=None):
|
|||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'qint4' in name.lower():
|
||||
model_quant.load_quanto('Load model: type=T5')
|
||||
quantization_config = transformers.QuantoConfig(weights='int4')
|
||||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif 'qint8' in name.lower():
|
||||
model_quant.load_quanto('Load model: type=T5')
|
||||
quantization_config = transformers.QuantoConfig(weights='int8')
|
||||
if quantization_config is not None:
|
||||
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
||||
|
||||
elif '/' in name:
|
||||
log.debug(f'Load model: type=T5 repo={name}')
|
||||
quant_config = model_quant.create_config(module='TE')
|
||||
|
|
|
|||
|
|
@ -93,7 +93,10 @@ class Options:
|
|||
|
||||
def set(self, key, value):
|
||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||
oldval = self.data.get(key, None)
|
||||
if key in self.secrets:
|
||||
oldval = self.secrets.get(key, None)
|
||||
else:
|
||||
oldval = self.data.get(key, None)
|
||||
if oldval is None:
|
||||
if key in self.data_labels:
|
||||
oldval = self.data_labels[key].default
|
||||
|
|
|
|||
|
|
@ -176,6 +176,8 @@ class UpscalerESRGAN(Upscaler):
|
|||
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ def process_samples(p: StableDiffusionProcessing, samples):
|
|||
split_tone_balance=getattr(p, 'grading_split_tone_balance', 0.5),
|
||||
vignette=getattr(p, 'grading_vignette', 0.0),
|
||||
grain=getattr(p, 'grading_grain', 0.0),
|
||||
lut_file=getattr(p, 'grading_lut_file', ''),
|
||||
lut_cube_file=getattr(p, 'grading_lut_file', ''),
|
||||
lut_strength=getattr(p, 'grading_lut_strength', 1.0),
|
||||
)
|
||||
if processing_grading.is_active(grading_params):
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class GradingParams:
|
|||
vignette: float = 0.0
|
||||
grain: float = 0.0
|
||||
# lut
|
||||
lut_file: str = ""
|
||||
lut_cube_file: str = ""
|
||||
lut_strength: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
@ -179,17 +179,17 @@ def _apply_color_temp(img: torch.Tensor, kelvin: float) -> torch.Tensor:
|
|||
return (img * scales).clamp(0, 1)
|
||||
|
||||
|
||||
def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Image:
|
||||
def _apply_lut(image: Image.Image, lut_cube_file: str, strength: float) -> Image.Image:
|
||||
"""Apply .cube LUT file via pillow-lut-tools."""
|
||||
if not lut_file or not os.path.isfile(lut_file):
|
||||
if not lut_cube_file or not os.path.isfile(lut_cube_file):
|
||||
return image
|
||||
pillow_lut = _ensure_pillow_lut()
|
||||
try:
|
||||
cube = pillow_lut.load_cube_file(lut_file)
|
||||
cube = pillow_lut.load_cube_file(lut_cube_file)
|
||||
if strength != 1.0:
|
||||
cube = pillow_lut.amplify_lut(cube, strength)
|
||||
result = image.filter(cube)
|
||||
debug(f'Grading LUT: file={os.path.basename(lut_file)} strength={strength}')
|
||||
debug(f'Grading LUT: file={os.path.basename(lut_cube_file)} strength={strength}')
|
||||
return result
|
||||
except Exception as e:
|
||||
log.error(f'Grading LUT: {e}')
|
||||
|
|
@ -198,8 +198,8 @@ def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Imag
|
|||
|
||||
def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
|
||||
"""Full grading pipeline: PIL -> GPU tensor -> kornia ops -> PIL."""
|
||||
log.debug(f"Grading: params={params}")
|
||||
kornia = _ensure_kornia()
|
||||
debug(f'Grading: params={params}')
|
||||
arr = np.array(image).astype(np.float32) / 255.0
|
||||
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
|
||||
tensor = tensor.to(device=devices.device, dtype=devices.dtype)
|
||||
|
|
@ -246,7 +246,7 @@ def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
|
|||
result = Image.fromarray(arr)
|
||||
|
||||
# LUT applied last (CPU, via pillow-lut-tools)
|
||||
if params.lut_file:
|
||||
result = _apply_lut(result, params.lut_file, params.lut_strength)
|
||||
if params.lut_cube_file:
|
||||
result = _apply_lut(result, params.lut_cube_file, params.lut_strength)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -76,9 +76,14 @@ class ScriptPostprocessingRunner:
|
|||
script.controls = wrap_call(script.ui, script.filename, "ui")
|
||||
if script.controls is None:
|
||||
script.controls = {}
|
||||
for control in script.controls.values():
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
inputs += list(script.controls.values())
|
||||
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
|
||||
for control in script.controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
inputs += script.controls
|
||||
else:
|
||||
for control in script.controls.values():
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
inputs += list(script.controls.values())
|
||||
script.args_to = len(inputs)
|
||||
|
||||
def scripts_in_preferred_order(self):
|
||||
|
|
@ -109,11 +114,16 @@ class ScriptPostprocessingRunner:
|
|||
for script in self.scripts_in_preferred_order():
|
||||
jobid = shared.state.begin(script.name)
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
process_args = {}
|
||||
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
|
||||
process_args[name] = value
|
||||
log.debug(f'Process: script="{script.name}" args={process_args}')
|
||||
script.process(pp, **process_args)
|
||||
process_args = []
|
||||
process_kwargs = {}
|
||||
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
|
||||
for _control, value in zip(script.controls, script_args, strict=False):
|
||||
process_args.append(value)
|
||||
else:
|
||||
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
|
||||
process_kwargs[name] = value
|
||||
log.debug(f'Process: script="{script.name}" args={process_args} kwargs={process_kwargs}')
|
||||
script.process(pp, *process_args, **process_kwargs)
|
||||
shared.state.end(jobid)
|
||||
|
||||
def create_args_for_run(self, scripts_args):
|
||||
|
|
@ -139,9 +149,14 @@ class ScriptPostprocessingRunner:
|
|||
continue
|
||||
jobid = shared.state.begin(script.name)
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
process_args = {}
|
||||
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
|
||||
process_args[name] = value
|
||||
log.debug(f'Postprocess: script={script.name} args={process_args}')
|
||||
script.postprocess(filenames, **process_args)
|
||||
process_args = []
|
||||
process_kwargs = {}
|
||||
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
|
||||
for _control, value in zip(script.controls, script_args, strict=False):
|
||||
process_args.append(value)
|
||||
else:
|
||||
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
|
||||
process_kwargs[name] = value
|
||||
log.debug(f'Postprocess: script={script.name} args={process_args} kwargs={process_kwargs}')
|
||||
script.postprocess(filenames, *process_args, **process_kwargs)
|
||||
shared.state.end(jobid)
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ warn_once = False
|
|||
|
||||
|
||||
class CheckpointInfo:
|
||||
def __init__(self, filename, sha=None, subfolder=None):
|
||||
self.name = None
|
||||
def __init__(self, filename, name=None, sha=None, subfolder=None, model_type: str = 'checkpoint'):
|
||||
self.name = name
|
||||
self.hash = sha
|
||||
self.filename = filename
|
||||
self.type = ''
|
||||
|
|
@ -62,9 +62,9 @@ class CheckpointInfo:
|
|||
self.sha256 = None
|
||||
self.type = 'unknown'
|
||||
elif os.path.isfile(filename): # ckpt or safetensor
|
||||
self.name = relname
|
||||
self.name = self.name or relname
|
||||
self.filename = filename
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
|
||||
self.sha256 = hashes.sha256_from_cache(self.filename, f"{model_type}/{relname}") or hashes.sha256_from_cache(self.filename, f"{model_type}/{name}")
|
||||
self.type = ext
|
||||
if 'nf4' in filename:
|
||||
self.type = 'transformer'
|
||||
|
|
@ -74,12 +74,12 @@ class CheckpointInfo:
|
|||
else:
|
||||
repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']]
|
||||
if len(repo) == 0:
|
||||
self.name = filename
|
||||
self.name = self.name or filename
|
||||
self.filename = filename
|
||||
self.sha256 = None
|
||||
self.type = 'unknown'
|
||||
else:
|
||||
self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
|
||||
self.name = self.name or os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]["name"])
|
||||
self.filename = repo[0]['path']
|
||||
self.sha256 = repo[0]['hash']
|
||||
self.type = 'diffusers'
|
||||
|
|
@ -109,7 +109,7 @@ class CheckpointInfo:
|
|||
return self.shorthash
|
||||
|
||||
def __str__(self):
|
||||
return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" hash={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")'
|
||||
return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" sha256={self.sha256} sha={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")'
|
||||
|
||||
|
||||
def setup_model():
|
||||
|
|
@ -160,7 +160,7 @@ def list_models():
|
|||
checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
|
||||
|
||||
|
||||
def update_model_hashes():
|
||||
def update_model_hashes(model_list: dict = None, model_type: str = 'checkpoint'):
|
||||
def update_model_hashes_table(rows):
|
||||
html = """
|
||||
<table class="simple-table">
|
||||
|
|
@ -186,14 +186,16 @@ def update_model_hashes():
|
|||
log.error(f'Model list: row={row} {e}')
|
||||
return html.format(tbody=tbody)
|
||||
|
||||
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
|
||||
if model_list is None:
|
||||
model_list = checkpoints_list
|
||||
lst = [ckpt for ckpt in model_list.values() if ckpt.hash is None]
|
||||
for ckpt in lst:
|
||||
ckpt.hash = model_hash(ckpt.filename)
|
||||
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
|
||||
log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
|
||||
lst = [ckpt for ckpt in model_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
|
||||
log.info(f'Models list: hash missing={len(lst)} total={len(model_list)}')
|
||||
updated = []
|
||||
for ckpt in lst:
|
||||
ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
|
||||
ckpt.sha256 = hashes.sha256(ckpt.filename, f"{model_type}/{ckpt.name}")
|
||||
ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
|
||||
updated.append(ckpt)
|
||||
yield update_model_hashes_table(updated)
|
||||
|
|
|
|||
|
|
@ -798,8 +798,8 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di
|
|||
"requires_safety_checker": False, # sd15 specific but we cant know ahead of time
|
||||
# "use_safetensors": True,
|
||||
}
|
||||
if shared.opts.huggingface_token and len(shared.opts.huggingface_token) > 0:
|
||||
diffusers_load_config['token'] = shared.opts.huggingface_token
|
||||
# if shared.opts.huggingface_token and len(shared.opts.huggingface_token) > 0:
|
||||
# diffusers_load_config['token'] = shared.opts.huggingface_token
|
||||
if revision is not None:
|
||||
diffusers_load_config['revision'] = revision
|
||||
if shared.opts.diffusers_model_load_variant != 'default':
|
||||
|
|
|
|||
|
|
@ -102,3 +102,4 @@ def refresh_unet_list():
|
|||
name = os.path.splitext(basename)[0] if ".safetensors" in basename else basename
|
||||
unet_dict[name] = file
|
||||
log.info(f'Available UNets: path="{shared.opts.unet_dir}" items={len(unet_dict)}')
|
||||
return unet_dict
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ def list_samplers():
|
|||
modules.sd_samplers.set_samplers()
|
||||
return modules.sd_samplers.all_samplers
|
||||
|
||||
|
||||
log.debug('Initializing: default modes')
|
||||
startup_offload_mode, startup_offload_min_gpu, startup_offload_max_gpu, startup_cross_attention, startup_sdp_options, startup_sdp_choices, startup_sdp_override_options, startup_sdp_override_choices, startup_offload_always, startup_offload_never = get_default_modes(cmd_opts=cmd_opts, mem_stat=mem_stat)
|
||||
|
||||
|
|
|
|||
|
|
@ -439,6 +439,7 @@ def update_token_counter(text: str):
|
|||
from modules.extra_networks import parse_prompt
|
||||
|
||||
count_formatted = '0'
|
||||
max_length = 0
|
||||
visible = False
|
||||
|
||||
prompt, _ = parse_prompt(text)
|
||||
|
|
@ -475,7 +476,7 @@ def update_token_counter(text: str):
|
|||
token_counts = [len(group) - int(has_bos_token) - int(has_eos_token) for group in ids]
|
||||
if len(token_counts) > 1:
|
||||
visible = True
|
||||
count_formatted = f"{token_counts} {sum(token_counts)}" if shared.opts.prompt_detailed_tokens else str(sum(token_counts))
|
||||
count_formatted = f"{token_counts}/{sum(token_counts)}"
|
||||
elif len(token_counts) == 1 and token_counts[0] > 0:
|
||||
visible = True
|
||||
count_formatted = str(token_counts[0])
|
||||
|
|
|
|||
|
|
@ -166,26 +166,6 @@ def create_settings(cmd_opts):
|
|||
"nunchaku_attention": OptionInfo(False, "Nunchaku attention", gr.Checkbox),
|
||||
"nunchaku_offload": OptionInfo(False, "Nunchaku offloading", gr.Checkbox),
|
||||
|
||||
"bnb_quantization_sep": OptionInfo("<h2>BitsAndBytes</h2>", "", gr.HTML),
|
||||
"bnb_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "VAE"]}),
|
||||
"bnb_quantization_type": OptionInfo("nf4", "Quantization type", gr.Dropdown, {"choices": ["nf4", "fp8", "fp4"]}),
|
||||
"bnb_quantization_storage": OptionInfo("uint8", "Backend storage", gr.Dropdown, {"choices": ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]}),
|
||||
|
||||
"quanto_quantization_sep": OptionInfo("<h2>Optimum Quanto</h2>", "", gr.HTML),
|
||||
"quanto_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM"]}),
|
||||
"quanto_quantization_type": OptionInfo("int8", "Quantization weights type", gr.Dropdown, {"choices": ["float8", "int8", "int4", "int2"]}),
|
||||
|
||||
"optimum_quanto_sep": OptionInfo("<h2>Optimum Quanto: post-load</h2>", "", gr.HTML),
|
||||
"optimum_quanto_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "Control", "VAE"]}),
|
||||
"optimum_quanto_weights_type": OptionInfo("qint8", "Quantization weights type", gr.Dropdown, {"choices": ["qint8", "qfloat8_e4m3fn", "qfloat8_e5m2", "qint4", "qint2"]}),
|
||||
"optimum_quanto_activations_type": OptionInfo("none", "Quantization activations type ", gr.Dropdown, {"choices": ["none", "qint8", "qfloat8_e4m3fn", "qfloat8_e5m2"]}),
|
||||
"optimum_quanto_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox),
|
||||
|
||||
"torchao_sep": OptionInfo("<h2>TorchAO</h2>", "", gr.HTML),
|
||||
"torchao_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "LLM", "Control", "VAE"]}),
|
||||
"torchao_quantization_mode": OptionInfo("auto", "Quantization mode", gr.Dropdown, {"choices": ["auto", "pre", "post"]}),
|
||||
"torchao_quantization_type": OptionInfo("int8_weight_only", "Quantization type", gr.Dropdown, {"choices": ["int4_weight_only", "int8_dynamic_activation_int4_weight", "int8_weight_only", "int8_dynamic_activation_int8_weight", "float8_weight_only", "float8_dynamic_activation_float8_weight", "float8_static_activation_float8_weight"]}),
|
||||
|
||||
"layerwise_quantization_sep": OptionInfo("<h2>Layerwise Casting</h2>", "", gr.HTML),
|
||||
"layerwise_quantization": OptionInfo([], "Layerwise casting enabled", gr.CheckboxGroup, {"choices": ["Model", "TE"]}),
|
||||
"layerwise_quantization_storage": OptionInfo("float8_e4m3fn", "Layerwise casting storage", gr.Dropdown, {"choices": ["float8_e4m3fn", "float8_e5m2"]}),
|
||||
|
|
@ -194,14 +174,6 @@ def create_settings(cmd_opts):
|
|||
"trt_quantization_sep": OptionInfo("<h2>TensorRT</h2>", "", gr.HTML),
|
||||
"trt_quantization": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model"]}),
|
||||
"trt_quantization_type": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": ["int8", "int4", "fp8", "nf4", "nvfp4"]}),
|
||||
|
||||
"nncf_compress_sep": OptionInfo("<h2>NNCF: Neural Network Compression Framework</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights": OptionInfo([], "Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights_mode": OptionInfo("INT8_SYM", "Quantization type", gr.Dropdown, {"choices": ["INT8", "INT8_SYM", "FP8", "MXFP8", "INT4_ASYM", "INT4_SYM", "FP4", "MXFP4", "NF4"], "visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights_raito": OptionInfo(0, "Compress ratio", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
|
||||
"nncf_compress_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": cmd_opts.use_openvino}),
|
||||
"nncf_quantize": OptionInfo([], "Static Quantization enabled", gr.CheckboxGroup, {"choices": ["Model", "TE", "VAE"], "visible": cmd_opts.use_openvino}),
|
||||
"nncf_quantize_mode": OptionInfo("INT8", "OpenVINO activations mode", gr.Dropdown, {"choices": ["INT8", "FP8_E4M3", "FP8_E5M2"], "visible": cmd_opts.use_openvino}),
|
||||
}))
|
||||
# --- VAE & Text Encoder ---
|
||||
options_templates.update(options_section(('vae_encoder', "Variational Auto Encoder"), {
|
||||
|
|
|
|||
|
|
@ -584,6 +584,8 @@ def register_pages():
|
|||
register_page(ExtraNetworksPageLora())
|
||||
from modules.ui_extra_networks_wildcards import ExtraNetworksPageWildcards
|
||||
register_page(ExtraNetworksPageWildcards())
|
||||
from modules.ui_extra_networks_unet import ExtraNetworksPageUNets
|
||||
register_page(ExtraNetworksPageUNets())
|
||||
if shared.opts.latent_history > 0:
|
||||
from modules.ui_extra_networks_history import ExtraNetworksPageHistory
|
||||
register_page(ExtraNetworksPageHistory())
|
||||
|
|
@ -596,7 +598,7 @@ def get_pages(title=None):
|
|||
visible = shared.opts.extra_networks
|
||||
pages: list[ExtraNetworksPage] = []
|
||||
if 'All' in visible or visible == []: # default en sort order
|
||||
visible = ['Model', 'Lora', 'Style', 'Wildcards', 'Embedding', 'VAE', 'History', 'Hypernetwork']
|
||||
visible = ['Model', 'Lora', 'UNet/DiT', 'Style', 'Wildcards', 'Embedding', 'VAE', 'History', 'Hypernetwork']
|
||||
|
||||
titles = [page.title for page in shared.extra_networks]
|
||||
if title is None:
|
||||
|
|
@ -743,7 +745,7 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
|
|||
|
||||
with ui.tabs:
|
||||
def ui_tab_change(page):
|
||||
scan_visible = page in ['Model', 'Lora', 'VAE', 'Hypernetwork', 'Embedding']
|
||||
scan_visible = page in ['Model', 'Lora', 'VAE', 'UNet/DiT', 'Hypernetwork', 'Embedding']
|
||||
save_visible = page in ['Style']
|
||||
model_visible = page in ['Model']
|
||||
return [gr.update(visible=scan_visible), gr.update(visible=save_visible), gr.update(visible=model_visible)]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
import html
|
||||
import json
|
||||
import os
|
||||
from modules import shared, ui_extra_networks, sd_unet, hashes, modelstats
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
class ExtraNetworksPageUNets(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('UNet/DiT')
|
||||
|
||||
def refresh(self):
|
||||
return sd_unet.refresh_unet_list()
|
||||
|
||||
def list_items(self):
|
||||
for name, filename in sd_unet.unet_dict.items():
|
||||
try:
|
||||
size, mtime = modelstats.stat(filename)
|
||||
info = self.find_info(filename)
|
||||
version = self.find_version(None, info)
|
||||
record = {
|
||||
"type": 'UNet/DiT',
|
||||
"name": name,
|
||||
"alias": os.path.splitext(os.path.basename(filename))[0],
|
||||
"title": name,
|
||||
"filename": filename,
|
||||
"hash": hashes.sha256_from_cache(filename, f"unet/{name}"),
|
||||
"preview": self.find_preview(filename),
|
||||
"local_preview": f"{os.path.splitext(filename)[0]}.{shared.opts.samples_format}",
|
||||
"metadata": {},
|
||||
"onclick": '"' + html.escape(f"""return selectUNet({json.dumps(name)})""") + '"',
|
||||
"mtime": mtime,
|
||||
"size": size,
|
||||
"info": info,
|
||||
"description": self.find_description(filename, info),
|
||||
"version": version.get("baseModel", "N/A") if info else "N/A",
|
||||
}
|
||||
yield record
|
||||
except Exception as e:
|
||||
log.debug(f'Networks error: type=vae file="{filename}" {e}')
|
||||
|
||||
def allowed_directories_for_previews(self):
|
||||
return [v for v in [shared.opts.unet_dir] if v is not None]
|
||||
|
|
@ -12,6 +12,15 @@ from modules.shared import opts, log
|
|||
extra_ui = []
|
||||
|
||||
|
||||
def update_model_hashes():
|
||||
from modules import sd_unet, sd_checkpoint
|
||||
unets = {}
|
||||
for k, v in sd_unet.unet_dict.items():
|
||||
unets[k] = sd_checkpoint.CheckpointInfo(name=k, filename=v, model_type='unet')
|
||||
yield from sd_models.update_model_hashes(unets, model_type='unet')
|
||||
yield from sd_models.update_model_hashes(model_type='checkpoint')
|
||||
|
||||
|
||||
def create_ui():
|
||||
log.debug('UI initialize: tab=models')
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
|
@ -143,7 +152,7 @@ def create_ui():
|
|||
with gr.Row():
|
||||
model_table = gr.HTML(value='', elem_id="model_list_table")
|
||||
|
||||
model_checkhash_btn.click(fn=sd_models.update_model_hashes, inputs=[], outputs=[model_table])
|
||||
model_checkhash_btn.click(fn=update_model_hashes, inputs=[], outputs=[model_table])
|
||||
model_list_btn.click(fn=lambda: create_models_table(list(sd_models.checkpoints_list.values())), inputs=[], outputs=[model_table])
|
||||
|
||||
with gr.Tab(label="Metadata", elem_id="models_metadata_tab"):
|
||||
|
|
|
|||
|
|
@ -172,15 +172,15 @@ def create_latent_inputs(tab):
|
|||
hdr_sharpen = gr.Slider(minimum=-4.0, maximum=4.0, step=0.05, value=0, label="Latent sharpen", elem_id=f"{tab}_hdr_sharpen")
|
||||
hdr_color = gr.Slider(minimum=0.0, maximum=16.0, step=0.1, value=0.0, label="Latent color", elem_id=f"{tab}_hdr_color")
|
||||
with gr.Row(elem_id=f"{tab}_hdr_clamp_row"):
|
||||
hdr_clamp = gr.Checkbox(label="Clamp", value=False, elem_id=f"{tab}_hdr_clamp")
|
||||
hdr_boundary = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=4.0, label="Range", elem_id=f"{tab}_hdr_boundary")
|
||||
hdr_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label="Threshold", elem_id=f"{tab}_hdr_threshold")
|
||||
hdr_clamp = gr.Checkbox(label="Latent clamp", value=False, elem_id=f"{tab}_hdr_clamp")
|
||||
hdr_boundary = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=4.0, label="Latent range", elem_id=f"{tab}_hdr_boundary")
|
||||
hdr_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.95, label="Latent threshold", elem_id=f"{tab}_hdr_threshold")
|
||||
with gr.Row(elem_id=f"{tab}_hdr_max_row"):
|
||||
hdr_maximize = gr.Checkbox(label="Maximize", value=False, elem_id=f"{tab}_hdr_maximize")
|
||||
hdr_max_center = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=0.6, label="Center", elem_id=f"{tab}_hdr_max_center")
|
||||
hdr_max_boundary = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Max range", elem_id=f"{tab}_hdr_max_boundary")
|
||||
hdr_maximize = gr.Checkbox(label="Latent maximize", value=False, elem_id=f"{tab}_hdr_maximize")
|
||||
hdr_max_center = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=0.6, label="Latent center", elem_id=f"{tab}_hdr_max_center")
|
||||
hdr_max_boundary = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Latent max range", elem_id=f"{tab}_hdr_max_boundary")
|
||||
with gr.Row(elem_id=f"{tab}_hdr_color_row"):
|
||||
hdr_color_picker = gr.ColorPicker(label="Tint color", show_label=True, container=False, value=None, elem_id=f"{tab}_hdr_color_picker")
|
||||
hdr_color_picker = gr.ColorPicker(label="Latent tint", show_label=True, container=False, value=None, elem_id=f"{tab}_hdr_color_picker")
|
||||
hdr_tint_ratio = gr.Slider(label="Tint strength", minimum=-4.0, maximum=4.0, step=0.05, value=0.0, elem_id=f"{tab}_hdr_tint_ratio")
|
||||
return hdr_mode, hdr_brightness, hdr_color, hdr_sharpen, hdr_clamp, hdr_boundary, hdr_threshold, hdr_maximize, hdr_max_center, hdr_max_boundary, hdr_color_picker, hdr_tint_ratio, hdr_apply_hires
|
||||
|
||||
|
|
@ -197,7 +197,7 @@ def create_color_inputs(tab):
|
|||
grading_hue = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0, label='Hue', elem_id=f"{tab}_grading_hue")
|
||||
grading_gamma = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label='Gamma', elem_id=f"{tab}_grading_gamma")
|
||||
grading_sharpness = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0, label='Sharpness', elem_id=f"{tab}_grading_sharpness")
|
||||
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp (K)', elem_id=f"{tab}_grading_color_temp")
|
||||
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp', elem_id=f"{tab}_grading_color_temp")
|
||||
with gr.Group():
|
||||
gr.HTML('<h3>Tone</h3>')
|
||||
with gr.Row(elem_id=f"{tab}_grading_tone_row"):
|
||||
|
|
@ -211,7 +211,7 @@ def create_color_inputs(tab):
|
|||
with gr.Row(elem_id=f"{tab}_grading_split_row"):
|
||||
grading_shadows_tint = gr.ColorPicker(label="Shadows tint", value="#000000", elem_id=f"{tab}_grading_shadows_tint")
|
||||
grading_highlights_tint = gr.ColorPicker(label="Highlights tint", value="#ffffff", elem_id=f"{tab}_grading_highlights_tint")
|
||||
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Balance', elem_id=f"{tab}_grading_split_tone_balance")
|
||||
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Split tone balance', elem_id=f"{tab}_grading_split_tone_balance")
|
||||
with gr.Group():
|
||||
gr.HTML('<h3>Effects</h3>')
|
||||
with gr.Row(elem_id=f"{tab}_grading_effects_row"):
|
||||
|
|
@ -220,9 +220,9 @@ def create_color_inputs(tab):
|
|||
with gr.Group():
|
||||
gr.HTML('<h3>LUT</h3>')
|
||||
with gr.Row(elem_id=f"{tab}_grading_lut_row"):
|
||||
grading_lut_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
|
||||
grading_lut_cube_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
|
||||
grading_lut_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label='LUT strength', elem_id=f"{tab}_grading_lut_strength")
|
||||
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_file, grading_lut_strength
|
||||
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_cube_file, grading_lut_strength
|
||||
|
||||
|
||||
def create_sampler_and_steps_selection(choices, tabname, default_steps:int=20):
|
||||
|
|
|
|||
|
|
@ -396,6 +396,13 @@ def create_quicksettings(interfaces):
|
|||
inputs=[shared.settings_components['sd_vae'], dummy_component],
|
||||
outputs=[shared.settings_components['sd_vae'], text_settings],
|
||||
)
|
||||
button_set_unet = gr.Button("Change UNet", elem_id="change_unet", visible=False)
|
||||
button_set_unet.click(
|
||||
fn=lambda value, _: run_settings_single(value, key="sd_unet"),
|
||||
_js="function(v){ var res = desiredUNetName; desiredUNetName = ''; return [res || v, null]; }",
|
||||
inputs=[shared.settings_components["sd_unet"], dummy_component],
|
||||
outputs=[shared.settings_components["sd_unet"], text_settings],
|
||||
)
|
||||
|
||||
def reference_submit(model):
|
||||
if '@' not in model: # diffusers
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@
|
|||
"venv": ". venv/bin/activate",
|
||||
"start": ". venv/bin/activate; python launch.py --debug",
|
||||
"localize": "node cli/localize.js",
|
||||
"packages": ". venv/bin/activate && pip install --upgrade transformers accelerate huggingface_hub safetensors tokenizers peft pytorch_lightning pylint ruff",
|
||||
"packages": ". venv/bin/activate && pip install --upgrade accelerate huggingface_hub safetensors tokenizers peft pytorch_lightning pylint ruff",
|
||||
"format": ". venv/bin/activate && pre-commit run --all-files",
|
||||
"format-win": "venv\\scripts\\activate && pre-commit run --all-files",
|
||||
"eslint": "eslint . javascript/",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,216 @@
|
|||
"""Flux2/Klein-specific LoRA loading.
|
||||
|
||||
Handles:
|
||||
- Bare BFL-format keys in state dicts (adds diffusion_model. prefix for converter)
|
||||
- LoKR adapters via native module loading (bypasses diffusers PEFT system)
|
||||
|
||||
Installed via apply_patch() during pipeline loading.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from modules import shared, sd_models
|
||||
from modules.logger import log
|
||||
from modules.lora import network, network_lokr, lora_convert
|
||||
from modules.lora import lora_common as l
|
||||
|
||||
|
||||
BARE_FLUX_PREFIXES = ("single_blocks.", "double_blocks.", "img_in.", "txt_in.",
|
||||
"final_layer.", "time_in.", "single_stream_modulation.",
|
||||
"double_stream_modulation_")
|
||||
|
||||
# BFL -> diffusers module path mapping for Flux2/Klein
|
||||
F2_SINGLE_MAP = {
|
||||
'linear1': 'attn.to_qkv_mlp_proj',
|
||||
'linear2': 'attn.to_out',
|
||||
}
|
||||
F2_DOUBLE_MAP = {
|
||||
'img_attn.proj': 'attn.to_out.0',
|
||||
'txt_attn.proj': 'attn.to_add_out',
|
||||
'img_mlp.0': 'ff.linear_in',
|
||||
'img_mlp.2': 'ff.linear_out',
|
||||
'txt_mlp.0': 'ff_context.linear_in',
|
||||
'txt_mlp.2': 'ff_context.linear_out',
|
||||
}
|
||||
F2_QKV_MAP = {
|
||||
'img_attn.qkv': ('attn', ['to_q', 'to_k', 'to_v']),
|
||||
'txt_attn.qkv': ('attn', ['add_q_proj', 'add_k_proj', 'add_v_proj']),
|
||||
}
|
||||
|
||||
|
||||
def apply_lora_alphas(state_dict):
|
||||
"""Bake kohya-format .alpha scaling into lora_down weights and remove alpha keys.
|
||||
|
||||
Diffusers' Flux2 converter only handles lora_A/lora_B (or lora_down/lora_up) keys.
|
||||
Kohya-format LoRAs store per-layer alpha values as separate .alpha keys that the
|
||||
converter doesn't consume, causing a ValueError on leftover keys. This matches the
|
||||
approach used by _convert_kohya_flux_lora_to_diffusers for Flux 1.
|
||||
"""
|
||||
alpha_keys = [k for k in state_dict if k.endswith('.alpha')]
|
||||
if not alpha_keys:
|
||||
return state_dict
|
||||
for alpha_key in alpha_keys:
|
||||
base = alpha_key[:-len('.alpha')]
|
||||
down_key = f'{base}.lora_down.weight'
|
||||
if down_key not in state_dict:
|
||||
continue
|
||||
down_weight = state_dict[down_key]
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
state_dict[down_key] = down_weight * scale_down
|
||||
up_key = f'{base}.lora_up.weight'
|
||||
if up_key in state_dict:
|
||||
state_dict[up_key] = state_dict[up_key] * scale_up
|
||||
remaining = [k for k in state_dict if k.endswith('.alpha')]
|
||||
if remaining:
|
||||
log.debug(f'Network load: type=LoRA stripped {len(remaining)} orphaned alpha keys')
|
||||
for k in remaining:
|
||||
del state_dict[k]
|
||||
return state_dict
|
||||
|
||||
|
||||
def preprocess_f2_keys(state_dict):
|
||||
"""Add 'diffusion_model.' prefix to bare BFL-format keys so
|
||||
Flux2LoraLoaderMixin's format detection routes them to the converter."""
|
||||
if any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in state_dict):
|
||||
return state_dict
|
||||
if any(k.startswith(p) for k in state_dict for p in BARE_FLUX_PREFIXES):
|
||||
log.debug('Network load: type=LoRA adding diffusion_model prefix for bare BFL-format keys')
|
||||
state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def try_load_lokr(name, network_on_disk, lora_scale):
|
||||
"""Try loading a Flux2/Klein LoRA as LoKR native modules.
|
||||
|
||||
Returns a Network with native modules if the state dict contains LoKR keys,
|
||||
or None to fall through to the generic diffusers path.
|
||||
"""
|
||||
t0 = time.time()
|
||||
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
|
||||
if not any('.lokr_w1' in k for k in state_dict):
|
||||
return None
|
||||
net = load_lokr_native(name, network_on_disk, state_dict)
|
||||
if len(net.modules) == 0:
|
||||
log.error(f'Network load: type=LoKR name="{name}" no modules matched')
|
||||
return None
|
||||
log.debug(f'Network load: type=LoKR name="{name}" native modules={len(net.modules)} scale={lora_scale}')
|
||||
l.timer.activate += time.time() - t0
|
||||
return net
|
||||
|
||||
|
||||
def load_lokr_native(name, network_on_disk, state_dict):
|
||||
"""Load Flux2 LoKR as native modules applied at inference time.
|
||||
|
||||
Stores only the compact LoKR factors (w1, w2) and computes kron(w1, w2)
|
||||
on-the-fly during weight application. For fused QKV modules in double
|
||||
blocks, NetworkModuleLokrChunk computes the full Kronecker product and
|
||||
returns only its designated Q/K/V chunk, then frees the temporary.
|
||||
"""
|
||||
prefix = "diffusion_model."
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
lora_convert.assign_network_names_to_compvis_modules(sd_model)
|
||||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
if not key.endswith('.lokr_w1'):
|
||||
continue
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
base = key[len(prefix):].rsplit('.lokr_w1', 1)[0]
|
||||
lokr_weights = {}
|
||||
for suffix in ['lokr_w1', 'lokr_w2', 'lokr_w1_a', 'lokr_w1_b', 'lokr_w2_a', 'lokr_w2_b', 'lokr_t2', 'alpha']:
|
||||
full_key = f'{prefix}{base}.{suffix}'
|
||||
if full_key in state_dict:
|
||||
lokr_weights[suffix] = state_dict[full_key]
|
||||
|
||||
parts = base.split('.')
|
||||
block_type, block_idx, module_suffix = parts[0], parts[1], '.'.join(parts[2:])
|
||||
|
||||
targets = [] # (module_path, chunk_index, num_chunks)
|
||||
if block_type == 'single_blocks' and module_suffix in F2_SINGLE_MAP:
|
||||
path = f'single_transformer_blocks.{block_idx}.{F2_SINGLE_MAP[module_suffix]}'
|
||||
targets.append((path, None, None))
|
||||
elif block_type == 'double_blocks':
|
||||
if module_suffix in F2_DOUBLE_MAP:
|
||||
path = f'transformer_blocks.{block_idx}.{F2_DOUBLE_MAP[module_suffix]}'
|
||||
targets.append((path, None, None))
|
||||
elif module_suffix in F2_QKV_MAP:
|
||||
attn_prefix, proj_keys = F2_QKV_MAP[module_suffix]
|
||||
for i, proj_key in enumerate(proj_keys):
|
||||
path = f'transformer_blocks.{block_idx}.{attn_prefix}.{proj_key}'
|
||||
targets.append((path, i, len(proj_keys)))
|
||||
|
||||
for module_path, chunk_index, num_chunks in targets:
|
||||
network_key = "lora_transformer_" + module_path.replace(".", "_")
|
||||
sd_module = sd_model.network_layer_mapping.get(network_key)
|
||||
if sd_module is None:
|
||||
log.warning(f'Network load: type=LoKR module not found in mapping: {network_key}')
|
||||
continue
|
||||
weights = network.NetworkWeights(
|
||||
network_key=network_key,
|
||||
sd_key=network_key,
|
||||
w=dict(lokr_weights),
|
||||
sd_module=sd_module,
|
||||
)
|
||||
if chunk_index is not None:
|
||||
net.modules[network_key] = network_lokr.NetworkModuleLokrChunk(net, weights, chunk_index, num_chunks)
|
||||
else:
|
||||
net.modules[network_key] = network_lokr.NetworkModuleLokr(net, weights)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
patched = False
|
||||
|
||||
|
||||
def apply_patch():
|
||||
"""Patch Flux2LoraLoaderMixin.lora_state_dict to handle bare BFL-format keys.
|
||||
|
||||
When a LoRA file has bare BFL keys (no diffusion_model. prefix), the original
|
||||
lora_state_dict won't detect them as AI toolkit format. This patch checks for
|
||||
bare keys after the original returns and adds the prefix + re-runs conversion.
|
||||
"""
|
||||
global patched # pylint: disable=global-statement
|
||||
if patched:
|
||||
return
|
||||
patched = True
|
||||
|
||||
from diffusers.loaders.lora_pipeline import Flux2LoraLoaderMixin
|
||||
original_lora_state_dict = Flux2LoraLoaderMixin.lora_state_dict.__func__
|
||||
|
||||
@classmethod # pylint: disable=no-self-argument
|
||||
def patched_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs):
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = preprocess_f2_keys(pretrained_model_name_or_path_or_dict)
|
||||
pretrained_model_name_or_path_or_dict = apply_lora_alphas(pretrained_model_name_or_path_or_dict)
|
||||
elif isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
|
||||
path = str(pretrained_model_name_or_path_or_dict)
|
||||
if path.endswith('.safetensors'):
|
||||
try:
|
||||
from safetensors import safe_open
|
||||
with safe_open(path, framework="pt") as f:
|
||||
keys = list(f.keys())
|
||||
needs_load = (
|
||||
any(k.endswith('.alpha') for k in keys)
|
||||
or (not any(k.startswith("diffusion_model.") or k.startswith("base_model.model.") for k in keys)
|
||||
and any(k.startswith(p) for k in keys for p in BARE_FLUX_PREFIXES))
|
||||
)
|
||||
if needs_load:
|
||||
from safetensors.torch import load_file
|
||||
sd = load_file(path)
|
||||
sd = preprocess_f2_keys(sd)
|
||||
pretrained_model_name_or_path_or_dict = apply_lora_alphas(sd)
|
||||
except Exception:
|
||||
pass
|
||||
return original_lora_state_dict(cls, pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
Flux2LoraLoaderMixin.lora_state_dict = patched_lora_state_dict
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
import diffusers
|
||||
import transformers
|
||||
from modules import devices, model_quant
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer = None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
return transformer
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer, text_encoder_2 = None, None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
try:
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
|
||||
transformer, text_encoder_2 = None, None
|
||||
if debug:
|
||||
errors.display(e, 'FLUX:')
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
|
||||
try:
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": cache_dir,
|
||||
}
|
||||
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = None
|
||||
if 'flux.1-kontext' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
elif 'flux.1-dev' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
elif 'flux.1-schnell' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-fill' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-depth' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'shuttle' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
|
||||
else:
|
||||
log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
|
||||
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype, cache_dir=cache_dir)
|
||||
kwargs['transformer'].quantization_method = 'SVDQuant'
|
||||
if shared.opts.nunchaku_attention:
|
||||
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
|
||||
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
|
||||
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype, cache_dir=cache_dir)
|
||||
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
|
||||
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
|
||||
except Exception as e:
|
||||
log.error(f'Quantization: {e}')
|
||||
errors.display(e, 'Quantization:')
|
||||
return kwargs
|
||||
|
||||
|
||||
def load_transformer(file_path): # triggered by opts.sd_unet change
|
||||
if file_path is None or not os.path.exists(file_path):
|
||||
return None
|
||||
transformer = None
|
||||
quant = model_quant.get_quant(file_path)
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": shared.opts.hfcache_dir,
|
||||
}
|
||||
if quant is not None and quant != 'none':
|
||||
log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
|
||||
if 'gguf' in file_path.lower():
|
||||
from modules import ggml
|
||||
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant == "fp8":
|
||||
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'qint8', 'qint4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'fp8', 'fp4', 'nf4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif 'nf4' in quant:
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
else:
|
||||
quant_args = model_quant.create_bnb_config({})
|
||||
if quant_args:
|
||||
log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
|
||||
if transformer is None:
|
||||
log.error('Failed to load UNet model')
|
||||
shared.opts.sd_unet = 'Default'
|
||||
return transformer
|
||||
|
||||
|
||||
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
allow_post_quant = False
|
||||
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
|
||||
debug(f'Load model: type=FLUX config={diffusers_load_config}')
|
||||
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
|
||||
# unload current model
|
||||
sd_models.unload_model_weights()
|
||||
shared.sd_model = None
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
|
||||
if shared.opts.teacache_enabled:
|
||||
from modules import teacache
|
||||
log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
|
||||
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
|
||||
|
||||
# load overrides if any
|
||||
if shared.opts.sd_unet != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
|
||||
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
|
||||
if transformer is None:
|
||||
shared.opts.sd_unet = 'Default'
|
||||
sd_unet.failed_unet.append(shared.opts.sd_unet)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load UNet: {e}")
|
||||
shared.opts.sd_unet = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX UNet:')
|
||||
if shared.opts.sd_text_encoder != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
|
||||
from modules.model_te import load_t5, load_vit_l
|
||||
if 'vit-l' in shared.opts.sd_text_encoder.lower():
|
||||
text_encoder_1 = load_vit_l()
|
||||
else:
|
||||
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load T5: {e}")
|
||||
shared.opts.sd_text_encoder = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX T5:')
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
if os.path.exists(vae_file):
|
||||
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX VAE:')
|
||||
|
||||
# load quantized components if any
|
||||
if prequantized == 'nf4':
|
||||
try:
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX NF4:')
|
||||
if prequantized == 'qint8' or prequantized == 'qint4':
|
||||
try:
|
||||
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
# initialize pipeline with pre-loaded components
|
||||
kwargs = {}
|
||||
if transformer is not None:
|
||||
kwargs['transformer'] = transformer
|
||||
sd_unet.loaded_unet = shared.opts.sd_unet
|
||||
if text_encoder_1 is not None:
|
||||
kwargs['text_encoder'] = text_encoder_1
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if text_encoder_2 is not None:
|
||||
kwargs['text_encoder_2'] = text_encoder_2
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if vae is not None:
|
||||
kwargs['vae'] = vae
|
||||
if repo_id == 'sayakpaul/flux.1-dev-nf4':
|
||||
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
|
||||
if 'Fill' in repo_id:
|
||||
cls = diffusers.FluxFillPipeline
|
||||
elif 'Canny' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Depth' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Kontext' in repo_id:
|
||||
cls = diffusers.FluxKontextPipeline
|
||||
from diffusers import pipelines
|
||||
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
|
||||
|
||||
else:
|
||||
cls = diffusers.FluxPipeline
|
||||
log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
|
||||
for c in kwargs:
|
||||
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
|
||||
log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
|
||||
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
|
||||
try:
|
||||
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
|
||||
log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
|
||||
fn = checkpoint_info.path
|
||||
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
|
||||
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
|
||||
if fn.endswith('.safetensors') and os.path.isfile(fn):
|
||||
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
else:
|
||||
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
|
||||
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
|
||||
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
|
||||
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
||||
|
||||
# release memory
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
for k in kwargs.keys():
|
||||
kwargs[k] = None
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe, allow_post_quant
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
"""
|
||||
Copied from: https://github.com/huggingface/diffusers/issues/9165
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.quantizers.quantizers_utils import get_module_from_name
|
||||
from huggingface_hub import hf_hub_download
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
|
||||
import safetensors.torch
|
||||
from modules import shared, devices, model_quant
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
|
||||
|
||||
|
||||
def _replace_with_bnb_linear(
|
||||
model,
|
||||
method="nf4",
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
bnb = model_quant.load_bnb('Load model: type=FLUX')
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, nn.Linear):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
if method == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt( # pylint: disable=protected-access
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
has_been_replaced = True
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit( # pylint: disable=protected-access
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
compute_dtype=devices.dtype,
|
||||
compress_statistics=False,
|
||||
quant_type="nf4",
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module) # pylint: disable=protected-access
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False) # pylint: disable=protected-access
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_bnb_linear(
|
||||
module,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def check_quantized_param(
|
||||
model,
|
||||
param_name: str,
|
||||
) -> bool:
|
||||
bnb = model_quant.load_bnb('Load model: type=FLUX')
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # pylint: disable=protected-access
|
||||
# Add here check for loaded components' dtypes once serialization is implemented
|
||||
return True
|
||||
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
|
||||
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
|
||||
# but it would wrongly use uninitialized weight there.
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def create_quantized_param(
|
||||
model,
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict=None,
|
||||
unexpected_keys=None,
|
||||
pre_quantized=False
|
||||
):
|
||||
bnb = model_quant.load_bnb('Load model: type=FLUX')
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if tensor_name not in module._parameters: # pylint: disable=protected-access
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
if tensor_name == "bias":
|
||||
if param_value is None:
|
||||
new_value = old_value.to(target_device)
|
||||
else:
|
||||
new_value = param_value.to(target_device)
|
||||
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
|
||||
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
|
||||
return
|
||||
|
||||
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # pylint: disable=protected-access
|
||||
raise ValueError("this function only loads `Linear4bit components`")
|
||||
if (
|
||||
old_value.device == torch.device("meta")
|
||||
and target_device not in ["meta", torch.device("meta")]
|
||||
and param_value is None
|
||||
):
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
|
||||
|
||||
if pre_quantized:
|
||||
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
|
||||
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
|
||||
quantized_stats = {}
|
||||
for k, v in state_dict.items():
|
||||
# `startswith` to counter for edge cases where `param_name`
|
||||
# substring can be present in multiple places in the `state_dict`
|
||||
if param_name + "." in k and k.startswith(param_name):
|
||||
quantized_stats[k] = v
|
||||
if unexpected_keys is not None and k in unexpected_keys:
|
||||
unexpected_keys.remove(k)
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=param_value,
|
||||
quantized_stats=quantized_stats,
|
||||
requires_grad=False,
|
||||
device=target_device,
|
||||
)
|
||||
else:
|
||||
new_value = param_value.to("cpu")
|
||||
kwargs = old_value.__dict__
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
|
||||
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
|
||||
|
||||
|
||||
def load_flux_nf4(checkpoint_info, prequantized: bool = True):
|
||||
transformer = None
|
||||
text_encoder_2 = None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
if os.path.exists(repo_path) and os.path.isfile(repo_path):
|
||||
ckpt_path = repo_path
|
||||
elif os.path.exists(repo_path) and os.path.isdir(repo_path) and os.path.exists(os.path.join(repo_path, "diffusion_pytorch_model.safetensors")):
|
||||
ckpt_path = os.path.join(repo_path, "diffusion_pytorch_model.safetensors")
|
||||
else:
|
||||
ckpt_path = hf_hub_download(repo_path, filename="diffusion_pytorch_model.safetensors", cache_dir=shared.opts.diffusers_dir)
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
|
||||
if 'sayakpaul' in repo_path:
|
||||
converted_state_dict = original_state_dict # already converted
|
||||
else:
|
||||
try:
|
||||
converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX Failed to convert UNET: {e}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'FLUX convert:')
|
||||
converted_state_dict = original_state_dict
|
||||
|
||||
with init_empty_weights():
|
||||
from diffusers import FluxTransformer2DModel
|
||||
config = FluxTransformer2DModel.load_config(os.path.join('configs', 'flux'), subfolder="transformer")
|
||||
transformer = FluxTransformer2DModel.from_config(config).to(devices.dtype)
|
||||
expected_state_dict_keys = list(transformer.state_dict().keys())
|
||||
|
||||
_replace_with_bnb_linear(transformer, "nf4")
|
||||
|
||||
try:
|
||||
for param_name, param in converted_state_dict.items():
|
||||
if param_name not in expected_state_dict_keys:
|
||||
continue
|
||||
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
|
||||
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
|
||||
param = param.to(devices.dtype)
|
||||
if not check_quantized_param(transformer, param_name):
|
||||
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
|
||||
else:
|
||||
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
|
||||
except Exception as e:
|
||||
transformer, text_encoder_2 = None, None
|
||||
log.error(f"Load model: type=FLUX failed to load UNET: {e}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'FLUX:')
|
||||
|
||||
del original_state_dict
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return transformer, text_encoder_2
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, model_quant
|
||||
from modules.logger import log
|
||||
|
||||
|
||||
debug = log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
|
@ -41,18 +41,6 @@ def load_flux(checkpoint_info, diffusers_load_config=None):
|
|||
transformer = None
|
||||
text_encoder_2 = None
|
||||
|
||||
# handle prequantized models
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
if prequantized == 'nf4':
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
transformer, text_encoder_2 = load_flux_nf4(checkpoint_info)
|
||||
elif prequantized == 'qint8' or prequantized == 'qint4':
|
||||
from pipelines.flux.flux_quanto import load_flux_quanto
|
||||
transformer, text_encoder_2 = load_flux_quanto(checkpoint_info)
|
||||
elif prequantized == 'fp4' or prequantized == 'fp8':
|
||||
from pipelines.flux.flux_bnb import load_flux_bnb
|
||||
transformer = load_flux_bnb(checkpoint_info, diffusers_load_config)
|
||||
|
||||
# handle transformer svdquant if available, t5 is handled inside load_text_encoder
|
||||
if transformer is None and model_quant.check_nunchaku('Model'):
|
||||
from pipelines.flux.flux_nunchaku import load_flux_nunchaku
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ def load_flux2(checkpoint_info, diffusers_load_config=None):
|
|||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux2"] = diffusers.Flux2Pipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux2"] = diffusers.Flux2Pipeline
|
||||
|
||||
from pipelines.flux import flux2_lora
|
||||
flux2_lora.apply_patch()
|
||||
|
||||
del text_encoder
|
||||
del transformer
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ def load_flux2_klein(checkpoint_info, diffusers_load_config=None):
|
|||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux2klein"] = diffusers.Flux2KleinPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux2klein"] = diffusers.Flux2KleinPipeline
|
||||
|
||||
from pipelines.flux import flux2_lora
|
||||
flux2_lora.apply_patch()
|
||||
|
||||
del text_encoder
|
||||
del transformer
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
|
|
|
|||
|
|
@ -27,13 +27,13 @@ peft==0.18.1
|
|||
httpx==0.28.1
|
||||
requests==2.32.3
|
||||
tqdm==4.67.3
|
||||
accelerate==1.12.0
|
||||
accelerate==1.13.0
|
||||
einops==0.8.1
|
||||
huggingface_hub==1.5.0
|
||||
huggingface_hub==1.7.2
|
||||
numpy==2.1.2
|
||||
pandas==2.3.1
|
||||
protobuf==6.33.5
|
||||
pytorch_lightning==2.6.0
|
||||
pytorch_lightning==2.6.1
|
||||
urllib3==1.26.19
|
||||
Pillow==10.4.0
|
||||
timm==1.0.24
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from modules import scripts_postprocessing, ui_sections, processing_grading
|
||||
|
||||
|
||||
class ScriptPostprocessingColorGrading(scripts_postprocessing.ScriptPostprocessing):
|
||||
name = "Color Grading"
|
||||
|
||||
def ui(self):
|
||||
ui_controls = ui_sections.create_color_inputs('process')
|
||||
ui_controls_dict = {control.label.replace(" ", "_").replace(".", "").lower(): control for control in ui_controls}
|
||||
return ui_controls_dict
|
||||
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
grading_params = processing_grading.GradingParams(*args, **kwargs)
|
||||
if processing_grading.is_active(grading_params):
|
||||
pp.image = processing_grading.grade_image(pp.image, grading_params)
|
||||
|
|
@ -124,7 +124,7 @@ def process(
|
|||
|
||||
|
||||
# defines script for dual-mode usage
|
||||
class Script(scripts.Script):
|
||||
class ScriptNudeNet(scripts.Script):
|
||||
# see below for all available options and callbacks
|
||||
# <https://github.com/vladmandic/automatic/blob/master/modules/scripts.py#L26>
|
||||
|
||||
|
|
@ -148,7 +148,7 @@ class Script(scripts.Script):
|
|||
|
||||
|
||||
# defines postprocessing script for dual-mode usage
|
||||
class ScriptPostprocessing(scripts_postprocessing.ScriptPostprocessing):
|
||||
class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing):
|
||||
name = 'NudeNet'
|
||||
order = 10000
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import gradio as gr
|
|||
from modules import video, scripts_postprocessing
|
||||
|
||||
|
||||
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
||||
name = "Video"
|
||||
class ScriptPostprocessingVideo(scripts_postprocessing.ScriptPostprocessing):
|
||||
name = "Create Video"
|
||||
|
||||
def ui(self):
|
||||
with gr.Accordion('Create video', open = False, elem_id="postprocess_video_accordion"):
|
||||
|
|
@ -18,7 +18,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||
]
|
||||
|
||||
with gr.Row():
|
||||
gr.HTML("<span>  Video</span><br>")
|
||||
gr.HTML("<span>  Create video from generated images</span><br>")
|
||||
with gr.Row():
|
||||
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type")
|
||||
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,446 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import installer
|
||||
from modules.logger import log
|
||||
from modules.json_helpers import readfile, writefile
|
||||
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
|
||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||
|
||||
|
||||
def _check_rocm() -> bool:
|
||||
from modules import shared
|
||||
if getattr(shared.cmd_opts, 'use_rocm', False):
|
||||
return True
|
||||
if installer.torch_info.get('type') == 'rocm':
|
||||
return True
|
||||
import torch # pylint: disable=import-outside-toplevel
|
||||
return hasattr(torch.version, 'hip') and torch.version.hip is not None
|
||||
|
||||
|
||||
is_rocm = _check_rocm()
|
||||
|
||||
|
||||
CONFIG = Path(os.path.abspath(os.path.join('data', 'rocm.json')))
|
||||
|
||||
_cache: Optional[Dict[str, str]] = None # loaded once, invalidated on save
|
||||
|
||||
# Metadata key written into rocm.json to record which architecture profile is active.
|
||||
# Not an environment variable — always skipped during env application but preserved in the
|
||||
# saved config so that arch-safety enforcement is consistent across restarts.
|
||||
_ARCH_KEY = "_rocm_arch"
|
||||
|
||||
# Vars that must never appear in the process environment.
|
||||
#
|
||||
# _DTYPE_UNSAFE: alter FP16 inference dtype — must be cleared regardless of config
|
||||
# MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — DEBUG alias: routes all FP16 convs through BF16 exponent math
|
||||
# MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL — API-level alias: same BF16-exponent effect
|
||||
# MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM — unstable experimental FP16 path
|
||||
# MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16 — changes FP16 WrW atomic accumulation
|
||||
#
|
||||
# SOLVER_DISABLED_BY_DEFAULT: every solver known to be incompatible with this runtime
|
||||
# (FP32-only, training-only WrW/BWD, fixed-geometry mismatches, XDLOPS/CDNA-only, arch-specific).
|
||||
# Actively unsetting these ensures no inherited shell value can re-enable them.
|
||||
_DTYPE_UNSAFE = {
|
||||
"MIOPEN_DEBUG_CONVOLUTION_ATTRIB_FP16_ALT_IMPL",
|
||||
"MIOPEN_CONVOLUTION_ATTRIB_FP16_ALT_IMPL",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_EXPEREMENTAL_FP16_TRANSFORM",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16",
|
||||
}
|
||||
# _UNSET_VARS: hard-blocked vars that are DELETED from the process env and never written,
|
||||
# regardless of saved config. Limited to dtype-corrupting vars only.
|
||||
# IMPORTANT: SOLVER_DISABLED_BY_DEFAULT is intentionally NOT included here.
|
||||
# When a solver var is absent (unset) MIOpen still calls IsApplicable() on every
|
||||
# conv-find — wasted probing overhead. When a var is explicitly "0" MIOpen skips
|
||||
# IsApplicable() immediately. Solver defaults flow through the config loop as "0"
|
||||
# (their ROCM_ENV_VARS default is "0") so they are explicitly set to "0" in the env.
|
||||
_UNSET_VARS = _DTYPE_UNSAFE
|
||||
|
||||
# Additional environment vars that must be removed from the process before MIOpen loads.
|
||||
# These are not MIOpen solver toggles but can corrupt MIOpen's runtime behaviour:
|
||||
# HIP_PATH / HIP_PATH_71 — point to the system AMD ROCm install; override the venv-bundled
|
||||
# _rocm_sdk_devel DLLs with a potentially mismatched system version
|
||||
# QML_*/QT_* — QtQuick shader/disk-cache flags leaked from Qt tools; harmless for
|
||||
# PyTorch but can conflict with Gradio's embedded Qt helpers
|
||||
# PYENV_VIRTUALENV_DISABLE_PROMPT — pyenv noise that confuses venv detection
|
||||
_EXTRA_CLEAR_VARS = {
|
||||
"HIP_PATH",
|
||||
"HIP_PATH_71",
|
||||
"PYENV_VIRTUALENV_DISABLE_PROMPT",
|
||||
"QML_DISABLE_DISK_CACHE",
|
||||
"QML_FORCE_DISK_CACHE",
|
||||
"QT_DISABLE_SHADER_DISK_CACHE",
|
||||
# PERF_VALS vars are NOT boolean toggles — MIOpen reads them as perf-config strings.
|
||||
# If inherited from a parent shell with value "1", MIOpen's GetPerfConfFromEnv parses
|
||||
# "1" as a degenerate config and can return dtype=float32 output from FP16 tensors.
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS",
|
||||
}
|
||||
|
||||
# Solvers whose MIOpen IsApplicable() explicitly rejects non-FP32 tensors.
|
||||
# They are safe to leave enabled in FP32 mode. When the active dtype is FP16 or BF16
|
||||
# we force them OFF so MIOpen skips the IsApplicable probe entirely — avoids overhead on
|
||||
# every conv shape find. These are NOT in _UNSET_VARS because they are valid in FP32.
|
||||
_FP32_ONLY_SOLVERS = {
|
||||
"MIOPEN_DEBUG_CONV_FFT", # FFT convolution — FP32 only (MIOpen source: IsFp32 check)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", # Winograd 3x3 — FP32 only
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD", # Fused Winograd — FP32 only
|
||||
}
|
||||
|
||||
|
||||
def _resolve_dtype() -> str:
|
||||
"""Return the resolved active compute dtype: 'FP16', 'BF16', 'FP32', or '' (not yet known).
|
||||
Prefers the resolved devices.dtype (post test_fp16/bf16) over the raw opts string."""
|
||||
try:
|
||||
import torch # pylint: disable=import-outside-toplevel
|
||||
from modules import devices as _dev # pylint: disable=import-outside-toplevel
|
||||
if _dev.dtype is not None:
|
||||
if _dev.dtype == torch.float16:
|
||||
return 'FP16'
|
||||
if _dev.dtype == torch.bfloat16:
|
||||
return 'BF16'
|
||||
if _dev.dtype == torch.float32:
|
||||
return 'FP32'
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from modules import shared as _sh # pylint: disable=import-outside-toplevel
|
||||
v = getattr(getattr(_sh, 'opts', None), 'cuda_dtype', None)
|
||||
if v in ('FP16', 'BF16', 'FP32'):
|
||||
return v
|
||||
except Exception:
|
||||
pass
|
||||
return ''
|
||||
|
||||
|
||||
# --- venv helpers ---
|
||||
|
||||
def _get_venv() -> str:
|
||||
return os.environ.get("VIRTUAL_ENV", "") or sys.prefix
|
||||
|
||||
|
||||
def _expand_venv(value: str) -> str:
|
||||
return value.replace("{VIRTUAL_ENV}", _get_venv())
|
||||
|
||||
|
||||
def _collapse_venv(value: str) -> str:
|
||||
venv = _get_venv()
|
||||
if venv and value.startswith(venv):
|
||||
return "{VIRTUAL_ENV}" + value[len(venv):]
|
||||
return value
|
||||
|
||||
|
||||
# --- dropdown helpers ---
|
||||
|
||||
def _dropdown_display(stored_val: str, options) -> str:
|
||||
if options and isinstance(options[0], tuple):
|
||||
return next((label for label, val in options if val == str(stored_val)), str(stored_val))
|
||||
return str(stored_val)
|
||||
|
||||
|
||||
def _dropdown_stored(display_val: str, options) -> str:
|
||||
if options and isinstance(options[0], tuple):
|
||||
return next((val for label, val in options if label == str(display_val)), str(display_val))
|
||||
return str(display_val)
|
||||
|
||||
|
||||
def _dropdown_choices(options):
|
||||
if options and isinstance(options[0], tuple):
|
||||
return [label for label, _ in options]
|
||||
return options
|
||||
|
||||
|
||||
# --- config I/O ---
|
||||
|
||||
def load_config() -> Dict[str, str]:
|
||||
global _cache # pylint: disable=global-statement
|
||||
if _cache is None:
|
||||
file_existed = CONFIG.exists()
|
||||
if file_existed:
|
||||
data = readfile(str(CONFIG), lock=True, as_type="dict")
|
||||
_cache = data if data else {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
|
||||
# Purge unsafe vars from a stale saved config and re-persist only if the file existed.
|
||||
# When running without a saved config (first run / after Delete), load_config() must
|
||||
# never create the file — that only happens via save_config() on Apply or Apply Profile.
|
||||
dirty = {k for k in _cache if k in _UNSET_VARS or (k != _ARCH_KEY and k not in ROCM_ENV_VARS)}
|
||||
if dirty:
|
||||
_cache = {k: v for k, v in _cache.items() if k not in dirty}
|
||||
writefile(_cache, str(CONFIG))
|
||||
log.debug(f'ROCm load_config: purged {len(dirty)} stale/unsafe var(s) from saved config')
|
||||
else:
|
||||
_cache = {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
|
||||
log.debug(f'ROCm load_config: path={CONFIG} existed={file_existed} items={len(_cache)}')
|
||||
return _cache
|
||||
|
||||
|
||||
def save_config(config: Dict[str, str]) -> None:
|
||||
global _cache # pylint: disable=global-statement
|
||||
sanitized = {k: v for k, v in config.items() if k not in _UNSET_VARS}
|
||||
# Enforce arch-incompatible solvers to "0" before writing.
|
||||
# Prevents malformed edits (UI or JSON hand-edit) from persisting incompatible "1" values.
|
||||
arch = sanitized.get(_ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
for var in unavailable:
|
||||
if var in sanitized and sanitized[var] != "0":
|
||||
sanitized[var] = "0"
|
||||
log.debug(f'ROCm save_config: clamped arch-incompatible var={var} arch={arch}')
|
||||
writefile(sanitized, str(CONFIG))
|
||||
_cache = sanitized
|
||||
|
||||
|
||||
def apply_env(config: Optional[Dict[str, str]] = None) -> None:
|
||||
if config is None:
|
||||
config = load_config()
|
||||
for var in _UNSET_VARS | _EXTRA_CLEAR_VARS:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
for var, value in config.items():
|
||||
if var == _ARCH_KEY:
|
||||
continue
|
||||
if var in _UNSET_VARS:
|
||||
continue
|
||||
if var not in ROCM_ENV_VARS:
|
||||
continue
|
||||
meta = ROCM_ENV_VARS.get(var, {})
|
||||
if meta.get("options"):
|
||||
value = _dropdown_stored(str(value), meta["options"])
|
||||
expanded = _expand_venv(str(value))
|
||||
if expanded == "":
|
||||
continue
|
||||
os.environ[var] = expanded
|
||||
# Arch safety net: hard-force all hardware-incompatible vars to "0" in the env.
|
||||
# This runs *after* the config loop so it overrides any stale "1" that survived in the JSON.
|
||||
# Source of truth: rocm_profiles.UNAVAILABLE[arch] — vars with no supporting hardware.
|
||||
arch = config.get(_ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
if unavailable:
|
||||
for var in unavailable:
|
||||
os.environ[var] = "0"
|
||||
dtype_str = _resolve_dtype()
|
||||
if dtype_str in ('FP16', 'BF16'):
|
||||
for var in _FP32_ONLY_SOLVERS:
|
||||
os.environ[var] = "0"
|
||||
|
||||
|
||||
def apply_all(names: list, values: list) -> None:
|
||||
config = load_config().copy()
|
||||
arch = config.get(_ARCH_KEY, "")
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
for name, value in zip(names, values):
|
||||
if name not in ROCM_ENV_VARS:
|
||||
log.warning(f'ROCm apply_all: unknown variable={name}')
|
||||
continue
|
||||
# Arch safety net: silently clamp incompatible solvers back to "0".
|
||||
# The UI may send the current checkbox state even for greyed-out vars.
|
||||
if name in unavailable:
|
||||
config[name] = "0"
|
||||
continue
|
||||
meta = ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
if value is None:
|
||||
pass # Gradio passed None (component not interacted with) — leave config unchanged
|
||||
else:
|
||||
config[name] = "1" if value else "0"
|
||||
elif meta["widget"] == "radio":
|
||||
stored = _dropdown_stored(str(value), meta["options"])
|
||||
valid = {v for _, v in meta["options"]} if meta["options"] and isinstance(meta["options"][0], tuple) else set(meta["options"] or [])
|
||||
if stored in valid:
|
||||
config[name] = stored
|
||||
# else: value was None/invalid — leave the existing saved value untouched
|
||||
else:
|
||||
if meta.get("options"):
|
||||
value = _dropdown_stored(str(value), meta["options"])
|
||||
config[name] = _collapse_venv(str(value))
|
||||
save_config(config)
|
||||
apply_env(config)
|
||||
|
||||
|
||||
def reset_defaults() -> None:
|
||||
defaults = {k: v["default"] for k, v in ROCM_ENV_VARS.items()}
|
||||
# Preserve the active arch key so safety nets survive a defaults reset.
|
||||
arch = load_config().get(_ARCH_KEY, "")
|
||||
if arch:
|
||||
defaults[_ARCH_KEY] = arch
|
||||
save_config(defaults)
|
||||
apply_env(defaults)
|
||||
log.info(f'ROCm reset_defaults: config reset to defaults arch={arch or "(none)"}')
|
||||
|
||||
|
||||
def clear_env() -> None:
|
||||
"""Remove all managed ROCm vars and known noise vars from os.environ without writing to disk."""
|
||||
cleared = 0
|
||||
for var in ROCM_ENV_VARS:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
cleared += 1
|
||||
for var in _UNSET_VARS | _EXTRA_CLEAR_VARS:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
cleared += 1
|
||||
log.info(f'ROCm clear_env: cleared={cleared}')
|
||||
|
||||
|
||||
def delete_config() -> None:
|
||||
"""Delete the saved config file, clear all vars, and wipe the MIOpen user DB cache."""
|
||||
import shutil # pylint: disable=import-outside-toplevel
|
||||
global _cache # pylint: disable=global-statement
|
||||
clear_env()
|
||||
if CONFIG.exists():
|
||||
CONFIG.unlink()
|
||||
log.info(f'ROCm delete_config: deleted {CONFIG}')
|
||||
_cache = None
|
||||
# Delete the MIOpen user DB (~/.miopen/db) — stale entries can cause solver mismatches
|
||||
miopen_db = Path(os.path.expanduser('~')) / '.miopen' / 'db'
|
||||
if miopen_db.exists():
|
||||
shutil.rmtree(miopen_db, ignore_errors=True)
|
||||
log.info(f'ROCm delete_config: wiped MIOpen user DB at {miopen_db}')
|
||||
else:
|
||||
log.debug(f'ROCm delete_config: MIOpen user DB not found at {miopen_db} — nothing to wipe')
|
||||
|
||||
|
||||
def apply_profile(name: str) -> None:
|
||||
"""Merge an architecture profile on top of the current config, then save and apply."""
|
||||
profile = rocm_profiles.PROFILES.get(name)
|
||||
if profile is None:
|
||||
log.warning(f'ROCm apply_profile: unknown profile={name}')
|
||||
return
|
||||
config = load_config().copy()
|
||||
config.update(profile)
|
||||
config[_ARCH_KEY] = name # stamp the active arch so safety nets survive restarts
|
||||
save_config(config)
|
||||
apply_env(config)
|
||||
log.info(f'ROCm apply_profile: profile={name} overrides={len(profile)}')
|
||||
|
||||
|
||||
def _hip_version_from_file(db_path: Path) -> str:
|
||||
"""Parse HIP_VERSION_* keys from .hipVersion in the SDK bin folder."""
|
||||
hip_ver_file = db_path / ".hipVersion"
|
||||
if not hip_ver_file.exists():
|
||||
return ""
|
||||
kv = {}
|
||||
for line in hip_ver_file.read_text(errors="ignore").splitlines():
|
||||
if "=" in line and not line.startswith("#"):
|
||||
k, _, v = line.partition("=")
|
||||
kv[k.strip()] = v.strip()
|
||||
major = kv.get("HIP_VERSION_MAJOR", "")
|
||||
minor = kv.get("HIP_VERSION_MINOR", "")
|
||||
patch = kv.get("HIP_VERSION_PATCH", "")
|
||||
git = kv.get("HIP_VERSION_GITHASH", "")
|
||||
if major:
|
||||
return f"{major}.{minor}.{patch} ({git})"
|
||||
return ""
|
||||
|
||||
|
||||
def _pkg_version(name: str) -> str:
|
||||
try:
|
||||
import importlib.metadata as _m # pylint: disable=import-outside-toplevel
|
||||
return _m.version(name)
|
||||
except Exception:
|
||||
return "n/a"
|
||||
|
||||
|
||||
def _db_file_summary(path: Path, patterns: list) -> dict:
|
||||
"""Return {filename: 'N KB'} for files matching any of the given glob patterns."""
|
||||
out = {}
|
||||
for pat in patterns:
|
||||
for f in sorted(path.glob(pat)):
|
||||
kb = f.stat().st_size // 1024
|
||||
out[f.name] = f"{kb} KB"
|
||||
return out
|
||||
|
||||
|
||||
def _user_db_summary(path: Path) -> dict:
|
||||
"""Return {filename: 'N KB, M entries'} for user MIOpen DB txt files."""
|
||||
out = {}
|
||||
for pat in ("*.udb.txt", "*.ufdb.txt"):
|
||||
for f in sorted(path.glob(pat)):
|
||||
kb = f.stat().st_size // 1024
|
||||
try:
|
||||
lines = sum(1 for _ in f.open(errors="ignore"))
|
||||
except Exception:
|
||||
lines = 0
|
||||
out[f.name] = f"{kb} KB, {lines} entries"
|
||||
return out
|
||||
|
||||
|
||||
def info() -> dict:
|
||||
config = load_config()
|
||||
db_path = Path(_expand_venv(config.get("MIOPEN_SYSTEM_DB_PATH", "")))
|
||||
|
||||
# --- ROCm / HIP package versions ---
|
||||
rocm_pkgs = {}
|
||||
for pkg in ("rocm", "rocm-sdk-core", "rocm-sdk-devel"):
|
||||
v = _pkg_version(pkg)
|
||||
if v != "n/a":
|
||||
rocm_pkgs[pkg] = v
|
||||
libs_pkg = _pkg_version("rocm-sdk-libraries-gfx103x-dgpu")
|
||||
if libs_pkg != "n/a":
|
||||
rocm_pkgs["rocm-sdk-libraries (gfx103x)"] = libs_pkg
|
||||
|
||||
hip_ver = _hip_version_from_file(db_path)
|
||||
if not hip_ver:
|
||||
try:
|
||||
import torch # pylint: disable=import-outside-toplevel
|
||||
hip_ver = getattr(torch.version, "hip", "") or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
rocm_section = {}
|
||||
if hip_ver:
|
||||
rocm_section["hip_version"] = hip_ver
|
||||
rocm_section.update(rocm_pkgs)
|
||||
|
||||
# --- Torch ---
|
||||
torch_section = {}
|
||||
try:
|
||||
import torch # pylint: disable=import-outside-toplevel
|
||||
torch_section["version"] = torch.__version__
|
||||
torch_section["hip"] = getattr(torch.version, "hip", None) or "n/a"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- GPU ---
|
||||
gpu_section = [dict(g) for g in installer.gpu_info]
|
||||
|
||||
# --- System DB ---
|
||||
sdb = {"path": str(db_path)}
|
||||
if db_path.exists():
|
||||
solver_db = _db_file_summary(db_path, ["*.db.txt"])
|
||||
find_db = _db_file_summary(db_path, ["*.HIP.fdb.txt", "*.fdb.txt"])
|
||||
kernel_db = _db_file_summary(db_path, ["*.kdb"])
|
||||
if solver_db:
|
||||
sdb["solver_db"] = solver_db
|
||||
if find_db:
|
||||
sdb["find_db"] = find_db
|
||||
if kernel_db:
|
||||
sdb["kernel_db"] = kernel_db
|
||||
else:
|
||||
sdb["exists"] = False
|
||||
|
||||
# --- User DB (~/.miopen/db) ---
|
||||
user_db_path = Path.home() / ".miopen" / "db"
|
||||
udb = {"path": str(user_db_path), "exists": user_db_path.exists()}
|
||||
if user_db_path.exists():
|
||||
ufiles = _user_db_summary(user_db_path)
|
||||
if ufiles:
|
||||
udb["files"] = ufiles
|
||||
|
||||
return {
|
||||
"rocm": rocm_section,
|
||||
"torch": torch_section,
|
||||
"gpu": gpu_section,
|
||||
"system_db": sdb,
|
||||
"user_db": udb,
|
||||
}
|
||||
|
||||
|
||||
# Apply saved config to os.environ at import time (only when ROCm is present)
|
||||
if is_rocm:
|
||||
try:
|
||||
apply_env()
|
||||
except Exception as _e:
|
||||
print(f"[rocm_mgr] Warning: failed to apply env at import: {_e}", file=sys.stderr)
|
||||
else:
|
||||
log.debug('ROCm is not installed — skipping rocm_mgr env apply')
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
"""
|
||||
Architecture-specific MIOpen solver profiles for AMD GCN/RDNA GPUs.
|
||||
|
||||
Sources:
|
||||
https://rocm.docs.amd.com/projects/MIOpen/en/develop/reference/env_variables.html
|
||||
|
||||
Key axis: consumer RDNA GPUs have NO XDLOPS hardware (that's CDNA/Instinct only).
|
||||
RDNA2 (gfx1030): RX 6000 series
|
||||
RDNA3 (gfx1100): RX 7000 series — adds Fury Winograd, wider MPASS
|
||||
RDNA4 (gfx1200): RX 9000 series — adds Rage Winograd, wider MPASS
|
||||
|
||||
Each profile is a dict of {var: value} that will be MERGED on top of the
|
||||
current config (general vars like DB path / log level are preserved).
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared: everything that must be OFF on ALL consumer RDNA (no XDLOPS hw)
|
||||
# ---------------------------------------------------------------------------
|
||||
_XDLOPS_OFF: Dict[str, str] = {
|
||||
# GTC XDLOPS (CDNA-only)
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS_NHWC": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS_NHWC": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS_NHWC": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_DLOPS_NCHWC": "0",
|
||||
# HIP XDLOPS variants (CDNA-only)
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R5_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4_PADDED_GEMM_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4_PADDED_GEMM_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_XDLOPS_EMULATE": "0",
|
||||
"MIOPEN_DEBUG_IMPLICIT_GEMM_XDLOPS_INLINE_ASM": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_GROUP_BWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS_AI_HEUR": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_FWD_V4R4_XDLOPS_ADD_VECTOR_LOAD_GEMMN_TUNE_PARAM": "0",
|
||||
# 3D XDLOPS (CDNA-only; no 3D conv XDLOPS on consumer RDNA)
|
||||
"MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_BWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_WRW_XDLOPS": "0",
|
||||
# Composable Kernel (requires XDLOPS / CDNA)
|
||||
"MIOPEN_DEBUG_CONV_CK_IGEMM_FWD_V6R1_DLOPS_NCHW": "0",
|
||||
"MIOPEN_DEBUG_CONV_CK_IGEMM_FWD_BIAS_ACTIV": "0",
|
||||
"MIOPEN_DEBUG_CONV_CK_IGEMM_FWD_BIAS_RES_ADD_ACTIV": "0",
|
||||
# MLIR (CDNA-only in practice)
|
||||
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_WRW_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_CONV_MLIR_IGEMM_BWD_XDLOPS": "0",
|
||||
# MP BD Winograd (Multi-pass Block-Decomposed — CDNA / high-end only)
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F3X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F4X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F5X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_WINOGRAD_F6X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F3X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F4X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F5X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_MP_BD_XDLOPS_WINOGRAD_F6X3": "0",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA2 — gfx1030 (RX 6000 series)
|
||||
# No XDLOPS, no Fury/Rage Winograd, MPASS limited to F3x2/F3x3
|
||||
# ASM IGEMM: V4R1 variants only; HIP IGEMM: non-XDLOPS V4R1/R4 only
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA2: Dict[str, str] = {
|
||||
**_XDLOPS_OFF,
|
||||
# General settings (architecture-independent; set here so all profiles cover them)
|
||||
"MIOPEN_SEARCH_CUTOFF": "0",
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": "0",
|
||||
# Core algo enables — FFT is FP32-only but harmless (IsApplicable rejects it for fp16 tensors)
|
||||
"MIOPEN_DEBUG_CONV_FFT": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT": "1",
|
||||
"MIOPEN_DEBUG_CONV_GEMM": "1",
|
||||
"MIOPEN_DEBUG_CONV_WINOGRAD": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMMED_FALLBACK": "1",
|
||||
"MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK": "1",
|
||||
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK": "0",
|
||||
# Kernel backends
|
||||
"MIOPEN_DEBUG_GCN_ASM_KERNELS": "1",
|
||||
"MIOPEN_DEBUG_HIP_KERNELS": "1",
|
||||
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "1",
|
||||
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "1",
|
||||
"MIOPEN_DEBUG_ATTN_SOFTMAX": "1",
|
||||
# Direct ASM — dtype notes
|
||||
# 3X3U / 1X1U / 1X1UV2: FP32/FP16 forward — enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "1",
|
||||
# 5X10U2V2: fixed geometry (5*10 stride-2), no SD conv matches — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_5X10U2V2": "0",
|
||||
# 7X7C3H224W224: hard-coded ImageNet stem (C=3, H=W=224, K=64) — never matches SD — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_7X7C3H224W224": "0",
|
||||
# WRW3X3 / WRW1X1: FP32-only weight-gradient (training only) — disabled for inference
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW3X3": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_WRW1X1": "0",
|
||||
# PERF_VALS intentionally blank: MIOpen reads this as a config string not a boolean;
|
||||
# setting to "1" causes GetPerfConfFromEnv to use a degenerate config and return float32
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_PERF_VALS": "",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "1",
|
||||
# NAIVE_CONV_FWD: scalar FP32 reference solver — IsApplicable does NOT reliably filter for FP16;
|
||||
# can be selected for unusual shapes (e.g. VAE decoder 3-ch output) and returns dtype=float32
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD": "0",
|
||||
# Direct OCL — dtype notes
|
||||
# FWD / FWD1X1: FP32/FP16 forward — enabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "1",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "1",
|
||||
# FWD11X11: requires 11*11 kernel — no SD match — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD11X11": "0",
|
||||
# FWDGEN: FP32 generic OCL fallback — IsApplicable does NOT reliably reject for FP16;
|
||||
# can produce dtype=float32 output for FP16 inputs — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWDGEN": "0",
|
||||
# WRW2 / WRW53 / WRW1X1: training-only weight-gradient — disabled
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW2": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW53": "0",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_WRW1X1": "0",
|
||||
# Winograd RxS — dtype per MIOpen docs
|
||||
# WINOGRAD_3X3: FP32-only — harmless (IsApplicable rejects for fp16); enabled
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "1",
|
||||
# RXS: covers FP32/FP16 F(3,3) Fwd/Bwd + FP32 F(3,2) WrW — keep enabled (fp16 fwd/bwd path exists)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "1",
|
||||
# RXS_FWD_BWD: FP32/FP16 — explicitly the fp16-capable subset
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "1",
|
||||
# RXS_WRW: FP32 WrW only — training-only, disabled for inference fp16 profile
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_WRW": "0",
|
||||
# RXS_F3X2: FP32/FP16 Fwd/Bwd
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "1",
|
||||
# RXS_F2X3: FP32/FP16 Fwd/Bwd (group convolutions)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "1",
|
||||
# RXS_F2X3_G1: FP32/FP16 Fwd/Bwd (non-group convolutions)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "1",
|
||||
# FUSED_WINOGRAD: FP32-only — harmless (IsApplicable rejects for fp16); enabled
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "1",
|
||||
# PERF_VALS intentionally blank: same reason as ASM_1X1U — not a boolean, config string
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_PERF_VALS": "",
|
||||
# Fury/Rage Winograd — NOT available on RDNA2
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "0",
|
||||
# MPASS — only F3x2 and F3x3 are safe on RDNA2
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2": "0",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3": "0",
|
||||
# ASM Implicit GEMM — forward V4R1 only; no GTC/XDLOPS on RDNA2
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_V4R1": "0",
|
||||
# HIP Implicit GEMM — non-XDLOPS V4R1/R4 forward only
|
||||
# BWD (backward data-gradient) and WrW (weight-gradient) are training-only — disabled
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V1R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_BWD_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1": "0",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R4": "0",
|
||||
# Group Conv XDLOPS / CK default kernels — RDNA3/4 only, not available on RDNA2
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "0",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "0",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "0",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA3 — gfx1100 (RX 7000 series)
|
||||
# Fury Winograd added; MPASS F3x4 enabled; Group Conv XDLOPS + CK default kernels enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA3: Dict[str, str] = {
|
||||
**RDNA2,
|
||||
# Fury Winograd — introduced for gfx1100 (RDNA3)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "1",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "1",
|
||||
# Wider MPASS on RDNA3
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4": "1",
|
||||
# Group Conv XDLOPS / CK — available from gfx1100 (RDNA3) onwards
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "1",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "1",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "1",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RDNA4 — gfx1200 (RX 9000 series)
|
||||
# Rage Winograd added; MPASS F3x5 enabled
|
||||
# ---------------------------------------------------------------------------
|
||||
RDNA4: Dict[str, str] = {
|
||||
**RDNA3,
|
||||
# Rage Winograd — introduced for gfx1200 (RDNA4)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "1",
|
||||
# Wider MPASS on RDNA4
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5": "1",
|
||||
}
|
||||
|
||||
PROFILES: Dict[str, Dict[str, str]] = {
|
||||
"RDNA2": RDNA2,
|
||||
"RDNA3": RDNA3,
|
||||
"RDNA4": RDNA4,
|
||||
}
|
||||
|
||||
# Vars that are architecturally unavailable (no supporting hardware) per arch.
|
||||
# These will be visually marked in the UI with strikethrough.
|
||||
_UNAVAILABLE_ALL_RDNA = set(_XDLOPS_OFF.keys())
|
||||
|
||||
UNAVAILABLE: Dict[str, set] = {
|
||||
"RDNA2": _UNAVAILABLE_ALL_RDNA | {
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X4",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
||||
},
|
||||
"RDNA3": _UNAVAILABLE_ALL_RDNA | {
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X5",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3",
|
||||
},
|
||||
"RDNA4": _UNAVAILABLE_ALL_RDNA | {
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X6",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F5X4",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F7X3",
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,253 @@
|
|||
from typing import Dict, Any, List, Tuple
|
||||
|
||||
# --- General MIOpen/rocBLAS variables (dropdown/textbox/checkbox) ---
|
||||
GENERAL_VARS: Dict[str, Dict[str, Any]] = {
|
||||
|
||||
"MIOPEN_GEMM_ENFORCE_BACKEND": {
|
||||
"default": "1",
|
||||
"desc": "Enforce GEMM backend",
|
||||
"widget": "dropdown",
|
||||
"options": [("1 - rocBLAS", "1"), ("5 - hipBLASLt", "5")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_FIND_MODE": {
|
||||
"default": "2",
|
||||
"desc": "MIOpen Find Mode",
|
||||
"widget": "dropdown",
|
||||
"options": [("1 - NORMAL", "1"), ("2 - FAST", "2"), ("3 - HYBRID", "3"), ("5 - DYNAMIC_HYBRID", "5"), ("6 - TRUST_VERIFY", "6"), ("7 - TRUST_VERIFY_FULL", "7")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_FIND_ENFORCE": {
|
||||
"default": "1",
|
||||
"desc": "MIOpen Find Enforce",
|
||||
"widget": "dropdown",
|
||||
"options": [("1 - NONE", "1"), ("2 - DB_UPDATE", "2"), ("3 - SEARCH", "3"), ("4 - SEARCH_DB_UPDATE", "4"), ("5 - DB_CLEAN", "5")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_SEARCH_CUTOFF": {
|
||||
"default": "0",
|
||||
"desc": "Enable early termination of suboptimal searches",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_SYSTEM_DB_PATH": {
|
||||
"default": "{VIRTUAL_ENV}\\Lib\\site-packages\\_rocm_sdk_devel\\bin\\",
|
||||
"desc": "MIOpen system DB path",
|
||||
"widget": "textbox",
|
||||
"options": None,
|
||||
"restart_required": True,
|
||||
},
|
||||
"MIOPEN_LOG_LEVEL": {
|
||||
"default": "0",
|
||||
"desc": "MIOpen log verbosity level",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Default", "0"), ("1 - Quiet", "1"), ("3 - Error", "3"), ("4 - Warning", "4"), ("5 - Info", "5"), ("6 - Detail", "6"), ("7 - Trace", "7")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_DEBUG_ENABLE": {
|
||||
"default": "0",
|
||||
"desc": "Enable MIOpen logging",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"ROCBLAS_LAYER": {
|
||||
"default": "0",
|
||||
"desc": "rocBLAS logging",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - Trace", "1"), ("2 - Bench", "2"), ("3 - Trace+Bench", "3"), ("4 - Profile", "4"), ("5 - Trace+Profile", "5"), ("6 - Bench+Profile", "6"), ("7 - All", "7")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"HIPBLASLT_LOG_LEVEL": {
|
||||
"default": "0",
|
||||
"desc": "hipBLASLt logging",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - Error", "1"), ("2 - Trace", "2"), ("3 - Hints", "3"), ("4 - Info", "4"), ("5 - API Trace", "5")],
|
||||
"restart_required": False,
|
||||
},
|
||||
"MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC": {
|
||||
"default": "0",
|
||||
"desc": "Deterministic convolution (reproducible results, may be slower)",
|
||||
"widget": "dropdown",
|
||||
"options": [("0 - Off", "0"), ("1 - On", "1")],
|
||||
"restart_required": False,
|
||||
},
|
||||
}
|
||||
|
||||
# --- Solver toggles (inference/FWD only, RDNA2/3/4 compatible) ---
|
||||
# Removed entirely — not representable in the UI, cannot be set by users:
|
||||
# WRW (weight-gradient) and BWD (data-gradient) — training passes only, never run during inference
|
||||
# XDLOPS/CK CDNA-exclusive (MI100/MI200/MI300 matrix engine variants) — not on any RDNA
|
||||
# Fixed-geometry (5x10, 7x7-ImageNet, 11x11) — shapes never appear in SD/video inference
|
||||
# FP32-reference (NAIVE_CONV_FWD, FWDGEN) — IsApplicable() unreliable for FP16/BF16
|
||||
# Wide MPASS (F3x4..F7x3) — kernel sizes that cannot match any SD convolution shape
|
||||
# Disabled by default (added but off): RDNA3/4-only — Group Conv XDLOPS, CK default kernels
|
||||
_SOLVER_DESCS: Dict[str, str] = {}
|
||||
|
||||
_SOLVER_DESCS.update({
|
||||
"MIOPEN_DEBUG_CONV_FFT": "Enable FFT solver",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT": "Enable Direct solver",
|
||||
"MIOPEN_DEBUG_CONV_GEMM": "Enable GEMM solver",
|
||||
"MIOPEN_DEBUG_CONV_WINOGRAD": "Enable Winograd solver",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM": "Enable Implicit GEMM solver",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
"MIOPEN_DEBUG_CONV_IMMED_FALLBACK": "Enable Immediate Fallback",
|
||||
"MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK": "Enable AI Immediate Mode Fallback",
|
||||
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK": "Force Immediate Mode Fallback",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
"MIOPEN_DEBUG_GCN_ASM_KERNELS": "Enable GCN ASM kernels",
|
||||
"MIOPEN_DEBUG_HIP_KERNELS": "Enable HIP kernels",
|
||||
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS": "Enable OpenCL convolutions",
|
||||
"MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP": "Enable OpenCL Wave64 NOWGP",
|
||||
"MIOPEN_DEBUG_ATTN_SOFTMAX": "Enable Attention Softmax",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
# Direct ASM — FWD inference only (WRW, fixed-geometry, FP32-reference removed)
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "Enable Direct ASM 3x3U",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "Enable Direct ASM 1x1U",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "Enable Direct ASM 1x1UV2",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED": "Enable Direct ASM 1x1U Search Optimized",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR": "Enable Direct ASM 1x1U AI Heuristic",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
# Direct OCL — FWD inference only (WRW, FWD11X11 fixed-geom, FWDGEN FP32-ref removed)
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "Enable Direct OCL FWD",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "Enable Direct OCL FWD1X1",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
# Winograd FWD — WRW removed; Fury/Rage kept as RDNA3/4 inference (off by default)
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "Enable AMD Winograd 3x3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "Enable AMD Winograd RxS",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "Enable AMD Winograd RxS FWD",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "Enable AMD Winograd RxS F3x2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "Enable AMD Winograd RxS F2x3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "Enable AMD Winograd RxS F2x3 G1",
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "Enable AMD Fused Winograd",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "Enable AMD Winograd Fury RxS F2x3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "Enable AMD Winograd Fury RxS F3x2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "Enable AMD Winograd Rage RxS F2x3",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
# Multi-pass Winograd — only F3x2/F3x3 match typical 3x3 SD shapes; wider kernels removed
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "Enable AMD Winograd MPASS F3x2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "Enable AMD Winograd MPASS F3x3",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
# Implicit GEMM FWD — BWD/WRW (training), CDNA-exclusive XDLOPS variants removed
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "Enable ASM Implicit GEMM FWD V4R1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "Enable ASM Implicit GEMM FWD V4R1 1x1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "Enable HIP Implicit GEMM FWD V4R1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "Enable HIP Implicit GEMM FWD V4R4",
|
||||
})
|
||||
_SOLVER_DESCS.update({
|
||||
# Group Conv XDLOPS FWD — RDNA3/4 (gfx1100+) only; disabled by default
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "Enable Group Conv Implicit GEMM XDLOPS FWD",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "Enable Group Conv Implicit GEMM XDLOPS FWD AI Heuristic",
|
||||
# CK (Composable Kernel) default kernels — RDNA3/4 (gfx1100+); disabled by default
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "Enable CK (Composable Kernel) default kernels",
|
||||
})
|
||||
|
||||
|
||||
# Solvers still in the registry but disabled by default.
|
||||
# FORCE_IMMED_MODE_FALLBACK — overrides FIND_MODE entirely, defeats tuning DB
|
||||
# Fury RxS F2x3/F3x2 — RDNA3/4-only; harmless on RDNA2 but won't select
|
||||
# Rage RxS F2x3 — RDNA4-only
|
||||
# Group Conv XDLOPS — RDNA3/4-only (gfx1100+)
|
||||
# CK_DEFAULT_KERNELS — RDNA3/4-only (gfx1100+)
|
||||
SOLVER_DISABLED_BY_DEFAULT = {
|
||||
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
||||
}
|
||||
|
||||
SOLVER_DTYPE_TAGS: Dict[str, str] = {
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3": "FP32",
|
||||
"MIOPEN_DEBUG_AMD_FUSED_WINOGRAD": "FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4": "FP16/FP32",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS": "FP16/BF16",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR": "FP16/BF16",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS": "FP16/BF16/FP32",
|
||||
}
|
||||
|
||||
# Build full merged var registry
|
||||
ROCM_ENV_VARS: Dict[str, Dict[str, Any]] = {}
|
||||
ROCM_ENV_VARS.update(GENERAL_VARS)
|
||||
for _var, _desc in _SOLVER_DESCS.items():
|
||||
ROCM_ENV_VARS[_var] = {
|
||||
"default": "0" if _var in SOLVER_DISABLED_BY_DEFAULT else "1",
|
||||
"desc": _desc,
|
||||
"widget": "checkbox",
|
||||
"options": None,
|
||||
"dtype": SOLVER_DTYPE_TAGS.get(_var),
|
||||
"restart_required": False,
|
||||
}
|
||||
|
||||
# UI group ordering for solver sections
|
||||
SOLVER_GROUPS: List[Tuple[str, List[str]]] = [
|
||||
("Algorithm/Solver Group Enables", [
|
||||
"MIOPEN_DEBUG_CONV_FFT", "MIOPEN_DEBUG_CONV_DIRECT", "MIOPEN_DEBUG_CONV_GEMM",
|
||||
"MIOPEN_DEBUG_CONV_WINOGRAD", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM",
|
||||
]),
|
||||
("Immediate Fallback Mode", [
|
||||
"MIOPEN_DEBUG_CONV_IMMED_FALLBACK", "MIOPEN_DEBUG_ENABLE_AI_IMMED_MODE_FALLBACK",
|
||||
"MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK",
|
||||
]),
|
||||
("Build Method Toggles", [
|
||||
"MIOPEN_DEBUG_GCN_ASM_KERNELS", "MIOPEN_DEBUG_HIP_KERNELS",
|
||||
"MIOPEN_DEBUG_OPENCL_CONVOLUTIONS", "MIOPEN_DEBUG_OPENCL_WAVE64_NOWGP",
|
||||
"MIOPEN_DEBUG_ATTN_SOFTMAX",
|
||||
]),
|
||||
("Direct ASM Solver Toggles", [
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_3X3U", "MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1UV2",
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_SEARCH_OPTIMIZED", "MIOPEN_DEBUG_CONV_DIRECT_ASM_1X1U_AI_HEUR",
|
||||
]),
|
||||
("Direct OpenCL Solver Toggles", [
|
||||
"MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD", "MIOPEN_DEBUG_CONV_DIRECT_OCL_FWD1X1",
|
||||
]),
|
||||
("Winograd Solver Toggles", [
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_3X3", "MIOPEN_DEBUG_AMD_WINOGRAD_RXS",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_FWD_BWD",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", "MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F2X3_G1", "MIOPEN_DEBUG_AMD_FUSED_WINOGRAD",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3",
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F3X2", "MIOPEN_DEBUG_AMD_WINOGRAD_RAGE_RXS_F2X3",
|
||||
]),
|
||||
("Multi-pass Winograd Toggles", [
|
||||
"MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X2", "MIOPEN_DEBUG_AMD_WINOGRAD_MPASS_F3X3",
|
||||
]),
|
||||
("Implicit GEMM Toggles", [
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_V4R1_1X1",
|
||||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1", "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R4",
|
||||
]),
|
||||
("Group Conv / CK Toggles (RDNA3/4+)", [
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS",
|
||||
"MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS_AI_HEUR",
|
||||
"MIOPEN_DEBUG_CK_DEFAULT_KERNELS",
|
||||
]),
|
||||
]
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
import gradio as gr
|
||||
import installer
|
||||
from modules import scripts_manager, shared
|
||||
|
||||
# rocm_mgr exposes package-internal helpers (prefixed _) that are intentionally called here
|
||||
# pylint: disable=protected-access
|
||||
|
||||
|
||||
class Script(scripts_manager.Script):
|
||||
def title(self):
|
||||
return "ROCm: Advanced Config"
|
||||
|
||||
def show(self, _is_img2img):
|
||||
if shared.cmd_opts.use_rocm or installer.torch_info.get('type') == 'rocm':
|
||||
return scripts_manager.AlwaysVisible # script should be visible only if rocm is detected or forced
|
||||
return False
|
||||
|
||||
def ui(self, _is_img2img):
|
||||
if not shared.cmd_opts.use_rocm and not installer.torch_info.get('type') == 'rocm': # skip ui creation if not rocm
|
||||
return []
|
||||
|
||||
from scripts.rocm import rocm_mgr, rocm_vars # pylint: disable=no-name-in-module
|
||||
|
||||
config = rocm_mgr.load_config()
|
||||
var_names = []
|
||||
components = []
|
||||
|
||||
def _make_component(name, meta, cfg):
|
||||
val = cfg.get(name, meta["default"])
|
||||
widget = meta["widget"]
|
||||
if widget == "checkbox":
|
||||
dtype_tag = meta.get("dtype")
|
||||
label = f"[{dtype_tag}] {meta['desc']}" if dtype_tag else meta["desc"]
|
||||
return gr.Checkbox(label=label, value=(val == "1"), elem_id=f"rocm_var_{name.lower()}")
|
||||
if widget == "dropdown":
|
||||
choices = rocm_mgr._dropdown_choices(meta["options"])
|
||||
display = rocm_mgr._dropdown_display(val, meta["options"])
|
||||
return gr.Dropdown(label=meta["desc"], choices=choices, value=display, elem_id=f"rocm_var_{name.lower()}")
|
||||
return gr.Textbox(label=meta["desc"], value=rocm_mgr._expand_venv(val), lines=1)
|
||||
|
||||
def _info_html():
|
||||
d = rocm_mgr.info()
|
||||
rows = []
|
||||
def section(title):
|
||||
rows.append(f"<tr><th colspan='2' style='padding-top:6px;text-align:left;color:var(--sd-main-accent-color)'>{title}</th></tr>")
|
||||
def row(k, v):
|
||||
rows.append(f"<tr><td style='color:var(--sd-muted-color);width:38%;padding:2px 8px;border-bottom:1px solid var(--sd-panel-border-color)'>{k}</td><td style='color:var(--sd-label-color);padding:2px 8px;border-bottom:1px solid var(--sd-panel-border-color)'>{v}</td></tr>")
|
||||
section("ROCm / HIP")
|
||||
for k, v in d.get("rocm", {}).items():
|
||||
row(k, v)
|
||||
section("System DB")
|
||||
sdb = d.get("system_db", {})
|
||||
row("path", sdb.get("path", ""))
|
||||
for sub in ("solver_db", "find_db", "kernel_db"):
|
||||
for fname, sz in sdb.get(sub, {}).items():
|
||||
row(sub.replace("_", " "), f"{fname} {sz}")
|
||||
section("User DB (~/.miopen/db)")
|
||||
udb = d.get("user_db", {})
|
||||
row("path", udb.get("path", ""))
|
||||
for fname, finfo in udb.get("files", {}).items():
|
||||
row(fname, finfo)
|
||||
return f"<table style='width:100%;border-collapse:collapse'>{''.join(rows)}</table>"
|
||||
|
||||
with gr.Accordion('ROCm: Advanced Config', open=False, elem_id='rocm_config'):
|
||||
with gr.Row():
|
||||
gr.HTML("<p>Advanced configuration for ROCm users.</p><br><p>Set your database and solver selections based on GPU profile or individually.</p><br><p>Enable cuDNN in Backend Settings to activate MIOpen.</p>")
|
||||
with gr.Row():
|
||||
btn_info = gr.Button("Refresh Info", variant="primary", elem_id="rocm_btn_info", size="sm")
|
||||
btn_apply = gr.Button("Apply", variant="primary", elem_id="rocm_btn_apply", size="sm")
|
||||
btn_reset = gr.Button("Defaults", elem_id="rocm_btn_reset", size="sm")
|
||||
btn_clear = gr.Button("Clear Run Vars", elem_id="rocm_btn_clear", size="sm")
|
||||
btn_delete = gr.Button("Delete UserDb", variant="stop", elem_id="rocm_btn_delete", size="sm")
|
||||
with gr.Row():
|
||||
btn_rdna2 = gr.Button("RDNA2 (RX 6000)", elem_id="rocm_btn_rdna2")
|
||||
btn_rdna3 = gr.Button("RDNA3 (RX 7000)", elem_id="rocm_btn_rdna3")
|
||||
btn_rdna4 = gr.Button("RDNA4 (RX 9000)", elem_id="rocm_btn_rdna4")
|
||||
style_out = gr.HTML("")
|
||||
info_out = gr.HTML(value=_info_html, elem_id="rocm_info_table")
|
||||
|
||||
# General vars (dropdowns, textboxes, checkboxes)
|
||||
with gr.Group():
|
||||
gr.HTML("<h3>MIOpen Settings</h3><hr>")
|
||||
for name, meta in rocm_vars.GENERAL_VARS.items():
|
||||
comp = _make_component(name, meta, config)
|
||||
var_names.append(name)
|
||||
components.append(comp)
|
||||
|
||||
# Solver groups (all checkboxes, grouped by section)
|
||||
for group_name, varlist in rocm_vars.SOLVER_GROUPS:
|
||||
with gr.Group():
|
||||
gr.HTML(f"<h3>{group_name}</h3><hr>")
|
||||
for name in varlist:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
comp = _make_component(name, meta, config)
|
||||
var_names.append(name)
|
||||
components.append(comp)
|
||||
gr.HTML("<br><center><div style='margin:0 Auto'><a href='https://rocm.docs.amd.com/projects/MIOpen/en/develop/reference/env_variables.html' target='_blank'>📄 MIOpen Environment Variables Reference</a></div></center><br>")
|
||||
|
||||
def _autosave_field(name, value):
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
stored = rocm_mgr._dropdown_stored(str(value), meta["options"])
|
||||
cfg = rocm_mgr.load_config()
|
||||
cfg[name] = stored
|
||||
rocm_mgr.save_config(cfg)
|
||||
rocm_mgr.apply_env(cfg)
|
||||
|
||||
for name, comp in zip(var_names, components):
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "dropdown":
|
||||
comp.change(fn=lambda v, n=name: _autosave_field(n, v), inputs=[comp], outputs=[], show_progress='hidden')
|
||||
|
||||
def apply_fn(*values):
|
||||
rocm_mgr.apply_all(var_names, list(values))
|
||||
saved = rocm_mgr.load_config()
|
||||
result = [gr.update(value="")]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
val = saved.get(name, meta["default"])
|
||||
if meta["widget"] == "checkbox":
|
||||
result.append(gr.update(value=val == "1"))
|
||||
elif meta["widget"] == "dropdown":
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
|
||||
else:
|
||||
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
||||
return result
|
||||
|
||||
def _build_style(unavailable):
|
||||
if not unavailable:
|
||||
return ""
|
||||
rules = " ".join(
|
||||
f"#rocm_var_{v.lower()} label {{ text-decoration: line-through; opacity: 0.5; }}"
|
||||
for v in unavailable
|
||||
)
|
||||
return f"<style>{rules}</style>"
|
||||
|
||||
def reset_fn():
|
||||
rocm_mgr.reset_defaults()
|
||||
updated = rocm_mgr.load_config()
|
||||
result = [gr.update(value="")]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
val = updated.get(name, meta["default"])
|
||||
if meta["widget"] == "checkbox":
|
||||
result.append(gr.update(value=val == "1"))
|
||||
elif meta["widget"] == "dropdown":
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
|
||||
else:
|
||||
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
||||
return result
|
||||
|
||||
def clear_fn():
|
||||
rocm_mgr.clear_env()
|
||||
result = [gr.update(value="")]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
result.append(gr.update(value=False))
|
||||
elif meta["widget"] == "dropdown":
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(meta["default"], meta["options"])))
|
||||
else:
|
||||
result.append(gr.update(value=""))
|
||||
return result
|
||||
|
||||
def delete_fn():
|
||||
rocm_mgr.delete_config()
|
||||
result = [gr.update(value="")]
|
||||
for name in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[name]
|
||||
if meta["widget"] == "checkbox":
|
||||
result.append(gr.update(value=False))
|
||||
elif meta["widget"] == "dropdown":
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(meta["default"], meta["options"])))
|
||||
else:
|
||||
result.append(gr.update(value=""))
|
||||
return result
|
||||
|
||||
def profile_fn(arch):
|
||||
from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||
rocm_mgr.apply_profile(arch)
|
||||
updated = rocm_mgr.load_config()
|
||||
unavailable = rocm_profiles.UNAVAILABLE.get(arch, set())
|
||||
result = [gr.update(value=_build_style(unavailable))]
|
||||
for pname in var_names:
|
||||
meta = rocm_vars.ROCM_ENV_VARS[pname]
|
||||
val = updated.get(pname, meta["default"])
|
||||
if meta["widget"] == "checkbox":
|
||||
result.append(gr.update(value=val == "1"))
|
||||
elif meta["widget"] == "dropdown":
|
||||
result.append(gr.update(value=rocm_mgr._dropdown_display(val, meta["options"])))
|
||||
else:
|
||||
result.append(gr.update(value=rocm_mgr._expand_venv(val)))
|
||||
return result
|
||||
|
||||
btn_info.click(fn=_info_html, inputs=[], outputs=[info_out], show_progress='hidden')
|
||||
btn_apply.click(fn=apply_fn, inputs=components, outputs=[style_out] + components, show_progress='hidden')
|
||||
btn_reset.click(fn=reset_fn, inputs=[], outputs=[style_out] + components, show_progress='hidden')
|
||||
btn_clear.click(fn=clear_fn, inputs=[], outputs=[style_out] + components, show_progress='hidden')
|
||||
btn_delete.click(fn=delete_fn, inputs=[], outputs=[style_out] + components, show_progress='hidden')
|
||||
btn_rdna2.click(fn=lambda: profile_fn("RDNA2"), inputs=[], outputs=[style_out] + components, show_progress='hidden')
|
||||
btn_rdna3.click(fn=lambda: profile_fn("RDNA3"), inputs=[], outputs=[style_out] + components, show_progress='hidden')
|
||||
btn_rdna4.click(fn=lambda: profile_fn("RDNA4"), inputs=[], outputs=[style_out] + components, show_progress='hidden')
|
||||
|
||||
return components
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
const fs = require('fs');
|
||||
|
||||
/**
|
||||
* Custom stringifier that switches to minified format at a specific depth
|
||||
*/
|
||||
const mixedStringify = (data, maxDepth, indent = 2, currentDepth = 0) => {
|
||||
if (currentDepth >= maxDepth) {
|
||||
return JSON.stringify(data);
|
||||
}
|
||||
|
||||
const spacing = ' '.repeat(indent * currentDepth);
|
||||
const nextSpacing = ' '.repeat(indent * (currentDepth + 1));
|
||||
|
||||
if (Array.isArray(data)) {
|
||||
if (data.length === 0) return '[]';
|
||||
const items = data.map((item) => nextSpacing + mixedStringify(item, maxDepth, indent, currentDepth + 1));
|
||||
return `[\n${items.join(',\n')}\n${spacing}]`;
|
||||
}
|
||||
|
||||
if (typeof data === 'object' && data !== null) {
|
||||
const keys = Object.keys(data);
|
||||
if (keys.length === 0) return '{}';
|
||||
const items = keys.map((key) => {
|
||||
const value = mixedStringify(data[key], maxDepth, indent, currentDepth + 1);
|
||||
return `${nextSpacing}"${key}": ${value}`;
|
||||
});
|
||||
return `{\n${items.join(',\n')}\n${spacing}}`;
|
||||
}
|
||||
|
||||
return JSON.stringify(data);
|
||||
};
|
||||
|
||||
// Capture CLI arguments
|
||||
const [,, inputFile, outputFile, maxDepth] = process.argv;
|
||||
console.log(`Input File: ${inputFile}, Output File: ${outputFile}, Max Depth: ${maxDepth}`);
|
||||
|
||||
if (!inputFile || !outputFile || !maxDepth) {
|
||||
console.log('Usage: node reformat.js <input.json> <output.json> <depth>');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
try {
|
||||
// Read input file
|
||||
const rawData = fs.readFileSync(inputFile, 'utf8');
|
||||
const jsonData = JSON.parse(rawData);
|
||||
|
||||
// Reformat with mixed depth
|
||||
const result = mixedStringify(jsonData, parseInt(maxDepth, 10));
|
||||
|
||||
// Write output file
|
||||
fs.writeFileSync(outputFile, result);
|
||||
console.log(`Success! File saved to ${outputFile} (levels expanded: ${maxDepth})`);
|
||||
} catch (err) {
|
||||
console.error('Error processing JSON:', err.message);
|
||||
}
|
||||
|
|
@ -95,7 +95,7 @@ def test_grading_params_defaults():
|
|||
assert p.split_tone_balance == 0.5
|
||||
assert p.vignette == 0.0
|
||||
assert p.grain == 0.0
|
||||
assert p.lut_file == ""
|
||||
assert p.lut_cube_file == ""
|
||||
assert p.lut_strength == 1.0
|
||||
return True
|
||||
|
||||
|
|
|
|||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
|||
Subproject commit 33dbd026a2e2fb7311d545a3b2d2db0363bb887f
|
||||
Subproject commit 99f4e13d03191b5269b869c71283d7fcf9c98f60
|
||||
Loading…
Reference in New Issue