Merge pull request #4618 from vladmandic/dev

merge dev
pull/4620/head
Vladimir Mandic 2026-02-04 15:05:58 +01:00 committed by GitHub
commit 0d240b1a8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
160 changed files with 13778 additions and 2460 deletions

10
.gitignore vendored
View File

@ -1,7 +1,7 @@
# defaults
venv/
__pycache__
.ruff_cache
/cache.json
/*.json
/*.yaml
/params.txt
@ -9,13 +9,14 @@ __pycache__
/user.css
/webui-user.bat
/webui-user.sh
/html/extensions.json
/html/themes.json
/data/metadata.json
/data/extensions.json
/data/cache.json
/data/themes.json
config_states
node_modules
pnpm-lock.yaml
package-lock.json
venv
.history
cache
**/.DS_Store
@ -65,6 +66,7 @@ tunableop_results*.csv
.*/
# force included
!/data
!/models/VAE-approx
!/models/VAE-approx/model.pt
!/models/Reference

View File

@ -36,6 +36,7 @@ ignore-paths=/usr/lib/.*$,
modules/taesd,
modules/teacache,
modules/todo,
modules/res4lyf,
pipelines/bria,
pipelines/flex2,
pipelines/f_lite,

View File

@ -1,3 +1,6 @@
line-length = 250
indent-width = 4
target-version = "py310"
exclude = [
"venv",
".git",
@ -41,9 +44,6 @@ exclude = [
"extensions-builtin/sd-webui-agent-scheduler",
"extensions-builtin/sdnext-modernui/node_modules",
]
line-length = 250
indent-width = 4
target-version = "py310"
[lint]
select = [

1
.vscode/launch.json vendored
View File

@ -11,7 +11,6 @@
"env": { "USED_VSCODE_COMMAND_PICKARGS": "1" },
"args": [
"--uv",
"--quick",
"--log", "vscode.log",
"${command:pickArgs}"]
}

View File

@ -1,4 +1,5 @@
{
"files.eol": "\n",
"python.analysis.extraPaths": [".", "./modules", "./scripts", "./pipelines"],
"python.analysis.typeCheckingMode": "off",
"editor.formatOnSave": false,

View File

@ -1,5 +1,79 @@
# Change Log for SD.Next
## Update for 2026-02-04
### Highlights for 2026-02-04
Refresh release two weeks after prior release, yet we still somehow managed to pack in *~150 commits*!
Highlights would be two new models: **Z-Image-Base** and **Anima**, *captioning* support for **tagger** models and a massive addition of new **schedulers**
Also here are updates to `torch` and additional GPU archs support for `ROCm` backends, plus a lot of internal improvements and fixes.
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
### Details for 2026-02-04
- **Models**
- [Tongyi-MAI Z-Image Base](https://tongyi-mai.github.io/Z-Image-blog/)
yup, its finally here, the full base model of **Z-Image**
- [CircleStone Anima](https://huggingface.co/circlestone-labs/Anima)
2B anime optimized model based on a modified Cosmos-Predict, using Qwen3-0.6B as a text encoder
- **Features**
- **caption** tab support for Booru tagger models, thanks @CalamitousFelicitousness
- add SmilingWolf WD14/WaifuDiffusion tagger models, thanks @CalamitousFelicitousness
- support comments in wildcard files, using `#`
- support aliases in metadata skip params, thanks @CalamitousFelicitousness
- ui gallery improve cache cleanup and add manual option, thanks @awsr
- selectable options to add system info to metadata, thanks @Athari
see *settings -> image metadata*
- **Schedulers**
- schedulers documentation has new home: <https://vladmandic.github.io/sdnext-docs/Schedulers/>
- add 13(!) new scheduler families
not a port, but more of inspired-by [res4lyf](https://github.com/ClownsharkBatwing/RES4LYF) library
all schedulers should be compatible with both `epsilon` and `flow` prediction style!
*note*: each family may have multiple actual schedulers, so the list total is 56(!) new schedulers
- core family: *RES*
- exponential: *DEIS, ETD, Lawson, ABNorsett*
- integrators: *Runge-Kutta, Linear-RK, Specialized-RK, Lobatto, Radau-IIA, Gauss-Legendre*
- flow: *PEC, Riemannian, Euclidean, Hyperbolic, Lorentzian, Langevin-Dynamics*
- add 3 additional schedulers: *CogXDDIM, DDIMParallel, DDPMParallel*
not originally intended to be a general purpose schedulers, but they work quite nicely and produce good results
- image metadata: always log scheduler class used
- **API**
- add `/sdapi/v1/xyz-grid` to enumerate xyz-grid axis options and their choices
see `/cli/api-xyzenum.py` for example usage
- add `/sdapi/v1/sampler` to get current sampler config
- modify `/sdapi/v1/samplers` to enumerate available samplers possible options
see `/cli/api-samplers.py` for example usage
- **Internal**
- tagged release history: <https://github.com/vladmandic/sdnext/tags>
each major for the past year is now tagged for easier reference
- **torch** update
*note*: may cause slow first startup/generate
**cuda**: update to `torch==2.10.0`
**xpu**: update to `torch==2.10.0`
**rocm**: update to `torch==2.10.0`
**openvino**: update to `torch==2.10.0` and `openvino==2025.4.1`
- rocm: expand available gfx archs, thanks @crashingalexsan
- rocm: set `MIOPEN_FIND_MODE=2` by default, thanks @crashingalexsan
- relocate all json data files to `data/` folder
existing data files are auto-migrated on startup
- refactor and improve connection monitor, thanks @awsr
- further work on type consistency and type checking, thanks @awsr
- log captured exceptions
- improve temp folder handling and cleanup
- remove torch errors/warings on fast server shutdown
- add ui placeholders for future agent-scheduler work, thanks @ryanmeador
- implement abort system on repeated errors, thanks @awsr
currently used by lora and textual-inversion loaders
- update package requirements
- **Fixes**
- add video ui elem_ids, thanks @ryanmeador
- use base steps as-is for non sd/sdxl models
- ui css fixes for modernui
- support lora inside prompt selector
- framepack video save
- metadata save for manual saves
## Update for 2026-01-22
Bugfix refresh
@ -139,7 +213,7 @@ End of year release update, just two weeks after previous one, with several new
- **Models**
- [LongCat Image](https://github.com/meituan-longcat/LongCat-Image) in *Image* and *Image Edit* variants
LongCat is a new 8B diffusion base model using Qwen-2.5 as text encoder
- [Qwen-Image-Edit 2511](Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants
- [Qwen-Image-Edit 2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants
Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability
- [Qwen-Image-Layered](https://huggingface.co/Qwen/Qwen-Image-Layered) in *base* and *pre-quantized* variants
Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers

View File

@ -1,7 +1,7 @@
<div align="center">
<img src="https://github.com/vladmandic/sdnext/raw/master/html/logo-transparent.png" width=200 alt="SD.Next">
# SD.Next: All-in-one WebUI for AI generative image and video creation
# SD.Next: All-in-one WebUI for AI generative image and video creation and captioning
![Last update](https://img.shields.io/github/last-commit/vladmandic/sdnext?svg=true)
![License](https://img.shields.io/github/license/vladmandic/sdnext?svg=true)
@ -27,10 +27,8 @@
All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
- Fully localized:
**English | Chinese | Russian | Spanish | German | French | Italian | Portuguese | Japanese | Korean**
- Multiple UIs!
**Standard | Modern**
- Desktop and Mobile support!
- Multiple [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
- Built-in Control for Text, Image, Batch and Video processing!
- Multi-platform!
▹ **Windows | Linux | MacOS | nVidia CUDA | AMD ROCm | Intel Arc / IPEX XPU | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
- Platform specific auto-detection and tuning performed on install
@ -38,9 +36,7 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG
Compile backends: *Triton | StableFast | DeepCache | OneDiff | TeaCache | etc.*
Quantization methods: *SDNQ | BitsAndBytes | Optimum-Quanto | TorchAO / LayerWise*
- **Interrogate/Captioning** with 150+ **OpenCLiP** models and 20+ built-in **VLMs**
- Built-in queue management
- Built in installer with automatic updates and dependency management
- Mobile compatible
<br>

190
TODO.md
View File

@ -1,107 +1,137 @@
# TODO
## Project Board
- <https://github.com/users/vladmandic/projects>
## Internal
- Feature: Move `nunchaku` models to refernce instead of internal decision
- Update: `transformers==5.0.0`
- Feature: Unify *huggingface* and *diffusers* model folders
- Reimplement `llama` remover for Kanvas
- Update: `transformers==5.0.0`, owner @CalamitousFelicitousness
- Deploy: Create executable for SD.Next
- Feature: Integrate natural language image search
[ImageDB](https://github.com/vladmandic/imagedb)
- Feature: Remote Text-Encoder support
- Refactor: move sampler options to settings to config
- Refactor: [GGUF](https://huggingface.co/docs/diffusers/main/en/quantization/gguf)
- Feature: LoRA add OMI format support for SD35/FLUX.1
- Refactor: remove `CodeFormer`
- Refactor: remove `GFPGAN`
- UI: Lite vs Expert mode
- Video tab: add full API support
- Control tab: add overrides handling
- Engine: `TensorRT` acceleration
- Deploy: Lite vs Expert mode
- Engine: [mmgp](https://github.com/deepbeepmeep/mmgp)
- Engine: [sharpfin](https://github.com/drhead/sharpfin) instead of `torchvision`
- Engine: `TensorRT` acceleration
- Feature: Auto handle scheduler `prediction_type`
- Feature: Cache models in memory
- Feature: Control tab add overrides handling
- Feature: Integrate natural language image search
[ImageDB](https://github.com/vladmandic/imagedb)
- Feature: LoRA add OMI format support for SD35/FLUX.1, on-hold
- Feature: Multi-user support
- Feature: Remote Text-Encoder support, sidelined for the moment
- Feature: Settings profile manager
- Feature: Video tab add full API support
- Refactor: Unify *huggingface* and *diffusers* model folders
- Refactor: Move `nunchaku` models to refernce instead of internal decision, owner @CalamitousFelicitousness
- Refactor: [GGUF](https://huggingface.co/docs/diffusers/main/en/quantization/gguf)
- Refactor: move sampler options to settings to config
- Refactor: remove `CodeFormer`, owner @CalamitousFelicitousness
- Refactor: remove `GFPGAN`, owner @CalamitousFelicitousness
- Reimplement `llama` remover for Kanvas, pending end-to-end review of `Kanvas`
## Modular
*Pending finalization of modular pipelines implementation and development of compatibility layer*
- Switch to modular pipelines
- Feature: Transformers unified cache handler
- Refactor: [Modular pipelines and guiders](https://github.com/huggingface/diffusers/issues/11915)
- [MagCache](https://github.com/lllyasviel/FramePack/pull/673/files)
- [MagCache](https://github.com/huggingface/diffusers/pull/12744)
- [SmoothCache](https://github.com/huggingface/diffusers/issues/11135)
## Features
- [Flux.2 TinyVAE](https://huggingface.co/fal/FLUX.2-Tiny-AutoEncoder)
- [IPAdapter composition](https://huggingface.co/ostris/ip-composition-adapter)
- [IPAdapter negative guidance](https://github.com/huggingface/diffusers/discussions/7167)
- [STG](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#spatiotemporal-skip-guidance)
- [Video Inpaint Pipeline](https://github.com/huggingface/diffusers/pull/12506)
- [Sonic Inpaint](https://github.com/ubc-vision/sonic)
### New models / Pipelines
## New models / Pipelines
TODO: Investigate which models are diffusers-compatible and prioritize!
- [Bria FiboEdit](https://github.com/huggingface/diffusers/commit/d7a1c31f4f85bae5a9e01cdce49bd7346bd8ccd6)
- [LTXVideo 0.98 LongMulti](https://github.com/huggingface/diffusers/pull/12614)
- [Cosmos-Predict-2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B)
- [NewBie Image Exp0.1](https://github.com/huggingface/diffusers/pull/12803)
- [Sana-I2V](https://github.com/huggingface/diffusers/pull/12634#issuecomment-3540534268)
- [Bria FIBO](https://huggingface.co/briaai/FIBO)
- [Bytedance Lynx](https://github.com/bytedance/lynx)
- [ByteDance OneReward](https://github.com/bytedance/OneReward)
- [ByteDance USO](https://github.com/bytedance/USO)
- [Chroma Radiance](https://huggingface.co/lodestones/Chroma1-Radiance)
- [Chroma Zeta](https://huggingface.co/lodestones/Zeta-Chroma)
- [DiffSynth Studio](https://github.com/modelscope/DiffSynth-Studio)
- [DiffusionForcing](https://github.com/kwsong0113/diffusion-forcing-transformer)
- [Dream0 guidance](https://huggingface.co/ByteDance/DreamO)
- [HunyuanAvatar](https://huggingface.co/tencent/HunyuanVideo-Avatar)
- [HunyuanCustom](https://github.com/Tencent-Hunyuan/HunyuanCustom)
- [Inf-DiT](https://github.com/zai-org/Inf-DiT)
- [Krea Realtime Video](https://huggingface.co/krea/krea-realtime-video)
- [LanDiff](https://github.com/landiff/landiff)
- [Liquid](https://github.com/FoundationVision/Liquid)
- [LongCat-Video](https://huggingface.co/meituan-longcat/LongCat-Video)
- [LucyEdit](https://github.com/huggingface/diffusers/pull/12340)
- [Lumina-DiMOO](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO)
- [Magi](https://github.com/SandAI-org/MAGI-1)(https://github.com/huggingface/diffusers/pull/11713)
- [Ming](https://github.com/inclusionAI/Ming)
- [MUG-V 10B](https://huggingface.co/MUG-V/MUG-V-inference)
- [Ovi](https://github.com/character-ai/Ovi)
- [Phantom HuMo](https://github.com/Phantom-video/Phantom)
- [SD3 UltraEdit](https://github.com/HaozheZhao/UltraEdit)
- [SelfForcing](https://github.com/guandeh17/Self-Forcing)
- [SEVA](https://github.com/huggingface/diffusers/pull/11440)
- [Step1X](https://github.com/stepfun-ai/Step1X-Edit)
- [Wan-2.2 Animate](https://github.com/huggingface/diffusers/pull/12526)
- [Wan-2.2 S2V](https://github.com/huggingface/diffusers/pull/12258)
- [WAN-CausVid-Plus t2v](https://github.com/goatWu/CausVid-Plus/)
- [WAN-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
- [WAN-StepDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill)
- [Wan2.2-Animate-14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B)
- [WAN2GP](https://github.com/deepbeepmeep/Wan2GP)
### Upscalers
- [HQX](https://github.com/uier/py-hqx/blob/main/hqx.py)
- [DCCI](https://every-algorithm.github.io/2024/11/06/directional_cubic_convolution_interpolation.html)
- [ICBI](https://github.com/gyfastas/ICBI/blob/master/icbi.py)
### Image-Base
- [Chroma Zeta](https://huggingface.co/lodestones/Zeta-Chroma): Image and video generator for creative effects and professional filters
- [Chroma Radiance](https://huggingface.co/lodestones/Chroma1-Radiance): Pixel-space model eliminating VAE artifacts for high visual fidelity
- [Liquid](https://github.com/FoundationVision/Liquid): Unified vision-language auto-regressive generation paradigm
- [Lumina-DiMOO](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO): Foundational multi-modal generation and understanding via discrete diffusion
- [nVidia Cosmos-Predict-2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B): Physics-aware world foundation model for consistent scene prediction
- [Liquid (unified multimodal generator)](https://github.com/FoundationVision/Liquid): Auto-regressive generation paradigm across vision and language
- [Lumina-DiMOO](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO): foundational multi-modal multi-task generation and understanding
### Image-Edit
- [Meituan LongCat-Image-Edit-Turbo](https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo):6B instruction-following image editing with high visual consistency
- [VIBE Image-Edit](https://huggingface.co/iitolstykh/VIBE-Image-Edit): (Sana+Qwen-VL)Fast visual instruction-based image editing framework
- [LucyEdit](https://github.com/huggingface/diffusers/pull/12340):Instruction-guided video editing while preserving motion and identity
- [Step1X-Edit](https://github.com/stepfun-ai/Step1X-Edit):Multimodal image editing decoding MLLM tokens via DiT
- [OneReward](https://github.com/bytedance/OneReward):Reinforcement learning grounded generative reward model for image editing
- [ByteDance DreamO](https://huggingface.co/ByteDance/DreamO): image customization framework for IP adaptation and virtual try-on
### Video
- [OpenMOSS MOVA](https://huggingface.co/OpenMOSS-Team/MOVA-720p): Unified foundation model for synchronized high-fidelity video and audio
- [Wan family (Wan2.1 / Wan2.2 variants)](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B): MoE-based foundational tools for cinematic T2V/I2V/TI2V
example: [Wan2.1-T2V-14B-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
distill / step-distill examples: [Wan2.1-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill)
- [Krea Realtime Video](https://huggingface.co/krea/krea-realtime-video): (Wan2.1)Distilled real-time video diffusion using self-forcing techniques
- [MAGI-1 (autoregressive video)](https://github.com/SandAI-org/MAGI-1): Autoregressive video generation allowing infinite and timeline control
- [MUG-V 10B (video generation)](https://huggingface.co/MUG-V/MUG-V-inference): large-scale DiT-based video generation system trained via flow-matching
- [Ovi (audio/video generation)](https://github.com/character-ai/Ovi): (Wan2.2)Speech-to-video with synchronized sound effects and music
- [HunyuanVideo-Avatar / HunyuanCustom](https://huggingface.co/tencent/HunyuanVideo-Avatar): (HunyuanVideo)MM-DiT based dynamic emotion-controllable dialogue generation
- [Sana Image→Video (Sana-I2V)](https://github.com/huggingface/diffusers/pull/12634#issuecomment-3540534268): (Sana)Compact Linear DiT framework for efficient high-resolution video
- [Wan-2.2 S2V (diffusers PR)](https://github.com/huggingface/diffusers/pull/12258): (Wan2.2)Audio-driven cinematic speech-to-video generation
- [LongCat-Video](https://huggingface.co/meituan-longcat/LongCat-Video): Unified framework for minutes-long coherent video generation via Block Sparse Attention
- [LTXVideo / LTXVideo LongMulti (diffusers PR)](https://github.com/huggingface/diffusers/pull/12614): Real-time DiT-based generation with production-ready camera controls
- [DiffSynth-Studio (ModelScope)](https://github.com/modelscope/DiffSynth-Studio): (Wan2.2)Comprehensive training and quantization tools for Wan video models
- [Phantom (Phantom HuMo)](https://github.com/Phantom-video/Phantom): Human-centric video generation framework focus on subject ID consistency
- [CausVid-Plus / WAN-CausVid-Plus](https://github.com/goatWu/CausVid-Plus/): (Wan2.1)Causal diffusion for high-quality temporally consistent long videos
- [Wan2GP (workflow/GUI for Wan)](https://github.com/deepbeepmeep/Wan2GP): (Wan)Web-based UI focused on running complex video models for GPU-poor setups
- [LivePortrait](https://github.com/KwaiVGI/LivePortrait): Efficient portrait animation system with high stitching and retargeting control
- [Magi (SandAI)](https://github.com/SandAI-org/MAGI-1): High-quality autoregressive video generation framework
- [Ming (inclusionAI)](https://github.com/inclusionAI/Ming): Unified multimodal model for processing text, audio, image, and video
### Other/Unsorted
- [DiffusionForcing](https://github.com/kwsong0113/diffusion-forcing-transformer): Full-sequence diffusion with autoregressive next-token prediction
- [Self-Forcing](https://github.com/guandeh17/Self-Forcing): Framework for improving temporal consistency in long-horizon video generation
- [SEVA](https://github.com/huggingface/diffusers/pull/11440): Stable Virtual Camera for novel view synthesis and 3D-consistent video
- [ByteDance USO](https://github.com/bytedance/USO): Unified Style-Subject Optimized framework for personalized image generation
- [ByteDance Lynx](https://github.com/bytedance/lynx): State-of-the-art high-fidelity personalized video generation based on DiT
- [LanDiff](https://github.com/landiff/landiff): Coarse-to-fine text-to-video integrating Language and Diffusion Models
- [Video Inpaint Pipeline](https://github.com/huggingface/diffusers/pull/12506): Unified inpainting pipeline implementation within Diffusers library
- [Sonic Inpaint](https://github.com/ubc-vision/sonic): Audio-driven portrait animation system focus on global audio perception
- [Make-It-Count](https://github.com/Litalby1/make-it-count): CountGen method for precise numerical control of objects via object identity features
- [ControlNeXt](https://github.com/dvlab-research/ControlNeXt/): Lightweight architecture for efficient controllable image and video generation
- [MS-Diffusion](https://github.com/MS-Diffusion/MS-Diffusion): Layout-guided multi-subject image personalization framework
- [UniRef](https://github.com/FoundationVision/UniRef): Unified model for segmentation tasks designed as foundation model plug-in
- [FlashFace](https://github.com/ali-vilab/FlashFace): High-fidelity human image customization and face swapping framework
- [ReNO](https://github.com/ExplainableML/ReNO): Reward-based Noise Optimization to improve text-to-image quality during inference
### Not Planned
- [Bria FIBO](https://huggingface.co/briaai/FIBO): Fully JSON based
- [Bria FiboEdit](https://github.com/huggingface/diffusers/commit/d7a1c31f4f85bae5a9e01cdce49bd7346bd8ccd6): Fully JSON based
- [LoRAdapter](https://github.com/CompVis/LoRAdapter): Not recently updated
- [SD3 UltraEdit](https://github.com/HaozheZhao/UltraEdit): Based on SD3
- [PowerPaint](https://github.com/open-mmlab/PowerPaint): Based on SD15
- [FreeCustom](https://github.com/aim-uofa/FreeCustom): Based on SD15
- [AnyDoor](https://github.com/ali-vilab/AnyDoor): Based on SD21
- [AnyText2](https://github.com/tyxsspa/AnyText2): Based on SD15
- [DragonDiffusion](https://github.com/MC-E/DragonDiffusion): Based on SD15
- [DenseDiffusion](https://github.com/naver-ai/DenseDiffusion): Based on SD15
- [IC-Light](https://github.com/lllyasviel/IC-Light): Based on SD15
## Migration
### Asyncio
- Policy system is deprecated and will be removed in **Python 3.16**
- [Python 3.14 removals - asyncio](https://docs.python.org/3.14/whatsnew/3.14.html#id10)
- https://docs.python.org/3.14/library/asyncio-policy.html
- Affected files:
- [`webui.py`](webui.py)
- [`cli/sdapi.py`](cli/sdapi.py)
- Migration:
- [asyncio.run](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.run)
- [asyncio.Runner](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.Runner)
- Policy system is deprecated and will be removed in Python 3.16
[Python 3.14 removalsasyncio](https://docs.python.org/3.14/whatsnew/3.14.html#id10)
https://docs.python.org/3.14/library/asyncio-policy.html
Affected files:
[`webui.py`](webui.py)
[`cli/sdapi.py`](cli/sdapi.py)
Migration:
[asyncio.run](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.run)
[asyncio.Runner](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.Runner)
#### rmtree
### rmtree
- `onerror` deprecated and replaced with `onexc` in **Python 3.12**
- `onerror` deprecated and replaced with `onexc` in Python 3.12
``` python
def excRemoveReadonly(func, path, exc: BaseException):
import stat

35
cli/api-samplers.py Normal file
View File

@ -0,0 +1,35 @@
#!/usr/bin/env python
"""
get list of all samplers and details of current sampler
"""
import sys
import logging
import urllib3
import requests
url = "http://127.0.0.1:7860"
user = ""
password = ""
log_format = '%(asctime)s %(levelname)s: %(message)s'
logging.basicConfig(level = logging.INFO, format = log_format)
log = logging.getLogger("sd")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
log.info('available samplers')
auth = requests.auth.HTTPBasicAuth(user, password) if len(user) > 0 and len(password) > 0 else None
req = requests.get(f'{url}/sdapi/v1/samplers', verify=False, auth=auth, timeout=60)
if req.status_code != 200:
log.error({ 'url': req.url, 'request': req.status_code, 'reason': req.reason })
exit(1)
res = req.json()
for item in res:
log.info(item)
log.info('current sampler')
req = requests.get(f'{url}/sdapi/v1/sampler', verify=False, auth=auth, timeout=60)
res = req.json()
log.info(res)

42
cli/api-xyzenum.py Executable file
View File

@ -0,0 +1,42 @@
#!/usr/bin/env python
import os
import logging
import requests
import urllib3
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
sd_username = os.environ.get('SDAPI_USR', None)
sd_password = os.environ.get('SDAPI_PWD', None)
options = {
"save_images": True,
"send_images": True,
}
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
log = logging.getLogger(__name__)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def auth():
if sd_username is not None and sd_password is not None:
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
return None
def get(endpoint: str, dct: dict = None):
req = requests.get(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
if req.status_code != 200:
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
else:
return req.json()
if __name__ == "__main__":
options = get('/sdapi/v1/xyz-grid')
log.info(f'api-xyzgrid-options: {len(options)}')
for option in options:
log.info(f' {option}')
details = get('/sdapi/v1/xyz-grid?option=upscaler')
for choice in details[0]['choices']:
log.info(f' {choice}')

260
cli/test-schedulers.py Normal file
View File

@ -0,0 +1,260 @@
import os
import sys
import time
import numpy as np
import torch
# Ensure we can import modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from modules.errors import log
from modules.res4lyf import (
BASE, SIMPLE, VARIANTS,
RESUnifiedScheduler, RESMultistepScheduler, RESDEISMultistepScheduler,
ETDRKScheduler, LawsonScheduler, ABNorsettScheduler, PECScheduler,
RiemannianFlowScheduler, RESSinglestepScheduler, RESSinglestepSDEScheduler,
RESMultistepSDEScheduler, SimpleExponentialScheduler, LinearRKScheduler,
LobattoScheduler, GaussLegendreScheduler, RungeKutta44Scheduler,
RungeKutta57Scheduler, RungeKutta67Scheduler, SpecializedRKScheduler,
BongTangentScheduler, CommonSigmaScheduler, RadauIIAScheduler,
LangevinDynamicsScheduler
)
from modules.schedulers.scheduler_vdm import VDMScheduler
from modules.schedulers.scheduler_unipc_flowmatch import FlowUniPCMultistepScheduler
from modules.schedulers.scheduler_ufogen import UFOGenScheduler
from modules.schedulers.scheduler_tdd import TDDScheduler
from modules.schedulers.scheduler_tcd import TCDScheduler
from modules.schedulers.scheduler_flashflow import FlashFlowMatchEulerDiscreteScheduler
from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler
from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler
from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler
def test_scheduler(name, scheduler_class, config):
try:
scheduler = scheduler_class(**config)
except Exception as e:
log.error(f'scheduler="{name}" cls={scheduler_class} config={config} error="Init failed: {e}"')
return False
num_steps = 20
scheduler.set_timesteps(num_steps)
sample = torch.randn((1, 4, 64, 64))
has_changed = False
t0 = time.time()
messages = []
try:
for i, t in enumerate(scheduler.timesteps):
# Simulate model output (noise or x0 or v), Using random noise for stability check
model_output = torch.randn_like(sample)
# Scaling Check
step_idx = scheduler.step_index if hasattr(scheduler, "step_index") and scheduler.step_index is not None else i
# Clamp index
if hasattr(scheduler, 'sigmas'):
step_idx = min(step_idx, len(scheduler.sigmas) - 1)
sigma = scheduler.sigmas[step_idx]
else:
sigma = torch.tensor(1.0) # Dummy for non-sigma schedulers
# Re-introduce scaling calculation first
scaled_sample = scheduler.scale_model_input(sample, t)
if config.get("prediction_type") == "flow_prediction" or name in ["UFOGenScheduler", "TDDScheduler", "TCDScheduler", "BDIA_DDIMScheduler", "DCSolverMultistepScheduler"]:
# Some new schedulers don't use K-diffusion scaling
expected_scale = 1.0
else:
expected_scale = 1.0 / ((sigma**2 + 1) ** 0.5)
# Simple check with loose tolerance due to float precision
expected_scaled_sample = sample * expected_scale
if not torch.allclose(scaled_sample, expected_scaled_sample, atol=1e-4):
# If failed, double check if it's just 'sample' (no scaling)
if torch.allclose(scaled_sample, sample, atol=1e-4):
messages.append('warning="scaling is identity"')
else:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} expected={expected_scale} error="scaling mismatch"')
return False
if torch.isnan(scaled_sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="NaN in scaled_sample"')
return False
if torch.isinf(scaled_sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Inf in scaled_sample"')
return False
output = scheduler.step(model_output, t, sample)
# Shape and Dtype check
if output.prev_sample.shape != sample.shape:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Shape mismatch: {output.prev_sample.shape} vs {sample.shape}"')
return False
if output.prev_sample.dtype != sample.dtype:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Dtype mismatch: {output.prev_sample.dtype} vs {sample.dtype}"')
return False
# Update check: Did the sample change?
if not torch.equal(sample, output.prev_sample):
has_changed = True
# Sample Evolution Check
step_diff = (sample - output.prev_sample).abs().mean().item()
if step_diff < 1e-6:
messages.append(f'warning="minimal sample change: {step_diff}"')
sample = output.prev_sample
if torch.isnan(sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="NaN in sample"')
return False
if torch.isinf(sample).any():
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Inf in sample"')
return False
# Divergence check
if sample.abs().max() > 1e10:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="divergence detected"')
return False
# External check for Sigma Monotonicity
if hasattr(scheduler, 'sigmas'):
sigmas = scheduler.sigmas.cpu().numpy()
if len(sigmas) > 1:
diffs = np.diff(sigmas) # Check if potentially monotonic decreasing (standard) OR increasing (some flow/inverse setups). We allow flat sections (diff=0) hence 1e-6 slack
is_monotonic_decreasing = np.all(diffs <= 1e-6)
is_monotonic_increasing = np.all(diffs >= -1e-6)
if not (is_monotonic_decreasing or is_monotonic_increasing):
messages.append('warning="sigmas are not monotonic"')
except Exception as e:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} exception: {e}')
import traceback
traceback.print_exc()
return False
if not has_changed:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} error="sample never changed"')
return False
final_std = sample.std().item()
if final_std > 50.0 or final_std < 0.1:
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} std={final_std} error="variance drift"')
t1 = time.time()
messages = list(set(messages))
log.info(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} time={t1-t0} messages={messages}')
return True
def run_tests():
prediction_types = ["epsilon", "v_prediction", "sample"] # flow_prediction is special, usually requires flow sigmas or specific setup, checking standard ones first
# Test BASE schedulers with their specific parameters
log.warning('type="base"')
for name, cls in BASE:
configs = []
# prediction_types
for pt in prediction_types:
configs.append({"prediction_type": pt})
# Specific params for specific classes
if cls == RESUnifiedScheduler:
rk_types = ["res_2m", "res_3m", "res_2s", "res_3s", "res_5s", "res_6s", "deis_1s", "deis_2m", "deis_3m"]
for rk in rk_types:
for pt in prediction_types:
configs.append({"rk_type": rk, "prediction_type": pt})
elif cls == RESMultistepScheduler:
variants = ["res_2m", "res_3m", "deis_2m", "deis_3m"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == RESDEISMultistepScheduler:
for order in range(1, 6):
for pt in prediction_types:
configs.append({"solver_order": order, "prediction_type": pt})
elif cls == ETDRKScheduler:
variants = ["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == LawsonScheduler:
variants = ["lawson2a_2s", "lawson2b_2s", "lawson4_4s"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == ABNorsettScheduler:
variants = ["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == PECScheduler:
variants = ["pec423_2h2s", "pec433_2h3s"]
for v in variants:
for pt in prediction_types:
configs.append({"variant": v, "prediction_type": pt})
elif cls == RiemannianFlowScheduler:
metrics = ["euclidean", "hyperbolic", "spherical", "lorentzian"]
for m in metrics:
configs.append({"metric_type": m, "prediction_type": "epsilon"}) # Flow usually uses v or raw, but epsilon check matches others
if not configs:
for pt in prediction_types:
configs.append({"prediction_type": pt})
for conf in configs:
test_scheduler(name, cls, conf)
log.warning('type="simple"')
for name, cls in SIMPLE:
for pt in prediction_types:
test_scheduler(name, cls, {"prediction_type": pt})
log.warning('type="variants"')
for name, cls in VARIANTS:
# these classes preset their variants/rk_types in __init__ so we just test prediction types
for pt in prediction_types:
test_scheduler(name, cls, {"prediction_type": pt})
# Extra robustness check: Flow Prediction Type
log.warning('type="flow"')
flow_schedulers = [
# res4lyf schedulers
RESUnifiedScheduler, RESMultistepScheduler, ABNorsettScheduler,
RESSinglestepScheduler, RESSinglestepSDEScheduler, RESDEISMultistepScheduler,
RESMultistepSDEScheduler, ETDRKScheduler, LawsonScheduler, PECScheduler,
SimpleExponentialScheduler, LinearRKScheduler, LobattoScheduler,
GaussLegendreScheduler, RungeKutta44Scheduler, RungeKutta57Scheduler,
RungeKutta67Scheduler, SpecializedRKScheduler, BongTangentScheduler,
CommonSigmaScheduler, RadauIIAScheduler, LangevinDynamicsScheduler,
RiemannianFlowScheduler,
# sdnext schedulers
FlowUniPCMultistepScheduler, FlashFlowMatchEulerDiscreteScheduler, FlowMatchDPMSolverMultistepScheduler,
]
for cls in flow_schedulers:
test_scheduler(cls.__name__, cls, {"prediction_type": "flow_prediction", "use_flow_sigmas": True})
log.warning('type="sdnext"')
extended_schedulers = [
VDMScheduler,
UFOGenScheduler,
TDDScheduler,
TCDScheduler,
DCSolverMultistepScheduler,
BDIA_DDIMScheduler
]
for prediction_type in ["epsilon", "v_prediction", "sample"]:
for cls in extended_schedulers:
test_scheduler(cls.__name__, cls, {"prediction_type": prediction_type})
if __name__ == "__main__":
run_tests()

847
cli/test-tagger.py Normal file
View File

@ -0,0 +1,847 @@
#!/usr/bin/env python
"""
Tagger Settings Test Suite
Tests all WaifuDiffusion and DeepBooru tagger settings to verify they're properly
mapped and affect output correctly.
Usage:
python cli/test-tagger.py [image_path]
If no image path is provided, uses a built-in test image.
"""
import os
import sys
import time
# Add parent directory to path for imports
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, script_dir)
os.chdir(script_dir)
# Suppress installer output during import
os.environ['SD_INSTALL_QUIET'] = '1'
# Initialize cmd_args properly with all argument groups
import modules.cmd_args
import installer
# Add installer args to the parser
installer.add_args(modules.cmd_args.parser)
# Parse with empty args to get defaults
modules.cmd_args.parsed, _ = modules.cmd_args.parser.parse_known_args([])
# Now we can safely import modules that depend on cmd_args
# Default test images (in order of preference)
DEFAULT_TEST_IMAGES = [
'html/sdnext-robot-2k.jpg', # SD.Next robot mascot
'venv/lib/python3.13/site-packages/gradio/test_data/lion.jpg',
'venv/lib/python3.13/site-packages/gradio/test_data/cheetah1.jpg',
'venv/lib/python3.13/site-packages/skimage/data/astronaut.png',
'venv/lib/python3.13/site-packages/skimage/data/coffee.png',
]
def find_test_image():
"""Find a suitable test image from defaults."""
for img_path in DEFAULT_TEST_IMAGES:
full_path = os.path.join(script_dir, img_path)
if os.path.exists(full_path):
return full_path
return None
def create_test_image():
"""Create a simple test image as fallback."""
from PIL import Image, ImageDraw
img = Image.new('RGB', (512, 512), color=(200, 150, 100))
draw = ImageDraw.Draw(img)
draw.ellipse([100, 100, 400, 400], fill=(255, 200, 150), outline=(100, 50, 0))
draw.rectangle([150, 200, 350, 350], fill=(150, 100, 200))
return img
class TaggerTest:
"""Test harness for tagger settings."""
def __init__(self):
self.results = {'passed': [], 'failed': [], 'skipped': []}
self.test_image = None
self.waifudiffusion_loaded = False
self.deepbooru_loaded = False
def log_pass(self, msg):
print(f" [PASS] {msg}")
self.results['passed'].append(msg)
def log_fail(self, msg):
print(f" [FAIL] {msg}")
self.results['failed'].append(msg)
def log_skip(self, msg):
print(f" [SKIP] {msg}")
self.results['skipped'].append(msg)
def log_warn(self, msg):
print(f" [WARN] {msg}")
self.results['skipped'].append(msg)
def setup(self):
"""Load test image and models."""
from PIL import Image
print("=" * 70)
print("TAGGER SETTINGS TEST SUITE")
print("=" * 70)
# Get or create test image
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
img_path = sys.argv[1]
print(f"\nUsing provided image: {img_path}")
self.test_image = Image.open(img_path).convert('RGB')
else:
img_path = find_test_image()
if img_path:
print(f"\nUsing default test image: {img_path}")
self.test_image = Image.open(img_path).convert('RGB')
else:
print("\nNo test image found, creating synthetic image...")
self.test_image = create_test_image()
print(f"Image size: {self.test_image.size}")
# Load models
print("\nLoading models...")
from modules.interrogate import waifudiffusion, deepbooru
t0 = time.time()
self.waifudiffusion_loaded = waifudiffusion.load_model()
print(f" WaifuDiffusion: {'loaded' if self.waifudiffusion_loaded else 'FAILED'} ({time.time()-t0:.1f}s)")
t0 = time.time()
self.deepbooru_loaded = deepbooru.load_model()
print(f" DeepBooru: {'loaded' if self.deepbooru_loaded else 'FAILED'} ({time.time()-t0:.1f}s)")
def cleanup(self):
"""Unload models and free memory."""
print("\n" + "=" * 70)
print("CLEANUP")
print("=" * 70)
from modules.interrogate import waifudiffusion, deepbooru
from modules import devices
waifudiffusion.unload_model()
deepbooru.unload_model()
devices.torch_gc(force=True)
print(" Models unloaded")
def print_summary(self):
"""Print test summary."""
print("\n" + "=" * 70)
print("TEST SUMMARY")
print("=" * 70)
print(f"\n PASSED: {len(self.results['passed'])}")
for item in self.results['passed']:
print(f" - {item}")
print(f"\n FAILED: {len(self.results['failed'])}")
for item in self.results['failed']:
print(f" - {item}")
print(f"\n SKIPPED: {len(self.results['skipped'])}")
for item in self.results['skipped']:
print(f" - {item}")
total = len(self.results['passed']) + len(self.results['failed'])
if total > 0:
success_rate = len(self.results['passed']) / total * 100
print(f"\n SUCCESS RATE: {success_rate:.1f}% ({len(self.results['passed'])}/{total})")
print("\n" + "=" * 70)
# =========================================================================
# TEST: ONNX Providers Detection
# =========================================================================
def test_onnx_providers(self):
"""Verify ONNX runtime providers are properly detected."""
print("\n" + "=" * 70)
print("TEST: ONNX Providers Detection")
print("=" * 70)
from modules import devices
# Test 1: onnxruntime can be imported
try:
import onnxruntime as ort
self.log_pass(f"onnxruntime imported: version={ort.__version__}")
except ImportError as e:
self.log_fail(f"onnxruntime import failed: {e}")
return
# Test 2: Get available providers
available = ort.get_available_providers()
if available and len(available) > 0:
self.log_pass(f"Available providers: {available}")
else:
self.log_fail("No ONNX providers available")
return
# Test 3: devices.onnx is properly configured
if devices.onnx is not None and len(devices.onnx) > 0:
self.log_pass(f"devices.onnx configured: {devices.onnx}")
else:
self.log_fail(f"devices.onnx not configured: {devices.onnx}")
# Test 4: Configured providers exist in available providers
for provider in devices.onnx:
if provider in available:
self.log_pass(f"Provider '{provider}' is available")
else:
self.log_fail(f"Provider '{provider}' configured but not available")
# Test 5: If WaifuDiffusion loaded, check session providers
if self.waifudiffusion_loaded:
from modules.interrogate import waifudiffusion
if waifudiffusion.tagger.session is not None:
session_providers = waifudiffusion.tagger.session.get_providers()
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
else:
self.log_skip("WaifuDiffusion session not initialized")
# =========================================================================
# TEST: Memory Management (Offload/Reload/Unload)
# =========================================================================
def get_memory_stats(self):
"""Get current GPU and CPU memory usage."""
import torch
stats = {}
# GPU memory (if CUDA available)
if torch.cuda.is_available():
torch.cuda.synchronize()
stats['gpu_allocated'] = torch.cuda.memory_allocated() / 1024 / 1024 # MB
stats['gpu_reserved'] = torch.cuda.memory_reserved() / 1024 / 1024 # MB
else:
stats['gpu_allocated'] = 0
stats['gpu_reserved'] = 0
# CPU/RAM memory (try psutil, fallback to basic)
try:
import psutil
process = psutil.Process()
stats['ram_used'] = process.memory_info().rss / 1024 / 1024 # MB
except ImportError:
stats['ram_used'] = 0
return stats
def test_memory_management(self):
"""Test model offload to RAM, reload to GPU, and unload with memory monitoring."""
print("\n" + "=" * 70)
print("TEST: Memory Management (Offload/Reload/Unload)")
print("=" * 70)
import torch
import gc
from modules import devices
from modules.interrogate import waifudiffusion, deepbooru
# Memory leak tolerance (MB) - some variance is expected
GPU_LEAK_TOLERANCE_MB = 50
RAM_LEAK_TOLERANCE_MB = 200
# =====================================================================
# DeepBooru: Test GPU/CPU movement with memory monitoring
# =====================================================================
if self.deepbooru_loaded:
print("\n DeepBooru Memory Management:")
# Baseline memory before any operations
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
baseline = self.get_memory_stats()
print(f" Baseline: GPU={baseline['gpu_allocated']:.1f}MB, RAM={baseline['ram_used']:.1f}MB")
# Test 1: Check initial state (should be on CPU after load)
initial_device = next(deepbooru.model.model.parameters()).device
print(f" Initial device: {initial_device}")
if initial_device.type == 'cpu':
self.log_pass("DeepBooru: initial state on CPU")
else:
self.log_pass(f"DeepBooru: initial state on {initial_device}")
# Test 2: Move to GPU (start)
deepbooru.model.start()
gpu_device = next(deepbooru.model.model.parameters()).device
after_gpu = self.get_memory_stats()
print(f" After start(): {gpu_device} | GPU={after_gpu['gpu_allocated']:.1f}MB (+{after_gpu['gpu_allocated']-baseline['gpu_allocated']:.1f}MB)")
if gpu_device.type == devices.device.type:
self.log_pass(f"DeepBooru: moved to GPU ({gpu_device})")
else:
self.log_fail(f"DeepBooru: failed to move to GPU, got {gpu_device}")
# Test 3: Run inference while on GPU
try:
tags = deepbooru.model.tag_multi(self.test_image, max_tags=3)
after_infer = self.get_memory_stats()
print(f" After inference: GPU={after_infer['gpu_allocated']:.1f}MB")
if tags:
self.log_pass(f"DeepBooru: inference on GPU works ({tags[:30]}...)")
else:
self.log_fail("DeepBooru: inference on GPU returned empty")
except Exception as e:
self.log_fail(f"DeepBooru: inference on GPU failed: {e}")
# Test 4: Offload to CPU (stop)
deepbooru.model.stop()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
after_offload = self.get_memory_stats()
cpu_device = next(deepbooru.model.model.parameters()).device
print(f" After stop(): {cpu_device} | GPU={after_offload['gpu_allocated']:.1f}MB, RAM={after_offload['ram_used']:.1f}MB")
if cpu_device.type == 'cpu':
self.log_pass("DeepBooru: offloaded to CPU")
else:
self.log_fail(f"DeepBooru: failed to offload, still on {cpu_device}")
# Check GPU memory returned to near baseline after offload
gpu_diff = after_offload['gpu_allocated'] - baseline['gpu_allocated']
if gpu_diff <= GPU_LEAK_TOLERANCE_MB:
self.log_pass(f"DeepBooru: GPU memory cleared after offload (diff={gpu_diff:.1f}MB)")
else:
self.log_fail(f"DeepBooru: GPU memory leak after offload (diff={gpu_diff:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
# Test 5: Full cycle - reload and run again
deepbooru.model.start()
try:
tags = deepbooru.model.tag_multi(self.test_image, max_tags=3)
if tags:
self.log_pass("DeepBooru: reload cycle works")
else:
self.log_fail("DeepBooru: reload cycle returned empty")
except Exception as e:
self.log_fail(f"DeepBooru: reload cycle failed: {e}")
deepbooru.model.stop()
# Test 6: Full unload with memory check
deepbooru.unload_model()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
after_unload = self.get_memory_stats()
print(f" After unload: GPU={after_unload['gpu_allocated']:.1f}MB, RAM={after_unload['ram_used']:.1f}MB")
if deepbooru.model.model is None:
self.log_pass("DeepBooru: unload successful")
else:
self.log_fail("DeepBooru: unload failed, model still exists")
# Check for memory leaks after full unload
gpu_leak = after_unload['gpu_allocated'] - baseline['gpu_allocated']
ram_leak = after_unload['ram_used'] - baseline['ram_used']
if gpu_leak <= GPU_LEAK_TOLERANCE_MB:
self.log_pass(f"DeepBooru: no GPU memory leak after unload (diff={gpu_leak:.1f}MB)")
else:
self.log_fail(f"DeepBooru: GPU memory leak detected (diff={gpu_leak:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
if ram_leak <= RAM_LEAK_TOLERANCE_MB:
self.log_pass(f"DeepBooru: no RAM leak after unload (diff={ram_leak:.1f}MB)")
else:
self.log_warn(f"DeepBooru: RAM increased after unload (diff={ram_leak:.1f}MB) - may be caching")
# Reload for remaining tests
deepbooru.load_model()
# =====================================================================
# WaifuDiffusion: Test session lifecycle with memory monitoring
# =====================================================================
if self.waifudiffusion_loaded:
print("\n WaifuDiffusion Memory Management:")
# Baseline memory
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
baseline = self.get_memory_stats()
print(f" Baseline: GPU={baseline['gpu_allocated']:.1f}MB, RAM={baseline['ram_used']:.1f}MB")
# Test 1: Session exists
if waifudiffusion.tagger.session is not None:
self.log_pass("WaifuDiffusion: session loaded")
else:
self.log_fail("WaifuDiffusion: session not loaded")
return
# Test 2: Get current providers
providers = waifudiffusion.tagger.session.get_providers()
print(f" Active providers: {providers}")
self.log_pass(f"WaifuDiffusion: using providers {providers}")
# Test 3: Run inference
try:
tags = waifudiffusion.tagger.predict(self.test_image, max_tags=3)
after_infer = self.get_memory_stats()
print(f" After inference: GPU={after_infer['gpu_allocated']:.1f}MB, RAM={after_infer['ram_used']:.1f}MB")
if tags:
self.log_pass(f"WaifuDiffusion: inference works ({tags[:30]}...)")
else:
self.log_fail("WaifuDiffusion: inference returned empty")
except Exception as e:
self.log_fail(f"WaifuDiffusion: inference failed: {e}")
# Test 4: Unload session with memory check
model_name = waifudiffusion.tagger.model_name
waifudiffusion.unload_model()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
after_unload = self.get_memory_stats()
print(f" After unload: GPU={after_unload['gpu_allocated']:.1f}MB, RAM={after_unload['ram_used']:.1f}MB")
if waifudiffusion.tagger.session is None:
self.log_pass("WaifuDiffusion: unload successful")
else:
self.log_fail("WaifuDiffusion: unload failed, session still exists")
# Check for memory leaks after unload
gpu_leak = after_unload['gpu_allocated'] - baseline['gpu_allocated']
ram_leak = after_unload['ram_used'] - baseline['ram_used']
if gpu_leak <= GPU_LEAK_TOLERANCE_MB:
self.log_pass(f"WaifuDiffusion: no GPU memory leak after unload (diff={gpu_leak:.1f}MB)")
else:
self.log_fail(f"WaifuDiffusion: GPU memory leak detected (diff={gpu_leak:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
if ram_leak <= RAM_LEAK_TOLERANCE_MB:
self.log_pass(f"WaifuDiffusion: no RAM leak after unload (diff={ram_leak:.1f}MB)")
else:
self.log_warn(f"WaifuDiffusion: RAM increased after unload (diff={ram_leak:.1f}MB) - may be caching")
# Test 5: Reload session
waifudiffusion.load_model(model_name)
after_reload = self.get_memory_stats()
print(f" After reload: GPU={after_reload['gpu_allocated']:.1f}MB, RAM={after_reload['ram_used']:.1f}MB")
if waifudiffusion.tagger.session is not None:
self.log_pass("WaifuDiffusion: reload successful")
else:
self.log_fail("WaifuDiffusion: reload failed")
# Test 6: Inference after reload
try:
tags = waifudiffusion.tagger.predict(self.test_image, max_tags=3)
if tags:
self.log_pass("WaifuDiffusion: inference after reload works")
else:
self.log_fail("WaifuDiffusion: inference after reload returned empty")
except Exception as e:
self.log_fail(f"WaifuDiffusion: inference after reload failed: {e}")
# Final memory check after full cycle
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
final = self.get_memory_stats()
print(f" Final (after full cycle): GPU={final['gpu_allocated']:.1f}MB, RAM={final['ram_used']:.1f}MB")
# =========================================================================
# TEST: Settings Existence
# =========================================================================
def test_settings_exist(self):
"""Verify all tagger settings exist in shared.opts."""
print("\n" + "=" * 70)
print("TEST: Settings Existence")
print("=" * 70)
from modules import shared
settings = [
('tagger_threshold', float),
('tagger_include_rating', bool),
('tagger_max_tags', int),
('tagger_sort_alpha', bool),
('tagger_use_spaces', bool),
('tagger_escape_brackets', bool),
('tagger_exclude_tags', str),
('tagger_show_scores', bool),
('waifudiffusion_model', str),
('waifudiffusion_character_threshold', float),
('interrogate_offload', bool),
]
for setting, _expected_type in settings:
if hasattr(shared.opts, setting):
value = getattr(shared.opts, setting)
self.log_pass(f"{setting} = {value!r}")
else:
self.log_fail(f"{setting} - NOT FOUND")
# =========================================================================
# TEST: Parameter Effect - Tests a single parameter on both taggers
# =========================================================================
def test_parameter(self, param_name, test_func, waifudiffusion_supported=True, deepbooru_supported=True):
"""Test a parameter on both WaifuDiffusion and DeepBooru."""
print(f"\n Testing: {param_name}")
if waifudiffusion_supported and self.waifudiffusion_loaded:
try:
result = test_func('waifudiffusion')
if result is True:
self.log_pass(f"WaifuDiffusion: {param_name}")
elif result is False:
self.log_fail(f"WaifuDiffusion: {param_name}")
else:
self.log_skip(f"WaifuDiffusion: {param_name} - {result}")
except Exception as e:
self.log_fail(f"WaifuDiffusion: {param_name} - {e}")
elif waifudiffusion_supported:
self.log_skip(f"WaifuDiffusion: {param_name} - model not loaded")
if deepbooru_supported and self.deepbooru_loaded:
try:
result = test_func('deepbooru')
if result is True:
self.log_pass(f"DeepBooru: {param_name}")
elif result is False:
self.log_fail(f"DeepBooru: {param_name}")
else:
self.log_skip(f"DeepBooru: {param_name} - {result}")
except Exception as e:
self.log_fail(f"DeepBooru: {param_name} - {e}")
elif deepbooru_supported:
self.log_skip(f"DeepBooru: {param_name} - model not loaded")
def tag(self, tagger, **kwargs):
"""Helper to call the appropriate tagger."""
if tagger == 'waifudiffusion':
from modules.interrogate import waifudiffusion
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
else:
from modules.interrogate import deepbooru
return deepbooru.model.tag(self.test_image, **kwargs)
# =========================================================================
# TEST: general_threshold
# =========================================================================
def test_threshold(self):
"""Test that threshold affects tag count."""
print("\n" + "=" * 70)
print("TEST: general_threshold effect")
print("=" * 70)
def check_threshold(tagger):
tags_high = self.tag(tagger, general_threshold=0.9)
tags_low = self.tag(tagger, general_threshold=0.1)
count_high = len(tags_high.split(', ')) if tags_high else 0
count_low = len(tags_low.split(', ')) if tags_low else 0
print(f" {tagger}: threshold=0.9 -> {count_high} tags, threshold=0.1 -> {count_low} tags")
if count_low > count_high:
return True
elif count_low == count_high == 0:
return "no tags returned"
else:
return "threshold effect unclear"
self.test_parameter('general_threshold', check_threshold)
# =========================================================================
# TEST: max_tags
# =========================================================================
def test_max_tags(self):
"""Test that max_tags limits output."""
print("\n" + "=" * 70)
print("TEST: max_tags effect")
print("=" * 70)
def check_max_tags(tagger):
tags_5 = self.tag(tagger, general_threshold=0.1, max_tags=5)
tags_50 = self.tag(tagger, general_threshold=0.1, max_tags=50)
count_5 = len(tags_5.split(', ')) if tags_5 else 0
count_50 = len(tags_50.split(', ')) if tags_50 else 0
print(f" {tagger}: max_tags=5 -> {count_5} tags, max_tags=50 -> {count_50} tags")
return count_5 <= 5
self.test_parameter('max_tags', check_max_tags)
# =========================================================================
# TEST: use_spaces
# =========================================================================
def test_use_spaces(self):
"""Test that use_spaces converts underscores to spaces."""
print("\n" + "=" * 70)
print("TEST: use_spaces effect")
print("=" * 70)
def check_use_spaces(tagger):
tags_under = self.tag(tagger, use_spaces=False, max_tags=10)
tags_space = self.tag(tagger, use_spaces=True, max_tags=10)
print(f" {tagger} use_spaces=False: {tags_under[:50]}...")
print(f" {tagger} use_spaces=True: {tags_space[:50]}...")
# Check if underscores are converted to spaces
has_underscore_before = '_' in tags_under
has_underscore_after = '_' in tags_space.replace(', ', ',') # ignore comma-space
# If there were underscores before but not after, it worked
if has_underscore_before and not has_underscore_after:
return True
# If there were never underscores, inconclusive
elif not has_underscore_before:
return "no underscores in tags to convert"
else:
return False
self.test_parameter('use_spaces', check_use_spaces)
# =========================================================================
# TEST: escape_brackets
# =========================================================================
def test_escape_brackets(self):
"""Test that escape_brackets escapes special characters."""
print("\n" + "=" * 70)
print("TEST: escape_brackets effect")
print("=" * 70)
def check_escape_brackets(tagger):
tags_escaped = self.tag(tagger, escape_brackets=True, max_tags=30, general_threshold=0.1)
tags_raw = self.tag(tagger, escape_brackets=False, max_tags=30, general_threshold=0.1)
print(f" {tagger} escape=True: {tags_escaped[:60]}...")
print(f" {tagger} escape=False: {tags_raw[:60]}...")
# Check for escaped brackets (\\( or \\))
has_escaped = '\\(' in tags_escaped or '\\)' in tags_escaped
has_unescaped = '(' in tags_raw.replace('\\(', '') or ')' in tags_raw.replace('\\)', '')
if has_escaped:
return True
elif has_unescaped:
# Has brackets but not escaped - fail
return False
else:
return "no brackets in tags to escape"
self.test_parameter('escape_brackets', check_escape_brackets)
# =========================================================================
# TEST: sort_alpha
# =========================================================================
def test_sort_alpha(self):
"""Test that sort_alpha sorts tags alphabetically."""
print("\n" + "=" * 70)
print("TEST: sort_alpha effect")
print("=" * 70)
def check_sort_alpha(tagger):
tags_conf = self.tag(tagger, sort_alpha=False, max_tags=20, general_threshold=0.1)
tags_alpha = self.tag(tagger, sort_alpha=True, max_tags=20, general_threshold=0.1)
list_conf = [t.strip() for t in tags_conf.split(',')]
list_alpha = [t.strip() for t in tags_alpha.split(',')]
print(f" {tagger} by_confidence: {', '.join(list_conf[:5])}...")
print(f" {tagger} alphabetical: {', '.join(list_alpha[:5])}...")
is_sorted = list_alpha == sorted(list_alpha)
return is_sorted
self.test_parameter('sort_alpha', check_sort_alpha)
# =========================================================================
# TEST: exclude_tags
# =========================================================================
def test_exclude_tags(self):
"""Test that exclude_tags removes specified tags."""
print("\n" + "=" * 70)
print("TEST: exclude_tags effect")
print("=" * 70)
def check_exclude_tags(tagger):
tags_all = self.tag(tagger, max_tags=50, general_threshold=0.1, exclude_tags='')
tag_list = [t.strip().replace(' ', '_') for t in tags_all.split(',')]
if len(tag_list) < 2:
return "not enough tags to test"
# Exclude the first tag
tag_to_exclude = tag_list[0]
tags_filtered = self.tag(tagger, max_tags=50, general_threshold=0.1, exclude_tags=tag_to_exclude)
print(f" {tagger} without exclusion: {tags_all[:50]}...")
print(f" {tagger} excluding '{tag_to_exclude}': {tags_filtered[:50]}...")
# Check if the exact tag was removed by parsing the filtered list
filtered_list = [t.strip().replace(' ', '_') for t in tags_filtered.split(',')]
# Also check space variant
tag_space_variant = tag_to_exclude.replace('_', ' ')
tag_present = tag_to_exclude in filtered_list or tag_space_variant in [t.strip() for t in tags_filtered.split(',')]
return not tag_present
self.test_parameter('exclude_tags', check_exclude_tags)
# =========================================================================
# TEST: tagger_show_scores (via shared.opts)
# =========================================================================
def test_show_scores(self):
"""Test that tagger_show_scores adds confidence scores."""
print("\n" + "=" * 70)
print("TEST: tagger_show_scores effect")
print("=" * 70)
from modules import shared
def check_show_scores(tagger):
original = shared.opts.tagger_show_scores
shared.opts.tagger_show_scores = False
tags_no_scores = self.tag(tagger, max_tags=5)
shared.opts.tagger_show_scores = True
tags_with_scores = self.tag(tagger, max_tags=5)
shared.opts.tagger_show_scores = original
print(f" {tagger} show_scores=False: {tags_no_scores[:50]}...")
print(f" {tagger} show_scores=True: {tags_with_scores[:50]}...")
has_scores = ':' in tags_with_scores and '(' in tags_with_scores
no_scores = ':' not in tags_no_scores
return has_scores and no_scores
self.test_parameter('tagger_show_scores', check_show_scores)
# =========================================================================
# TEST: include_rating
# =========================================================================
def test_include_rating(self):
"""Test that include_rating includes/excludes rating tags."""
print("\n" + "=" * 70)
print("TEST: include_rating effect")
print("=" * 70)
def check_include_rating(tagger):
tags_no_rating = self.tag(tagger, include_rating=False, max_tags=100, general_threshold=0.01)
tags_with_rating = self.tag(tagger, include_rating=True, max_tags=100, general_threshold=0.01)
print(f" {tagger} include_rating=False: {tags_no_rating[:60]}...")
print(f" {tagger} include_rating=True: {tags_with_rating[:60]}...")
# Rating tags typically start with "rating:" or are like "safe", "questionable", "explicit"
rating_keywords = ['rating:', 'safe', 'questionable', 'explicit', 'general', 'sensitive']
has_rating_before = any(kw in tags_no_rating.lower() for kw in rating_keywords)
has_rating_after = any(kw in tags_with_rating.lower() for kw in rating_keywords)
if has_rating_after and not has_rating_before:
return True
elif has_rating_after and has_rating_before:
return "rating tags appear in both (may need very low threshold)"
elif not has_rating_after:
return "no rating tags detected"
else:
return False
self.test_parameter('include_rating', check_include_rating)
# =========================================================================
# TEST: character_threshold (WaifuDiffusion only)
# =========================================================================
def test_character_threshold(self):
"""Test that character_threshold affects character tag count (WaifuDiffusion only)."""
print("\n" + "=" * 70)
print("TEST: character_threshold effect (WaifuDiffusion only)")
print("=" * 70)
def check_character_threshold(tagger):
if tagger != 'waifudiffusion':
return "not supported"
# Character threshold only affects character tags
# We need an image with character tags to properly test this
tags_high = self.tag(tagger, character_threshold=0.99, general_threshold=0.5)
tags_low = self.tag(tagger, character_threshold=0.1, general_threshold=0.5)
print(f" {tagger} char_threshold=0.99: {tags_high[:50]}...")
print(f" {tagger} char_threshold=0.10: {tags_low[:50]}...")
# If thresholds are different, the setting is at least being applied
# Hard to verify without an image with known character tags
return True # Setting exists and is applied (verified by code inspection)
self.test_parameter('character_threshold', check_character_threshold, deepbooru_supported=False)
# =========================================================================
# TEST: Unified Interface
# =========================================================================
def test_unified_interface(self):
"""Test that the unified tagger interface works for both backends."""
print("\n" + "=" * 70)
print("TEST: Unified tagger.tag() interface")
print("=" * 70)
from modules.interrogate import tagger
# Test WaifuDiffusion through unified interface
if self.waifudiffusion_loaded:
try:
models = tagger.get_models()
waifudiffusion_model = next((m for m in models if m != 'DeepBooru'), None)
if waifudiffusion_model:
tags = tagger.tag(self.test_image, model_name=waifudiffusion_model, max_tags=5)
print(f" WaifuDiffusion ({waifudiffusion_model}): {tags[:50]}...")
self.log_pass("Unified interface: WaifuDiffusion")
except Exception as e:
self.log_fail(f"Unified interface: WaifuDiffusion - {e}")
# Test DeepBooru through unified interface
if self.deepbooru_loaded:
try:
tags = tagger.tag(self.test_image, model_name='DeepBooru', max_tags=5)
print(f" DeepBooru: {tags[:50]}...")
self.log_pass("Unified interface: DeepBooru")
except Exception as e:
self.log_fail(f"Unified interface: DeepBooru - {e}")
def run_all_tests(self):
"""Run all tests."""
self.setup()
self.test_onnx_providers()
self.test_memory_management()
self.test_settings_exist()
self.test_threshold()
self.test_max_tags()
self.test_use_spaces()
self.test_escape_brackets()
self.test_sort_alpha()
self.test_exclude_tags()
self.test_show_scores()
self.test_include_rating()
self.test_character_threshold()
self.test_unified_interface()
self.cleanup()
self.print_summary()
return len(self.results['failed']) == 0
if __name__ == "__main__":
test = TaggerTest()
success = test.run_all_tests()
sys.exit(0 if success else 1)

View File

@ -128,5 +128,12 @@
"preview": "shuttleai--shuttle-jaguar.jpg",
"tags": "community",
"skip": true
},
"Anima": {
"path": "CalamitousFelicitousness/Anima-sdnext-diffusers",
"preview": "CalamitousFelicitousness--Anima-sdnext-diffusers.png",
"desc": "Modified Cosmos-Predict-2B that replaces the T5-11B text encoder with Qwen3-0.6B. Anima is a 2 billion parameter text-to-image model created via a collaboration between CircleStone Labs and Comfy Org. It is focused mainly on anime concepts, characters, and styles, but is also capable of generating a wide variety of other non-photorealistic content. The model is designed for making illustrations and artistic images, and will not work well at realism.",
"tags": "community",
"skip": true
}
}

View File

@ -143,6 +143,15 @@
"date": "2025 January"
},
"Z-Image": {
"path": "Tongyi-MAI/Z-Image",
"preview": "Tongyi-MAI--Z-Image.jpg",
"desc": "Z-Image, an efficient image generation foundation model built on a Single-Stream Diffusion Transformer architecture. It preserves the complete training signal with full CFG support, enabling aesthetic versatility from hyper-realistic photography to anime, enhanced output diversity, and robust negative prompting for artifact suppression. Ideal base for LoRA training, ControlNet, and semantic conditioning.",
"skip": true,
"extras": "sampler: Default, cfg_scale: 4.0, steps: 50",
"size": 20.3,
"date": "2026 January"
},
"Z-Image-Turbo": {
"path": "Tongyi-MAI/Z-Image-Turbo",
"preview": "Tongyi-MAI--Z-Image-Turbo.jpg",

View File

@ -53,6 +53,7 @@ const jsConfig = defineConfig([
generateForever: 'readonly',
showContributors: 'readonly',
opts: 'writable',
monitorOption: 'readonly',
sortUIElements: 'readonly',
all_gallery_buttons: 'readonly',
selected_gallery_button: 'readonly',
@ -98,6 +99,8 @@ const jsConfig = defineConfig([
idbAdd: 'readonly',
idbCount: 'readonly',
idbFolderCleanup: 'readonly',
idbClearAll: 'readonly',
idbIsReady: 'readonly',
initChangelog: 'readonly',
sendNotification: 'readonly',
monitorConnection: 'readonly',
@ -241,6 +244,9 @@ const jsonConfig = defineConfig([
plugins: { json },
language: 'json/json',
extends: ['json/recommended'],
rules: {
'json/no-empty-keys': 'off',
},
},
]);

View File

@ -90,7 +90,7 @@
{"id":"","label":"Embedding","localized":"","reload":"","hint":"Textual inversion embedding is a trained embedded information about the subject"},
{"id":"","label":"Hypernetwork","localized":"","reload":"","hint":"Small trained neural network that modifies behavior of the loaded model"},
{"id":"","label":"VLM Caption","localized":"","reload":"","hint":"Analyze image using vision langugage model"},
{"id":"","label":"CLiP Interrogate","localized":"","reload":"","hint":"Analyze image using CLiP model"},
{"id":"","label":"OpenCLiP","localized":"","reload":"","hint":"Analyze image using CLiP model via OpenCLiP"},
{"id":"","label":"VAE","localized":"","reload":"","hint":"Variational Auto Encoder: model used to run image decode at the end of generate"},
{"id":"","label":"History","localized":"","reload":"","hint":"List of previous generations that can be further reprocessed"},
{"id":"","label":"UI disable variable aspect ratio","localized":"","reload":"","hint":"When disabled, all thumbnails appear as squared images"},

View File

@ -112,7 +112,7 @@ def install_traceback(suppress: list = []):
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
if width is not None:
width = int(width)
traceback_install(
log.excepthook = traceback_install(
console=console,
extra_lines=int(os.environ.get("SD_TRACELINES", 1)),
max_frames=int(os.environ.get("SD_TRACEFRAMES", 16)),
@ -168,7 +168,6 @@ def setup_logging():
def get(self):
return self.buffer
class LogFilter(logging.Filter):
def __init__(self):
super().__init__()
@ -215,6 +214,23 @@ def setup_logging():
logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE)
logging.trace = partial(logging.log, logging.TRACE)
def exception_hook(e: Exception, suppress=[]):
from rich.traceback import Traceback
tb = Traceback.from_exception(type(e), e, e.__traceback__, show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
# print-to-console, does not get printed-to-file
exc_type, exc_value, exc_traceback = sys.exc_info()
log.excepthook(exc_type, exc_value, exc_traceback)
# print-to-file, temporarily disable-console-handler
for handler in log.handlers.copy():
if isinstance(handler, RichHandler):
log.removeHandler(handler)
with console.capture() as capture:
console.print(tb)
log.critical(capture.get())
log.addHandler(rh)
log.traceback = exception_hook
level = logging.DEBUG if (args.debug or args.trace) else logging.INFO
log.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
log.print = rprint
@ -240,8 +256,10 @@ def setup_logging():
)
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
pretty_install(console=console)
install_traceback()
while log.hasHandlers() and len(log.handlers) > 0:
log.removeHandler(log.handlers[0])
@ -288,7 +306,6 @@ def setup_logging():
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("ControlNet").handlers = log.handlers
logging.getLogger("lycoris").handlers = log.handlers
# logging.getLogger("DeepSpeed").handlers = log.handlers
ts('log', t_start)
@ -712,9 +729,9 @@ def install_cuda():
log.info('CUDA: nVidia toolkit detected')
ts('cuda', t_start)
if args.use_nightly:
cmd = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --extra-index-url https://download.pytorch.org/whl/nightly/cu126')
cmd = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --extra-index-url https://download.pytorch.org/whl/nightly/cu130')
else:
cmd = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+cu128 torchvision==0.24.1+cu128 --index-url https://download.pytorch.org/whl/cu128')
cmd = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+cu128 torchvision==0.25.0+cu128 --index-url https://download.pytorch.org/whl/cu128')
return cmd
@ -765,7 +782,6 @@ def install_rocm_zluda():
if sys.platform == "win32":
if args.use_zluda:
#check_python(supported_minors=[10, 11, 12, 13], reason='ZLUDA backend requires a Python version between 3.10 and 3.13')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+cu118 torchvision==0.22.1+cu118 --index-url https://download.pytorch.org/whl/cu118')
if args.device_id is not None:
@ -795,6 +811,7 @@ def install_rocm_zluda():
torch_command = os.environ.get('TORCH_COMMAND', f'torch torchvision --index-url https://rocm.nightlies.amd.com/{device.therock}')
else:
check_python(supported_minors=[12], reason='ROCm: Windows preview python==3.12 required')
# torch 2.8.0a0 is the last version with rocm 6.4 support
torch_command = os.environ.get('TORCH_COMMAND', '--no-cache-dir https://repo.radeon.com/rocm/windows/rocm-rel-6.4.4/torch-2.8.0a0%2Bgitfc14c65-cp312-cp312-win_amd64.whl https://repo.radeon.com/rocm/windows/rocm-rel-6.4.4/torchvision-0.24.0a0%2Bc85f008-cp312-cp312-win_amd64.whl')
else:
#check_python(supported_minors=[10, 11, 12, 13, 14], reason='ROCm backend requires a Python version between 3.10 and 3.13')
@ -804,7 +821,11 @@ def install_rocm_zluda():
else: # oldest rocm version on nightly is 7.0
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0')
else:
if rocm.version is None or float(rocm.version) >= 6.4: # assume the latest if version check fails
if rocm.version is None or float(rocm.version) >= 7.1: # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+rocm7.1 torchvision==0.25.0+rocm7.1 --index-url https://download.pytorch.org/whl/rocm7.1')
elif rocm.version == "7.0": # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+rocm7.0 torchvision==0.25.0+rocm7.0 --index-url https://download.pytorch.org/whl/rocm7.0')
elif rocm.version == "6.4":
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+rocm6.4 torchvision==0.24.1+rocm6.4 --index-url https://download.pytorch.org/whl/rocm6.4')
elif rocm.version == "6.3":
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+rocm6.3 torchvision==0.24.1+rocm6.3 --index-url https://download.pytorch.org/whl/rocm6.3')
@ -841,7 +862,7 @@ def install_ipex():
if args.use_nightly:
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+xpu torchvision==0.24.1+xpu --index-url https://download.pytorch.org/whl/xpu')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+xpu torchvision==0.25.0+xpu --index-url https://download.pytorch.org/whl/xpu')
ts('ipex', t_start)
return torch_command
@ -854,13 +875,13 @@ def install_openvino():
#check_python(supported_minors=[10, 11, 12, 13], reason='OpenVINO backend requires a Python version between 3.10 and 3.13')
if sys.platform == 'darwin':
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1 torchvision==0.24.1')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0 torchvision==0.25.0')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+cpu torchvision==0.24.1 --index-url https://download.pytorch.org/whl/cpu')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+cpu torchvision==0.25.0 --index-url https://download.pytorch.org/whl/cpu')
if not (args.skip_all or args.skip_requirements):
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.3.0'), 'openvino')
install(os.environ.get('NNCF_COMMAND', 'nncf==2.18.0'), 'nncf')
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino')
install(os.environ.get('NNCF_COMMAND', 'nncf==2.19.0'), 'nncf')
ts('openvino', t_start)
return torch_command
@ -1427,6 +1448,7 @@ def set_environment():
os.environ.setdefault('TORCH_CUDNN_V8_API_ENABLED', '1')
os.environ.setdefault('TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD', '1')
os.environ.setdefault('TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL', '1')
os.environ.setdefault('MIOPEN_FIND_MODE', '2')
os.environ.setdefault('UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS', '1')
os.environ.setdefault('USE_TORCH', '1')
os.environ.setdefault('UV_INDEX_STRATEGY', 'unsafe-any-match')
@ -1540,7 +1562,7 @@ def check_ui(ver):
t_start = time.time()
if not same(ver):
log.debug(f'Branch mismatch: sdnext={ver["branch"]} ui={ver["ui"]}')
log.debug(f'Branch mismatch: {ver}')
cwd = os.getcwd()
try:
os.chdir('extensions-builtin/sdnext-modernui')
@ -1548,10 +1570,7 @@ def check_ui(ver):
git('checkout ' + target, ignore=True, optional=True)
os.chdir(cwd)
ver = get_version(force=True)
if not same(ver):
log.debug(f'Branch synchronized: {ver["branch"]}')
else:
log.debug(f'Branch sync failed: sdnext={ver["branch"]} ui={ver["ui"]}')
log.debug(f'Branch sync: {ver}')
except Exception as e:
log.debug(f'Branch switch: {e}')
os.chdir(cwd)

View File

@ -2,6 +2,7 @@
let ws;
let url;
let currentImage = null;
let currentGalleryFolder = null;
let pruneImagesTimer;
let outstanding = 0;
let lastSort = 0;
@ -20,6 +21,7 @@ const el = {
search: undefined,
status: undefined,
btnSend: undefined,
clearCacheFolder: undefined,
};
const SUPPORTED_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'tiff', 'jp2', 'jxl', 'gif', 'mp4', 'mkv', 'avi', 'mjpeg', 'mpg', 'avr'];
@ -117,9 +119,12 @@ function updateGalleryStyles() {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
transition-duration: 0.2s;
transition-property: color, opacity, background-color, border-color;
transition-timing-function: ease-out;
}
.gallery-folder:hover {
background-color: var(--button-primary-background-fill-hover);
background-color: var(--button-primary-background-fill-hover, var(--sd-button-hover-color));
}
.gallery-folder-selected {
background-color: var(--sd-button-selected-color);
@ -258,6 +263,14 @@ class SimpleFunctionQueue {
this.#queue = [];
}
static abortLogger(identifier, result) {
if (typeof result === 'string' || (result instanceof DOMException && result.name === 'AbortError')) {
log(identifier, result?.message || result);
} else {
error(identifier, result.message);
}
}
/**
* @param {{
* signal: AbortSignal,
@ -301,6 +314,8 @@ class SimpleFunctionQueue {
// HTML Elements
class GalleryFolder extends HTMLElement {
static folders = new Set();
constructor(folder) {
super();
// Support both old format (string) and new format (object with path and label)
@ -314,21 +329,173 @@ class GalleryFolder extends HTMLElement {
this.style.overflowX = 'hidden';
this.shadow = this.attachShadow({ mode: 'open' });
this.shadow.adoptedStyleSheets = [folderStylesheet];
this.div = document.createElement('div');
}
connectedCallback() {
const div = document.createElement('div');
div.className = 'gallery-folder';
div.innerHTML = `<span class="gallery-folder-icon">\uf03e</span> ${this.label}`;
div.title = this.name; // Show full path on hover
div.addEventListener('click', () => {
for (const folder of el.folders.children) {
if (folder.name === this.name) folder.shadow.firstElementChild.classList.add('gallery-folder-selected');
else folder.shadow.firstElementChild.classList.remove('gallery-folder-selected');
if (GalleryFolder.folders.has(this)) return; // Element is just being moved
this.div.className = 'gallery-folder';
this.div.innerHTML = `<span class="gallery-folder-icon">\uf03e</span> ${this.label}`;
this.div.title = this.name; // Show full path on hover
this.div.addEventListener('click', () => { this.updateSelected(); }); // Ensures 'this' isn't the div in the called method
this.div.addEventListener('click', fetchFilesWS); // eslint-disable-line no-use-before-define
this.shadow.appendChild(this.div);
GalleryFolder.folders.add(this);
}
async disconnectedCallback() {
await Promise.resolve(); // Wait for other microtasks (such as element moving)
if (this.isConnected) return;
GalleryFolder.folders.delete(this);
}
updateSelected() {
this.div.classList.add('gallery-folder-selected');
for (const folder of GalleryFolder.folders) {
if (folder !== this) {
folder.div.classList.remove('gallery-folder-selected');
}
});
div.addEventListener('click', fetchFilesWS); // eslint-disable-line no-use-before-define
this.shadow.appendChild(div);
}
}
}
async function delayFetchThumb(fn, signal) {
await awaitForOutstanding(16, signal);
try {
outstanding++;
const ts = Date.now().toString();
const res = await authFetch(`${window.api}/browser/thumb?file=${encodeURI(fn)}&ts=${ts}`, { priority: 'low' });
if (!res.ok) {
error(`fetchThumb: ${res.statusText}`);
return undefined;
}
const json = await res.json();
if (!res || !json || json.error || Object.keys(json).length === 0) {
if (json.error) error(`fetchThumb: ${json.error}`);
return undefined;
}
return json;
} finally {
outstanding--;
}
}
class GalleryFile extends HTMLElement {
/** @type {AbortSignal} */
#signal;
constructor(folder, file, signal) {
super();
this.folder = folder;
this.name = file;
this.#signal = signal;
this.size = 0;
this.mtime = 0;
this.hash = undefined;
this.exif = '';
this.width = 0;
this.height = 0;
this.src = `${this.folder}/${this.name}`;
this.shadow = this.attachShadow({ mode: 'open' });
this.shadow.adoptedStyleSheets = [fileStylesheet];
this.firstRun = true;
}
async connectedCallback() {
if (!this.firstRun) return; // Element is just being moved
this.firstRun = false;
// Check separator state early to hide the element immediately
const dir = this.name.match(/(.*)[/\\]/);
if (dir && dir[1]) {
const dirPath = dir[1];
const isOpen = separatorStates.get(dirPath);
if (isOpen === false) {
this.style.display = 'none';
}
}
// Normalize path to ensure consistent hash regardless of which folder view is used
const normalizedPath = this.src.replace(/\/+/g, '/').replace(/\/$/, '');
this.hash = await getHash(`${normalizedPath}/${this.size}/${this.mtime}`); // eslint-disable-line no-use-before-define
const cachedData = (this.hash && opts.browser_cache) ? await idbGet(this.hash).catch(() => undefined) : undefined;
const img = document.createElement('img');
img.className = 'gallery-file';
img.loading = 'lazy';
img.onload = async () => {
img.title += `\nResolution: ${this.width} x ${this.height}`;
this.title = img.title;
if (!cachedData && opts.browser_cache) {
if ((this.width === 0) || (this.height === 0)) { // fetch thumb failed so we use actual image
this.width = img.naturalWidth;
this.height = img.naturalHeight;
}
}
};
let ok = true;
if (cachedData?.img) {
img.src = cachedData.img;
this.exif = cachedData.exif;
this.width = cachedData.width;
this.height = cachedData.height;
this.size = cachedData.size;
this.mtime = new Date(cachedData.mtime);
} else {
try {
const json = await delayFetchThumb(this.src, this.#signal);
if (!json) {
ok = false;
} else {
img.src = json.data;
this.exif = json.exif;
this.width = json.width;
this.height = json.height;
this.size = json.size;
this.mtime = new Date(json.mtime);
if (opts.browser_cache) {
await idbAdd({
hash: this.hash,
folder: this.folder,
file: this.name,
size: this.size,
mtime: this.mtime,
width: this.width,
height: this.height,
src: this.src,
exif: this.exif,
img: img.src,
// exif: await getExif(img), // alternative client-side exif
// img: await createThumb(img), // alternative client-side thumb
});
}
}
} catch (err) { // thumb fetch failed so assign actual image
img.src = `file=${this.src}`;
}
}
if (this.#signal.aborted) { // Do not change the operations order from here...
return;
}
galleryHashes.add(this.hash);
if (!ok) {
return;
} // ... to here unless modifications are also being made to maintenance functionality and the usage of AbortController/AbortSignal
img.onclick = () => {
setGallerySelectionByElement(this, { send: true });
};
img.title = `Folder: ${this.folder}\nFile: ${this.name}\nSize: ${this.size.toLocaleString()} bytes\nModified: ${this.mtime.toLocaleString()}`;
this.title = img.title;
// Final visibility check based on search term.
const shouldDisplayBasedOnSearch = this.title.toLowerCase().includes(el.search.value.toLowerCase());
if (this.style.display !== 'none') { // Only proceed if not already hidden by a closed separator
this.style.display = shouldDisplayBasedOnSearch ? 'unset' : 'none';
}
this.shadow.appendChild(img);
}
}
@ -459,148 +626,6 @@ async function addSeparators() {
}
}
async function delayFetchThumb(fn, signal) {
await awaitForOutstanding(16, signal);
try {
outstanding++;
const ts = Date.now().toString();
const res = await authFetch(`${window.api}/browser/thumb?file=${encodeURI(fn)}&ts=${ts}`, { priority: 'low' });
if (!res.ok) {
error(`fetchThumb: ${res.statusText}`);
return undefined;
}
const json = await res.json();
if (!res || !json || json.error || Object.keys(json).length === 0) {
if (json.error) error(`fetchThumb: ${json.error}`);
return undefined;
}
return json;
} finally {
outstanding--;
}
}
class GalleryFile extends HTMLElement {
/** @type {AbortSignal} */
#signal;
constructor(folder, file, signal) {
super();
this.folder = folder;
this.name = file;
this.#signal = signal;
this.size = 0;
this.mtime = 0;
this.hash = undefined;
this.exif = '';
this.width = 0;
this.height = 0;
this.src = `${this.folder}/${this.name}`;
this.shadow = this.attachShadow({ mode: 'open' });
this.shadow.adoptedStyleSheets = [fileStylesheet];
}
async connectedCallback() {
if (this.shadow.children.length > 0) {
return;
}
// Check separator state early to hide the element immediately
const dir = this.name.match(/(.*)[/\\]/);
if (dir && dir[1]) {
const dirPath = dir[1];
const isOpen = separatorStates.get(dirPath);
if (isOpen === false) {
this.style.display = 'none';
}
}
// Normalize path to ensure consistent hash regardless of which folder view is used
const normalizedPath = this.src.replace(/\/+/g, '/').replace(/\/$/, '');
this.hash = await getHash(`${normalizedPath}/${this.size}/${this.mtime}`); // eslint-disable-line no-use-before-define
const cachedData = (this.hash && opts.browser_cache) ? await idbGet(this.hash).catch(() => undefined) : undefined;
const img = document.createElement('img');
img.className = 'gallery-file';
img.loading = 'lazy';
img.onload = async () => {
img.title += `\nResolution: ${this.width} x ${this.height}`;
this.title = img.title;
if (!cachedData && opts.browser_cache) {
if ((this.width === 0) || (this.height === 0)) { // fetch thumb failed so we use actual image
this.width = img.naturalWidth;
this.height = img.naturalHeight;
}
}
};
let ok = true;
if (cachedData?.img) {
img.src = cachedData.img;
this.exif = cachedData.exif;
this.width = cachedData.width;
this.height = cachedData.height;
this.size = cachedData.size;
this.mtime = new Date(cachedData.mtime);
} else {
try {
const json = await delayFetchThumb(this.src, this.#signal);
if (!json) {
ok = false;
} else {
img.src = json.data;
this.exif = json.exif;
this.width = json.width;
this.height = json.height;
this.size = json.size;
this.mtime = new Date(json.mtime);
if (opts.browser_cache) {
// Store file's actual parent directory (not browsed folder) for consistent cleanup
const fileDir = this.src.replace(/\/+/g, '/').replace(/\/[^/]+$/, '');
await idbAdd({
hash: this.hash,
folder: fileDir,
file: this.name,
size: this.size,
mtime: this.mtime,
width: this.width,
height: this.height,
src: this.src,
exif: this.exif,
img: img.src,
// exif: await getExif(img), // alternative client-side exif
// img: await createThumb(img), // alternative client-side thumb
});
}
}
} catch (err) { // thumb fetch failed so assign actual image
img.src = `file=${this.src}`;
}
}
if (this.#signal.aborted) { // Do not change the operations order from here...
return;
}
galleryHashes.add(this.hash);
if (!ok) {
return;
} // ... to here unless modifications are also being made to maintenance functionality and the usage of AbortController/AbortSignal
img.onclick = () => {
setGallerySelectionByElement(this, { send: true });
};
img.title = `Folder: ${this.folder}\nFile: ${this.name}\nSize: ${this.size.toLocaleString()} bytes\nModified: ${this.mtime.toLocaleString()}`;
if (this.shadow.children.length > 0) {
return; // avoid double-adding
}
this.title = img.title;
// Final visibility check based on search term.
const shouldDisplayBasedOnSearch = this.title.toLowerCase().includes(el.search.value.toLowerCase());
if (this.style.display !== 'none') { // Only proceed if not already hidden by a closed separator
this.style.display = shouldDisplayBasedOnSearch ? 'unset' : 'none';
}
this.shadow.appendChild(img);
}
}
// methods
const gallerySendImage = (_images) => [currentImage]; // invoked by gradio button
@ -919,9 +944,10 @@ async function gallerySort(btn) {
/**
* Generate and display the overlay to announce cleanup is in progress.
* @param {number} count - Number of entries being cleaned up
* @param {boolean} all - Indicate that all thumbnails are being cleared
* @returns {ClearMsgCallback}
*/
function showCleaningMsg(count) {
function showCleaningMsg(count, all = false) {
// Rendering performance isn't a priority since this doesn't run often
const parent = el.folders.parentElement;
const cleaningOverlay = document.createElement('div');
@ -936,7 +962,7 @@ function showCleaningMsg(count) {
msgText.style.cssText = 'font-size: 1.2em';
msgInfo.style.cssText = 'font-size: 0.9em; text-align: center;';
msgText.innerText = 'Thumbnail cleanup...';
msgInfo.innerText = `Found ${count} old entries`;
msgInfo.innerText = all ? 'Clearing all entries' : `Found ${count} old entries`;
anim.classList.add('idbBusyAnim');
msgDiv.append(msgText, msgInfo);
@ -945,16 +971,17 @@ function showCleaningMsg(count) {
return () => { cleaningOverlay.remove(); };
}
const maintenanceQueue = new SimpleFunctionQueue('Maintenance');
const maintenanceQueue = new SimpleFunctionQueue('Gallery Maintenance');
/**
* Handles calling the cleanup function for the thumbnail cache
* @param {string} folder - Folder to clean
* @param {number} imgCount - Expected number of images in gallery
* @param {AbortController} controller - AbortController that's handling this task
* @param {boolean} force - Force full cleanup of the folder
*/
async function thumbCacheCleanup(folder, imgCount, controller) {
if (!opts.browser_cache) return;
async function thumbCacheCleanup(folder, imgCount, controller, force = false) {
if (!opts.browser_cache && !force) return;
try {
if (typeof folder !== 'string' || typeof imgCount !== 'number') {
throw new Error('Function called with invalid arguments');
@ -971,14 +998,14 @@ async function thumbCacheCleanup(folder, imgCount, controller) {
callback: async () => {
log(`Thumbnail DB cleanup: Checking if "${folder}" needs cleaning`);
const t0 = performance.now();
const staticGalleryHashes = new Set(galleryHashes); // External context should be safe since this function run is guarded by AbortController/AbortSignal in the SimpleFunctionQueue
const keptGalleryHashes = force ? new Set() : new Set(galleryHashes.values()); // External context should be safe since this function run is guarded by AbortController/AbortSignal in the SimpleFunctionQueue
const cachedHashesCount = await idbCount(folder)
.catch((e) => {
error(`Thumbnail DB cleanup: Error when getting entry count for "${folder}".`, e);
return Infinity; // Forces next check to fail if something went wrong
});
const cleanupCount = cachedHashesCount - staticGalleryHashes.size;
if (cleanupCount < 500 || !Number.isFinite(cleanupCount)) {
const cleanupCount = cachedHashesCount - keptGalleryHashes.size;
if (!force && (cleanupCount < 500 || !Number.isFinite(cleanupCount))) {
// Don't run when there aren't many excess entries
return;
}
@ -988,30 +1015,95 @@ async function thumbCacheCleanup(folder, imgCount, controller) {
return;
}
const cb_clearMsg = showCleaningMsg(cleanupCount);
const tRun = Date.now(); // Doesn't need high resolution
await idbFolderCleanup(staticGalleryHashes, folder, controller.signal)
await idbFolderCleanup(keptGalleryHashes, folder, controller.signal)
.then((delcount) => {
const t1 = performance.now();
log(`Thumbnail DB cleanup: folder=${folder} kept=${staticGalleryHashes.size} deleted=${delcount} time=${Math.floor(t1 - t0)}ms`);
log(`Thumbnail DB cleanup: folder=${folder} kept=${keptGalleryHashes.size} deleted=${delcount} time=${Math.floor(t1 - t0)}ms`);
currentGalleryFolder = null;
el.clearCacheFolder.innerText = '<select a folder first>';
updateStatusWithSort('Thumbnail cache cleared');
})
.catch((reason) => {
if (typeof reason === 'string' || (reason instanceof DOMException && reason.name === 'AbortError')) {
log('Thumbnail DB cleanup:', reason?.message || reason);
} else {
error('Thumbnail DB cleanup:', reason.message);
}
SimpleFunctionQueue.abortLogger('Thumbnail DB cleanup:', reason);
})
.finally(async () => {
// Ensure at least enough time to see that it's a message and not the UI breaking/flickering
await new Promise((resolve) => {
setTimeout(resolve, Math.min(1000, Math.max(1000 - (Date.now() - tRun), 0))); // Total display time of at least 1 second
});
await new Promise((resolve) => { setTimeout(resolve, 1000); }); // Delay removal by 1 second to ensure at least minimum visibility
cb_clearMsg();
});
},
});
}
function resetGalleryState(reason) {
maintenanceController.abort(reason);
const controller = new AbortController();
maintenanceController = controller;
galleryHashes.clear(); // Must happen AFTER the AbortController steps
galleryProgressBar.clear();
resetGallerySelection();
return controller;
}
function clearCacheIfDisabled(browser_cache) {
if (browser_cache === false) {
log('Thumbnail DB cleanup:', 'Image gallery cache setting disabled. Clearing cache.');
const controller = resetGalleryState('Clearing all thumbnails from cache');
maintenanceQueue.enqueue({
signal: controller.signal,
callback: async () => {
const t0 = performance.now();
const cb_clearMsg = showCleaningMsg(0, true);
await idbClearAll(controller.signal)
.then(() => {
log(`Thumbnail DB cleanup: Cache cleared. time=${Math.floor(performance.now() - t0)}ms`);
currentGalleryFolder = null;
el.clearCacheFolder.innerText = '<select a folder first>';
updateStatusWithSort('Thumbnail cache cleared');
})
.catch((e) => {
SimpleFunctionQueue.abortLogger('Thumbnail DB cleanup:', e);
})
.finally(async () => {
await new Promise((resolve) => { setTimeout(resolve, 1000); });
cb_clearMsg();
});
},
});
}
}
function addCacheClearLabel() { // Don't use async
const setting = document.querySelector('#setting_browser_cache');
if (setting) {
const div = document.createElement('div');
div.style.marginBlock = '0.75rem';
const span = document.createElement('span');
span.style.cssText = 'font-weight: bold; text-decoration: underline; cursor: pointer; color: var(--color-blue); user-select: none;';
span.innerText = '<select a folder first>';
div.append('Clear the thumbnail cache for: ', span, ' (double-click)');
setting.parentElement.insertAdjacentElement('afterend', div);
el.clearCacheFolder = span;
span.addEventListener('dblclick', (evt) => {
evt.preventDefault();
evt.stopPropagation();
if (!currentGalleryFolder) return;
el.clearCacheFolder.style.color = 'var(--color-green)';
setTimeout(() => {
el.clearCacheFolder.style.color = 'var(--color-blue)';
}, 1000);
const controller = resetGalleryState('Clearing folder thumbnails cache');
el.files.innerHTML = '';
thumbCacheCleanup(currentGalleryFolder, 0, controller, true);
});
return true;
}
return false;
}
async function fetchFilesHT(evt, controller) {
const t0 = performance.now();
const fragment = document.createDocumentFragment();
@ -1049,12 +1141,8 @@ async function fetchFilesHT(evt, controller) {
async function fetchFilesWS(evt) { // fetch file-by-file list over websockets
if (!url) return;
const controller = new AbortController(); // Only called here because fetchFilesHT isn't called directly
maintenanceController.abort('Gallery update'); // Abort previous controller
maintenanceController = controller; // Point to new controller for next time
galleryHashes.clear(); // Must happen AFTER the AbortController steps
galleryProgressBar.clear();
resetGallerySelection();
// Abort previous controller and point to new controller for next time
const controller = resetGalleryState('Gallery update'); // Called here because fetchFilesHT isn't called directly
el.files.innerHTML = '';
updateGalleryStyles();
@ -1068,6 +1156,10 @@ async function fetchFilesWS(evt) { // fetch file-by-file list over websockets
return;
}
log(`gallery: connected=${wsConnected} state=${ws?.readyState} url=${ws?.url}`);
currentGalleryFolder = evt.target.name;
if (el.clearCacheFolder) {
el.clearCacheFolder.innerText = currentGalleryFolder;
}
if (!wsConnected) {
await fetchFilesHT(evt, controller); // fallback to http
return;
@ -1115,26 +1207,17 @@ async function fetchFilesWS(evt) { // fetch file-by-file list over websockets
ws.send(encodeURI(evt.target.name));
}
async function pruneImages() {
// TODO replace img.src with placeholder for images that are not visible
}
async function galleryVisible() {
async function updateFolders() {
// if (el.folders.children.length > 0) return;
const res = await authFetch(`${window.api}/browser/folders`);
if (!res || res.status !== 200) return;
el.folders.innerHTML = '';
url = res.url.split('/sdapi')[0].replace('http', 'ws'); // update global url as ws need fqdn
const folders = await res.json();
el.folders.innerHTML = '';
for (const folder of folders) {
const f = new GalleryFolder(folder);
el.folders.appendChild(f);
}
pruneImagesTimer = setInterval(pruneImages, 1000);
}
async function galleryHidden() {
if (pruneImagesTimer) clearInterval(pruneImagesTimer);
}
async function monitorGalleries() {
@ -1165,6 +1248,32 @@ async function setOverlayAnimation() {
document.head.append(busyAnimation);
}
async function galleryClearInit() {
let galleryClearInitTimeout = 0;
const tryCleanupInit = setInterval(() => {
if (addCacheClearLabel() || galleryClearInitTimeout++ === 60) {
clearInterval(tryCleanupInit);
monitorOption('browser_cache', clearCacheIfDisabled);
}
}, 1000);
}
async function blockQueueUntilReady() {
// Add block to maintenanceQueue until cache is ready
maintenanceQueue.enqueue({
signal: new AbortController().signal, // Use standalone AbortSignal that can't be aborted
callback: async () => {
let timeout = 0;
while (!idbIsReady() && timeout++ < 60) {
await new Promise((resolve) => { setTimeout(resolve, 1000); });
}
if (!idbIsReady()) {
throw new Error('Timed out waiting for thumbnail cache');
}
},
});
}
async function initGallery() { // triggered on gradio change to monitor when ui gets sufficiently constructed
log('initGallery');
el.folders = gradioApp().getElementById('tab-gallery-folders');
@ -1175,9 +1284,12 @@ async function initGallery() { // triggered on gradio change to monitor when ui
error('initGallery', 'Missing gallery elements');
return;
}
blockQueueUntilReady(); // Run first
updateGalleryStyles();
injectGalleryStatusCSS();
setOverlayAnimation();
galleryClearInit();
const progress = gradioApp().getElementById('tab-gallery-progress');
if (progress) {
galleryProgressBar.attachTo(progress);
@ -1188,12 +1300,9 @@ async function initGallery() { // triggered on gradio change to monitor when ui
el.btnSend = gradioApp().getElementById('tab-gallery-send-image');
document.getElementById('tab-gallery-files').style.height = opts.logmonitor_show ? '75vh' : '85vh';
const intersectionObserver = new IntersectionObserver((entries) => {
if (entries[0].intersectionRatio <= 0) galleryHidden();
if (entries[0].intersectionRatio > 0) galleryVisible();
});
intersectionObserver.observe(el.folders);
monitorGalleries();
updateFolders();
monitorOption('browser_folders', updateFolders);
}
// register on startup

View File

@ -36,6 +36,41 @@ async function initIndexDB() {
if (!db) await createDB();
}
function idbIsReady() {
return db !== null;
}
/**
* Reusable setup for handling IDB transactions.
* @param {Object} resources - Required resources for implementation
* @param {IDBTransaction} resources.transaction
* @param {AbortSignal} resources.signal
* @param {Function} resources.resolve
* @param {Function} resources.reject
* @param {*} resolveValue - Value to resolve the outer Promise with
* @returns {() => void} - Function for manually aborting the transaction
*/
function configureTransactionAbort({ transaction, signal, resolve, reject }, resolveValue) {
function abortTransaction() {
signal.removeEventListener('abort', abortTransaction);
transaction.abort();
}
signal.addEventListener('abort', abortTransaction);
transaction.onabort = () => {
signal.removeEventListener('abort', abortTransaction);
reject(new DOMException(`Aborting database transaction. ${signal.reason}`, 'AbortError'));
};
transaction.onerror = (e) => {
signal.removeEventListener('abort', abortTransaction);
reject(new Error('Database transaction error.', e));
};
transaction.oncomplete = () => {
signal.removeEventListener('abort', abortTransaction);
resolve(resolveValue);
};
return abortTransaction;
}
async function add(record) {
if (!db) return null;
return new Promise((resolve, reject) => {
@ -150,10 +185,7 @@ async function idbFolderCleanup(keepSet, folder, signal) {
throw new Error('IndexedDB cleaning function must be told the current active folder');
}
// Use range query to match folder and all its subdirectories
const folderNormalized = folder.replace(/\/+/g, '/').replace(/\/$/, '');
const range = IDBKeyRange.bound(folderNormalized, `${folderNormalized}\uffff`, false, true);
let removals = new Set(await idbGetAllKeys('folder', range));
let removals = new Set(await idbGetAllKeys('folder', folder));
removals = removals.difference(keepSet); // Don't need to keep full set in memory
const totalRemovals = removals.size;
if (signal.aborted) {
@ -161,31 +193,20 @@ async function idbFolderCleanup(keepSet, folder, signal) {
}
return new Promise((resolve, reject) => {
const transaction = db.transaction('thumbs', 'readwrite');
function abortTransaction() {
signal.removeEventListener('abort', abortTransaction);
transaction.abort();
}
signal.addEventListener('abort', abortTransaction);
transaction.onabort = () => {
signal.removeEventListener('abort', abortTransaction);
reject(`Aborting. ${signal.reason}`); // eslint-disable-line prefer-promise-reject-errors
};
transaction.onerror = () => {
signal.removeEventListener('abort', abortTransaction);
reject(new Error('Database transaction error'));
};
transaction.oncomplete = async () => {
signal.removeEventListener('abort', abortTransaction);
resolve(totalRemovals);
};
const props = { transaction, signal, resolve, reject };
configureTransactionAbort(props, totalRemovals);
const store = transaction.objectStore('thumbs');
removals.forEach((entry) => { store.delete(entry); });
});
}
try {
const store = transaction.objectStore('thumbs');
removals.forEach((entry) => { store.delete(entry); });
} catch (err) {
error(err);
abortTransaction();
}
async function idbClearAll(signal) {
if (!db) return null;
return new Promise((resolve, reject) => {
const transaction = db.transaction(['thumbs'], 'readwrite');
const props = { transaction, signal, resolve, reject };
configureTransactionAbort(props, null);
transaction.objectStore('thumbs').clear();
});
}

View File

@ -1,31 +1,63 @@
const getModel = () => {
const cp = opts?.sd_model_checkpoint || '';
if (!cp) return 'unknown model';
const noBracket = cp.replace(/\s*\[.*\]\s*$/, ''); // remove trailing [hash]
const parts = noBracket.split(/[\\/]/); // split on / or \
return parts[parts.length - 1].trim() || 'unknown model';
};
class ConnectionMonitorState {
static element;
static version = '';
static commit = '';
static branch = '';
static online = false;
static getModel() {
const cp = opts?.sd_model_checkpoint || '';
return cp ? this.trimModelName(cp) : 'unknown model';
}
static trimModelName(name) {
// remove trailing [hash], split on / or \, return last segment, trim
return name.replace(/\s*\[.*\]\s*$/, '').split(/[\\/]/).pop().trim() || 'unknown model';
}
static setData({ online, updated, commit, branch }) {
this.online = online;
this.version = updated;
this.commit = commit;
this.branch = branch;
}
static setElement(el) {
this.element = el;
}
static toHTML(modelOverride) {
return `
Version: <b>${this.version}</b><br>
Commit: <b>${this.commit}</b><br>
Branch: <b>${this.branch}</b><br>
Status: ${this.online ? '<b style="color:lime">online</b>' : '<b style="color:darkred">offline</b>'}<br>
Model: <b>${modelOverride ? this.trimModelName(modelOverride) : this.getModel()}</b><br>
Since: ${new Date().toLocaleString()}<br>
`;
}
static updateState(incomingModel) {
this.element.dataset.hint = this.toHTML(incomingModel);
this.element.style.backgroundColor = this.online ? 'var(--sd-main-accent-color)' : 'var(--color-error)';
}
}
let monitorAutoUpdating = false;
async function updateIndicator(online, data, msg) {
const el = document.getElementById('logo_nav');
if (!el || !data) return;
const status = online ? '<b style="color:lime">online</b>' : '<b style="color:darkred">offline</b>';
const date = new Date();
const template = `
Version: <b>${data.updated}</b><br>
Commit: <b>${data.commit}</b><br>
Branch: <b>${data.branch}</b><br>
Status: ${status}<br>
Model: <b>${getModel()}</b><br>
Since: ${date.toLocaleString()}<br>
`;
ConnectionMonitorState.setElement(el);
if (!monitorAutoUpdating) {
monitorOption('sd_model_checkpoint', (newVal) => { ConnectionMonitorState.updateState(newVal); }); // Runs before opt actually changes
monitorAutoUpdating = true;
}
ConnectionMonitorState.setData({ online, ...data });
ConnectionMonitorState.updateState();
if (online) {
el.dataset.hint = template;
el.style.backgroundColor = 'var(--sd-main-accent-color)';
log('monitorConnection: online', data);
} else {
el.dataset.hint = template;
el.style.backgroundColor = 'var(--color-error)';
log('monitorConnection: offline', msg);
}
}

View File

@ -11,6 +11,10 @@ const monitoredOpts = [
{ sd_backend: () => gradioApp().getElementById('refresh_sd_model_checkpoint')?.click() },
];
function monitorOption(option, callback) {
monitoredOpts.push({ [option]: callback });
}
const AppyOpts = [
{ compact_view: (val, old) => toggleCompact(val, old) },
{ gradio_theme: (val, old) => setTheme(val, old) },
@ -25,17 +29,15 @@ async function updateOpts(json_string) {
const t1 = performance.now();
for (const op of monitoredOpts) {
const key = Object.keys(op)[0];
const callback = op[key];
if (opts[key] && opts[key] !== settings_data.values[key]) {
log('updateOpt', key, opts[key], settings_data.values[key]);
const [key, callback] = Object.entries(op)[0];
if (Object.hasOwn(opts, key) && opts[key] !== new_opts[key]) {
log('updateOpt', key, opts[key], new_opts[key]);
if (callback) callback(new_opts[key], opts[key]);
}
}
for (const op of AppyOpts) {
const key = Object.keys(op)[0];
const callback = op[key];
const [key, callback] = Object.entries(op)[0];
if (callback) callback(new_opts[key], opts[key]);
}

View File

@ -574,7 +574,7 @@ function toggleCompact(val, old) {
function previewTheme() {
let name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('input')?.[0].value || '';
fetch(`${window.subpath}/file=html/themes.json`)
fetch(`${window.subpath}/file=data/themes.json`)
.then((res) => {
res.json()
.then((themes) => {

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View File

@ -103,6 +103,7 @@ class Api:
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict)
# lora api
from modules.api import loras
@ -116,6 +117,10 @@ class Api:
from modules.api import nudenet
nudenet.register_api()
# xyz-grid api
from modules.api import xyz_grid
xyz_grid.register_api()
# civitai api
from modules.civitai import api_civitai
api_civitai.register_api()

View File

@ -6,8 +6,28 @@ from modules.api import models, helpers
def get_samplers():
from modules import sd_samplers
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
from modules import sd_samplers_diffusers
all_samplers = []
for k, v in sd_samplers_diffusers.config.items():
if k in ['All', 'Default', 'Res4Lyf']:
continue
all_samplers.append({
'name': k,
'options': v,
})
return all_samplers
def get_sampler():
if not shared.sd_loaded or shared.sd_model is None:
return {}
if hasattr(shared.sd_model, 'scheduler'):
scheduler = shared.sd_model.scheduler
config = {k: v for k, v in scheduler.config.items() if not k.startswith('_')}
return {
'name': scheduler.__class__.__name__,
'options': config
}
return {}
def get_sd_vaes():
from modules.sd_vae import vae_dict
@ -75,6 +95,13 @@ def get_interrogate():
from modules.interrogate.openclip import refresh_clip_models
return ['deepdanbooru'] + refresh_clip_models()
def get_schedulers():
from modules.sd_samplers import list_samplers
all_schedulers = list_samplers()
for s in all_schedulers:
shared.log.critical(s)
return all_schedulers
def post_interrogate(req: models.ReqInterrogate):
if req.image is None or len(req.image) < 64:
raise HTTPException(status_code=404, detail="Image not found")

View File

@ -86,8 +86,7 @@ class PydanticModelGenerator:
class ItemSampler(BaseModel):
name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options")
options: dict
class ItemVae(BaseModel):
model_name: str = Field(title="Model Name")
@ -199,6 +198,11 @@ class ItemExtension(BaseModel):
commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
class ItemScheduler(BaseModel):
name: str = Field(title="Name", description="Scheduler name")
cls: str = Field(title="Class", description="Scheduler class name")
options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
### request/response classes
ReqTxt2Img = PydanticModelGenerator(

26
modules/api/xyz_grid.py Normal file
View File

@ -0,0 +1,26 @@
from typing import List
def xyz_grid_enum(option: str = "") -> List[dict]:
from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module
options = []
for x in xyz_grid_classes.axis_options:
_option = {
'label': x.label,
'type': x.type.__name__,
'cost': x.cost,
'choices': x.choices is not None,
}
if len(option) == 0:
options.append(_option)
else:
if x.label.lower().startswith(option.lower()) or x.label.lower().endswith(option.lower()):
if callable(x.choices):
_option['choices'] = x.choices()
options.append(_option)
return options
def register_api():
from modules.shared import api as api_instance
api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict])

73
modules/errorlimiter.py Normal file
View File

@ -0,0 +1,73 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterable
class ErrorLimiterTrigger(BaseException): # Use BaseException to avoid being caught by "except Exception:".
def __init__(self, name: str, *args):
super().__init__(*args)
self.name = name
class ErrorLimiterAbort(RuntimeError):
def __init__(self, msg: str):
super().__init__(msg)
class ErrorLimiter:
_store: dict[str, int] = {}
@classmethod
def start(cls, name: str, limit: int = 5):
cls._store[name] = limit
@classmethod
def notify(cls, name: str | Iterable[str]): # Can be manually triggered if execution is spread across multiple files
if isinstance(name, str):
name = (name,)
for key in name:
if key in cls._store.keys():
cls._store[key] = cls._store[key] - 1
if cls._store[key] <= 0:
raise ErrorLimiterTrigger(key)
@classmethod
def end(cls, name: str):
cls._store.pop(name)
@contextmanager
def limit_errors(name: str, limit: int = 5):
"""Limiter for aborting execution after being triggered a specified number of times (default 5).
>>> with limit_errors("identifier", limit=5) as elimit:
>>> while do_thing():
>>> if (something_bad):
>>> print("Something bad happened")
>>> elimit() # In this example, raises ErrorLimiterAbort on the 5th call
>>> try:
>>> something_broken()
>>> except Exception:
>>> print("Encountered an exception")
>>> elimit() # Count is shared across all calls
Args:
name (str): Identifier.
limit (int, optional): Abort after `limit` number of triggers. Defaults to 5.
Raises:
ErrorLimiterAbort: Subclass of RuntimeException.
Yields:
Callable: Notification function to indicate that an error occurred.
"""
try:
ErrorLimiter.start(name, limit)
yield lambda: ErrorLimiter.notify(name)
except ErrorLimiterTrigger as e:
raise ErrorLimiterAbort(f"HALTING. Too many errors during '{e.name}'") from None
finally:
ErrorLimiter.end(name)

View File

@ -1,6 +1,7 @@
import logging
import warnings
from installer import get_log, get_console, setup_logging, install_traceback
from modules.errorlimiter import ErrorLimiterAbort
log = get_log()
@ -16,9 +17,18 @@ def install(suppress=[]):
def display(e: Exception, task: str, suppress=[]):
log.error(f"{task or 'error'}: {type(e).__name__}")
if isinstance(e, ErrorLimiterAbort):
return
log.critical(f"{task or 'error'}: {type(e).__name__}")
"""
trace = traceback.format_exc()
log.error(trace)
for line in traceback.format_tb(e.__traceback__):
log.error(repr(line))
console = get_console()
console.print_exception(show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
"""
log.traceback(e, suppress=suppress)
def display_once(e: Exception, task):

View File

@ -151,33 +151,30 @@ def deactivate(p, extra_network_data=None, force=shared.opts.lora_force_reload):
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
def parse_prompt(prompt):
res = defaultdict(list)
def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNetworkParams]]]:
res: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list)
if prompt is None:
return prompt, res
return "", res
if isinstance(prompt, list):
shared.log.warning(f"parse_prompt was called with a list instead of a string: {prompt}")
return parse_prompts(prompt)
def found(m):
name = m.group(1)
args = m.group(2)
def found(m: re.Match[str]):
name, args = m.group(1, 2)
res[name].append(ExtraNetworkParams(items=args.split(":")))
return ""
if isinstance(prompt, list):
prompt = [re.sub(re_extra_net, found, p) for p in prompt]
else:
prompt = re.sub(re_extra_net, found, prompt)
return prompt, res
updated_prompt = re.sub(re_extra_net, found, prompt)
return updated_prompt, res
def parse_prompts(prompts):
res = []
extra_data = None
if prompts is None:
return prompts, extra_data
def parse_prompts(prompts: list[str]):
updated_prompt_list: list[str] = []
extra_data: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list)
for prompt in prompts:
updated_prompt, parsed_extra_data = parse_prompt(prompt)
if extra_data is None:
if not extra_data:
extra_data = parsed_extra_data
res.append(updated_prompt)
updated_prompt_list.append(updated_prompt)
return res, extra_data
return updated_prompt_list, extra_data

View File

@ -205,7 +205,7 @@ def face_id(
ip_model_dict["faceid_embeds"] = face_embeds # overwrite placeholder
faceid_model.set_scale(scale)
if p.all_prompts is None or len(p.all_prompts) == 0:
if not p.all_prompts:
processing.process_init(p)
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
for n in range(p.n_iter):

View File

@ -63,7 +63,7 @@ def instant_id(p: processing.StableDiffusionProcessing, app, source_images, stre
sd_models.move_model(shared.sd_model, devices.device) # move pipeline to device
# pipeline specific args
if p.all_prompts is None or len(p.all_prompts) == 0:
if not p.all_prompts:
processing.process_init(p)
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
orig_prompt_attention = shared.opts.prompt_attention
@ -73,8 +73,8 @@ def instant_id(p: processing.StableDiffusionProcessing, app, source_images, stre
p.task_args['controlnet_conditioning_scale'] = float(conditioning)
p.task_args['ip_adapter_scale'] = float(strength)
shared.log.debug(f"InstantID args: {p.task_args}")
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts is not None else p.prompt
p.task_args['negative_prompt'] = p.all_negative_prompts[0] if p.all_negative_prompts is not None else p.negative_prompt
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts else p.prompt
p.task_args['negative_prompt'] = p.all_negative_prompts[0] if p.all_negative_prompts else p.negative_prompt
p.task_args['image_embeds'] = face_embeds[0] # overwrite placeholder
# run processing

View File

@ -34,7 +34,7 @@ def photo_maker(p: processing.StableDiffusionProcessing, app, model: str, input_
return None
# validate prompt
if p.all_prompts is None or len(p.all_prompts) == 0:
if not p.all_prompts:
processing.process_init(p)
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
trigger_ids = shared.sd_model.tokenizer.encode(trigger) + shared.sd_model.tokenizer_2.encode(trigger)
@ -61,7 +61,7 @@ def photo_maker(p: processing.StableDiffusionProcessing, app, model: str, input_
shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask
p.task_args['input_id_images'] = input_images
p.task_args['start_merge_step'] = int(start * p.steps)
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts is not None else p.prompt
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts else p.prompt
is_v2 = 'v2' in model
if is_v2:

View File

@ -43,7 +43,7 @@ def vae_decode_simple(latents):
def vae_decode_tiny(latents):
global taesd # pylint: disable=global-statement
if taesd is None:
from modules import sd_vae_taesd
from modules.vae import sd_vae_taesd
taesd, _variant = sd_vae_taesd.get_model(variant='TAE HunyuanVideo')
shared.log.debug(f'Video VAE: type=Tiny cls={taesd.__class__.__name__} latents={latents.shape}')
with devices.inference_context():
@ -56,7 +56,7 @@ def vae_decode_tiny(latents):
def vae_decode_remote(latents):
# from modules.sd_vae_remote import remote_decode
# from modules.vae.sd_vae_remote import remote_decode
# images = remote_decode(latents, model_type='hunyuanvideo')
from diffusers.utils.remote_utils import remote_decode
images = remote_decode(

View File

@ -309,16 +309,18 @@ def worker(
break
total_generated_frames, _video_filename = save_video(
None,
history_pixels,
mp4_fps,
mp4_codec,
mp4_opt,
mp4_ext,
mp4_sf,
mp4_video,
mp4_frames,
mp4_interpolate,
p=None,
pixels=history_pixels,
audio=None,
binary=None,
mp4_fps=mp4_fps,
mp4_codec=mp4_codec,
mp4_opt=mp4_opt,
mp4_ext=mp4_ext,
mp4_sf=mp4_sf,
mp4_video=mp4_video,
mp4_frames=mp4_frames,
mp4_interpolate=mp4_interpolate,
pbar=pbar,
stream=stream,
metadata=metadata,
@ -327,7 +329,23 @@ def worker(
except AssertionError:
shared.log.info('FramePack: interrupted')
if shared.opts.keep_incomplete:
save_video(None, history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate=0, stream=stream, metadata=metadata)
save_video(
p=None,
pixels=history_pixels,
audio=None,
binary=None,
mp4_fps=mp4_fps,
mp4_codec=mp4_codec,
mp4_opt=mp4_opt,
mp4_ext=mp4_ext,
mp4_sf=mp4_sf,
mp4_video=mp4_video,
mp4_frames=mp4_frames,
mp4_interpolate=0,
pbar=pbar,
stream=stream,
metadata=metadata,
)
except Exception as e:
shared.log.error(f'FramePack: {e}')
errors.display(e, 'FramePack')

View File

@ -17,6 +17,10 @@ debug('Trace: PASTE')
parse_generation_parameters = parse # compatibility
infotext_to_setting_name_mapping = mapping # compatibility
# Mapping of aliases to metadata parameter names, populated automatically from component labels/elem_ids
# This allows users to use component labels, elem_ids, or metadata names in the "skip params" setting
param_aliases: dict[str, str] = {}
class ParamBinding:
def __init__(self, paste_button, tabname: str, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
@ -74,7 +78,8 @@ def image_from_url_text(filedata):
filedata = filedata[len("data:image/jxl;base64,"):]
filebytes = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filebytes))
images.read_info_from_image(image)
image.load()
# images.read_info_from_image(image)
return image
@ -85,9 +90,36 @@ def add_paste_fields(tabname: str, init_img: gr.Image | gr.HTML | None, fields:
except Exception as e:
shared.log.error(f"Paste fields: tab={tabname} fields={fields} {e}")
field_names[tabname] = []
# Build param_aliases automatically from component labels and elem_ids
if fields is not None:
for component, metadata_name in fields:
if metadata_name is None or callable(metadata_name):
continue
metadata_lower = metadata_name.lower()
# Extract label from component (e.g., "Batch size" -> maps to "Batch-2")
label = getattr(component, 'label', None)
if label and isinstance(label, str):
label_lower = label.lower()
if label_lower != metadata_lower and label_lower not in param_aliases:
param_aliases[label_lower] = metadata_lower
# Extract elem_id and derive variable name (e.g., "txt2img_batch_size" -> "batch_size")
elem_id = getattr(component, 'elem_id', None)
if elem_id and isinstance(elem_id, str):
# Strip common prefixes like "txt2img_", "img2img_", "control_"
var_name = elem_id
for prefix in ['txt2img_', 'img2img_', 'control_', 'video_', 'extras_']:
if var_name.startswith(prefix):
var_name = var_name[len(prefix):]
break
var_name_lower = var_name.lower()
if var_name_lower != metadata_lower and var_name_lower not in param_aliases:
param_aliases[var_name_lower] = metadata_lower
# backwards compatibility for existing extensions
debug(f'Paste fields: tab={tabname} fields={field_names[tabname]}')
debug(f'All fields: {get_all_fields()}')
debug(f'Param aliases: {param_aliases}')
import modules.ui
if tabname == 'txt2img':
modules.ui.txt2img_paste_fields = fields # compatibility
@ -133,10 +165,22 @@ def should_skip(param: str):
skip_params = [p.strip().lower() for p in shared.opts.disable_apply_params.split(",")]
if not shared.opts.clip_skip_enabled:
skip_params += ['clip skip']
# Expand skip_params with aliases (e.g., "batch_size" -> "batch-2")
expanded_skip = set(skip_params)
for skip in skip_params:
if skip in param_aliases:
expanded_skip.add(param_aliases[skip])
# Check if param should be skipped
param_lower = param.lower()
# Also check normalized name (without -1/-2) so "batch" skips both "batch-1" and "batch-2"
param_normalized = param_lower.replace('-1', '').replace('-2', '')
all_params = [p.lower() for p in get_all_fields()]
valid = any(p in all_params for p in skip_params)
skip = param.lower() in skip_params
debug(f'Check: param="{param}" valid={valid} skip={skip}')
skip = param_lower in expanded_skip or param_normalized in expanded_skip
debug(f'Check: param="{param}" valid={valid} skip={skip} expanded={expanded_skip}')
return skip

View File

@ -104,6 +104,9 @@ def on_tmpdir_changed():
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
temp_dir = os.path.join(paths.temp_dir, "gradio")
shared.log.debug(f'Temp folder: path="{temp_dir}"')
if not os.path.isdir(temp_dir):
return
for root, _dirs, files in os.walk(temp_dir, topdown=False):
for name in files:

View File

@ -6,7 +6,7 @@ from modules.json_helpers import readfile, writefile
from modules.paths import data_path
cache_filename = os.path.join(data_path, "cache.json")
cache_filename = os.path.join(data_path, 'data', 'cache.json')
cache_data = None
progress_ok = True

View File

@ -311,7 +311,7 @@ def parse_novelai_metadata(data: dict):
return geninfo
def read_info_from_image(image: Image.Image, watermark: bool = False):
def read_info_from_image(image: Image.Image, watermark: bool = False) -> tuple[str, dict]:
if image is None:
return '', {}
if isinstance(image, str):
@ -322,9 +322,11 @@ def read_info_from_image(image: Image.Image, watermark: bool = False):
return '', {}
items = image.info or {}
geninfo = items.pop('parameters', None) or items.pop('UserComment', None) or ''
if geninfo is not None and len(geninfo) > 0:
if isinstance(geninfo, dict):
if 'UserComment' in geninfo:
geninfo = geninfo['UserComment']
geninfo = geninfo['UserComment'] # Info was nested
else:
geninfo = '' # Unknown format. Ignore contents
items['UserComment'] = geninfo
if "exif" in items:
@ -342,7 +344,7 @@ def read_info_from_image(image: Image.Image, watermark: bool = False):
val = round(val[0] / val[1], 2)
if val is not None and key in ExifTags.TAGS: # add known tags
if ExifTags.TAGS[key] == 'UserComment': # add geninfo from UserComment
geninfo = val
geninfo = str(val)
items['parameters'] = val
else:
items[ExifTags.TAGS[key]] = val

View File

@ -10,7 +10,8 @@ from pathlib import Path
from modules import shared, errors
debug = errors.log.trace if os.environ.get('SD_NAMEGEN_DEBUG', None) is not None else lambda *args, **kwargs: None
debug= os.environ.get('SD_NAMEGEN_DEBUG', None) is not None
debug_log = errors.log.trace if debug else lambda *args, **kwargs: None
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
@ -66,9 +67,9 @@ class FilenameGenerator:
def __init__(self, p, seed, prompt, image=None, grid=False, width=None, height=None):
if p is None:
debug('Filename generator init skip')
debug_log('Filename generator init skip')
else:
debug(f'Filename generator init: seed={seed} prompt="{prompt}"')
debug_log(f'Filename generator init: seed={seed} prompt="{prompt}"')
self.p = p
if seed is not None and int(seed) > 0:
self.seed = seed
@ -163,7 +164,7 @@ class FilenameGenerator:
def prompt_sanitize(self, prompt):
invalid_chars = '#<>:\'"\\|?*\n\t\r'
sanitized = prompt.translate({ ord(x): '_' for x in invalid_chars }).strip()
debug(f'Prompt sanitize: input="{prompt}" output={sanitized}')
debug_log(f'Prompt sanitize: input="{prompt}" output="{sanitized}"')
return sanitized
def sanitize(self, filename):
@ -200,7 +201,7 @@ class FilenameGenerator:
while len(os.path.abspath(fn)) > max_length:
fn = fn[:-1]
fn += ext
debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
debug_log(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
return fn
def safe_int(self, s):
@ -234,25 +235,38 @@ class FilenameGenerator:
def apply(self, x):
res = ''
if debug:
for k in self.replacements.keys():
try:
fn = self.replacements.get(k, None)
debug_log(f'Namegen: key={k} value={fn(self)}')
except Exception as e:
shared.log.error(f'Namegen: key={k} {e}')
errors.display(e, 'namegen')
for m in re_pattern.finditer(x):
text, pattern = m.groups()
if pattern is None:
res += text
continue
pattern_args = []
while True:
m = re_pattern_arg.match(pattern)
if m is None:
break
pattern, arg = m.groups()
pattern_args.insert(0, arg)
debug_log(f'Filename apply: text="{text}" pattern="{pattern}"')
if isinstance(pattern, list):
pattern = ' '.join(pattern)
if pattern is None or not isinstance(pattern, str) or pattern.strip() == '':
debug_log(f'Filename skip: pattern="{pattern}"')
res += text
continue
_pattern = pattern
pattern_args = []
while True:
m = re_pattern_arg.match(_pattern)
if m is None:
break
_pattern, arg = m.groups()
pattern_args.insert(0, arg)
fun = self.replacements.get(pattern.lower(), None)
if fun is not None:
try:
debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}')
replacement = fun(self, *pattern_args)
debug_log(f'Filename apply: pattern="{pattern}" args={pattern_args} replacement="{replacement}"')
except Exception as e:
replacement = None
errors.display(e, 'namegen')

View File

@ -4,7 +4,7 @@ import threading
import torch
import numpy as np
from PIL import Image
from modules import modelloader, paths, devices, shared, sd_models
from modules import modelloader, paths, devices, shared
re_special = re.compile(r'([\\()])')
load_lock = threading.Lock()
@ -35,21 +35,55 @@ class DeepDanbooru:
def start(self):
self.load()
sd_models.move_model(self.model, devices.device)
self.model.to(devices.device)
def stop(self):
if shared.opts.interrogate_offload:
sd_models.move_model(self.model, devices.cpu)
self.model.to(devices.cpu)
devices.torch_gc()
def tag(self, pil_image):
def tag(self, pil_image, **kwargs):
self.start()
res = self.tag_multi(pil_image)
res = self.tag_multi(pil_image, **kwargs)
self.stop()
return res
def tag_multi(self, pil_image, force_disable_ranks=False):
def tag_multi(
self,
pil_image,
general_threshold: float = None,
include_rating: bool = None,
exclude_tags: str = None,
max_tags: int = None,
sort_alpha: bool = None,
use_spaces: bool = None,
escape_brackets: bool = None,
):
"""Run inference and return formatted tag string.
Args:
pil_image: PIL Image to tag
general_threshold: Threshold for tag scores (0-1)
include_rating: Whether to include rating tags
exclude_tags: Comma-separated tags to exclude
max_tags: Maximum number of tags to return
sort_alpha: Sort tags alphabetically vs by confidence
use_spaces: Use spaces instead of underscores
escape_brackets: Escape parentheses/brackets in tags
Returns:
Formatted tag string
"""
# Use settings defaults if not specified
general_threshold = general_threshold or shared.opts.tagger_threshold
include_rating = include_rating if include_rating is not None else shared.opts.tagger_include_rating
exclude_tags = exclude_tags or shared.opts.tagger_exclude_tags
max_tags = max_tags or shared.opts.tagger_max_tags
sort_alpha = sort_alpha if sort_alpha is not None else shared.opts.tagger_sort_alpha
use_spaces = use_spaces if use_spaces is not None else shared.opts.tagger_use_spaces
escape_brackets = escape_brackets if escape_brackets is not None else shared.opts.tagger_escape_brackets
if isinstance(pil_image, list):
pil_image = pil_image[0] if len(pil_image) > 0 else None
if isinstance(pil_image, dict) and 'name' in pil_image:
@ -58,35 +92,237 @@ class DeepDanbooru:
return ''
pic = pil_image.resize((512, 512), resample=Image.Resampling.LANCZOS).convert("RGB")
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with devices.inference_context(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
with devices.inference_context():
x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype)
y = self.model(x)[0].detach().float().cpu().numpy()
probability_dict = {}
for tag, probability in zip(self.model.tags, y):
if probability < shared.opts.deepbooru_score_threshold:
for current, probability in zip(self.model.tags, y):
if probability < general_threshold:
continue
if tag.startswith("rating:"):
if current.startswith("rating:") and not include_rating:
continue
probability_dict[tag] = probability
if shared.opts.deepbooru_sort_alpha:
probability_dict[current] = probability
if sort_alpha:
tags = sorted(probability_dict)
else:
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
res = []
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
tag_outformat = tag
if shared.opts.deepbooru_use_spaces:
filtertags = {x.strip().replace(' ', '_') for x in exclude_tags.split(",")}
for filtertag in [x for x in tags if x not in filtertags]:
probability = probability_dict[filtertag]
tag_outformat = filtertag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if shared.opts.deepbooru_escape:
if escape_brackets:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if shared.opts.interrogate_score and not force_disable_ranks:
if shared.opts.tagger_show_scores:
tag_outformat = f"({tag_outformat}:{probability:.2f})"
res.append(tag_outformat)
if len(res) > shared.opts.deepbooru_max_tags:
res = res[:shared.opts.deepbooru_max_tags]
if max_tags > 0 and len(res) > max_tags:
res = res[:max_tags]
return ", ".join(res)
model = DeepDanbooru()
def _save_tags_to_file(img_path, tags_str: str, save_append: bool) -> bool:
"""Save tags to a text file with error handling.
Args:
img_path: Path to the image file
tags_str: Tags string to save
save_append: If True, append to existing file; otherwise overwrite
Returns:
True if save succeeded, False otherwise
"""
try:
txt_path = img_path.with_suffix('.txt')
if save_append and txt_path.exists():
with open(txt_path, 'a', encoding='utf-8') as f:
f.write(f', {tags_str}')
else:
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(tags_str)
return True
except Exception as e:
shared.log.error(f'DeepBooru batch: failed to save file="{img_path}" error={e}')
return False
def get_models() -> list:
"""Return list of available DeepBooru models (just one)."""
return ["DeepBooru"]
def load_model(model_name: str = None) -> bool: # pylint: disable=unused-argument
"""Load the DeepBooru model."""
try:
model.load()
return model.model is not None
except Exception as e:
shared.log.error(f'DeepBooru load: {e}')
return False
def unload_model():
"""Unload the DeepBooru model and free memory."""
if model.model is not None:
shared.log.debug('DeepBooru unload')
model.model = None
devices.torch_gc(force=True)
def tag(image, **kwargs) -> str:
"""Tag an image using DeepBooru.
Args:
image: PIL Image to tag
**kwargs: Tagger parameters (general_threshold, include_rating, exclude_tags,
max_tags, sort_alpha, use_spaces, escape_brackets)
Returns:
Formatted tag string
"""
import time
t0 = time.time()
jobid = shared.state.begin('DeepBooru Tag')
shared.log.info(f'DeepBooru: image_size={image.size if image else None}')
try:
result = model.tag(image, **kwargs)
shared.log.debug(f'DeepBooru: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
except Exception as e:
result = f"Exception {type(e)}"
shared.log.error(f'DeepBooru: {e}')
shared.state.end(jobid)
return result
def batch(
model_name: str, # pylint: disable=unused-argument
batch_files: list,
batch_folder: str,
batch_str: str,
save_output: bool = True,
save_append: bool = False,
recursive: bool = False,
**kwargs
) -> str:
"""Process multiple images in batch mode.
Args:
model_name: Model name (ignored, only DeepBooru available)
batch_files: List of file paths
batch_folder: Folder path from file picker
batch_str: Folder path as string
save_output: Save caption to .txt files
save_append: Append to existing caption files
recursive: Recursively process subfolders
**kwargs: Additional arguments (for interface compatibility)
Returns:
Combined tag results
"""
import time
from pathlib import Path
import rich.progress as rp
# Load model
model.load()
# Collect image files
image_files = []
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'}
# From file picker
if batch_files:
for f in batch_files:
if isinstance(f, dict):
image_files.append(Path(f['name']))
elif hasattr(f, 'name'):
image_files.append(Path(f.name))
else:
image_files.append(Path(f))
# From folder picker
if batch_folder:
folder_path = None
if isinstance(batch_folder, list) and len(batch_folder) > 0:
f = batch_folder[0]
if isinstance(f, dict):
folder_path = Path(f['name']).parent
elif hasattr(f, 'name'):
folder_path = Path(f.name).parent
if folder_path and folder_path.is_dir():
if recursive:
for ext in image_extensions:
image_files.extend(folder_path.rglob(f'*{ext}'))
else:
for ext in image_extensions:
image_files.extend(folder_path.glob(f'*{ext}'))
# From string path
if batch_str and batch_str.strip():
folder_path = Path(batch_str.strip())
if folder_path.is_dir():
if recursive:
for ext in image_extensions:
image_files.extend(folder_path.rglob(f'*{ext}'))
else:
for ext in image_extensions:
image_files.extend(folder_path.glob(f'*{ext}'))
# Remove duplicates while preserving order
seen = set()
unique_files = []
for f in image_files:
f_resolved = f.resolve()
if f_resolved not in seen:
seen.add(f_resolved)
unique_files.append(f)
image_files = unique_files
if not image_files:
shared.log.warning('DeepBooru batch: no images found')
return ''
t0 = time.time()
jobid = shared.state.begin('DeepBooru Batch')
shared.log.info(f'DeepBooru batch: images={len(image_files)} write={save_output} append={save_append} recursive={recursive}')
results = []
model.start()
# Progress bar
pbar = rp.Progress(rp.TextColumn('[cyan]DeepBooru:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
with pbar:
task = pbar.add_task(total=len(image_files), description='starting...')
for img_path in image_files:
pbar.update(task, advance=1, description=str(img_path.name))
try:
if shared.state.interrupted:
shared.log.info('DeepBooru batch: interrupted')
break
image = Image.open(img_path)
tags_str = model.tag_multi(image, **kwargs)
if save_output:
_save_tags_to_file(img_path, tags_str, save_append)
results.append(f'{img_path.name}: {tags_str[:100]}...' if len(tags_str) > 100 else f'{img_path.name}: {tags_str}')
except Exception as e:
shared.log.error(f'DeepBooru batch: file="{img_path}" error={e}')
results.append(f'{img_path.name}: ERROR - {e}')
model.stop()
elapsed = time.time() - t0
shared.log.info(f'DeepBooru batch: complete images={len(results)} time={elapsed:.1f}s')
shared.state.end(jobid)
return '\n'.join(results)

View File

@ -20,10 +20,21 @@ def interrogate(image):
prompt = openclip.interrogate(image, mode=shared.opts.interrogate_clip_mode)
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.interrogate_default_type == 'DeepBooru':
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type}')
from modules.interrogate import deepbooru
prompt = deepbooru.model.tag(image)
elif shared.opts.interrogate_default_type == 'Tagger':
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} model="{shared.opts.waifudiffusion_model}"')
from modules.interrogate import tagger
prompt = tagger.tag(
image=image,
model_name=shared.opts.waifudiffusion_model,
general_threshold=shared.opts.tagger_threshold,
character_threshold=shared.opts.waifudiffusion_character_threshold,
include_rating=shared.opts.tagger_include_rating,
exclude_tags=shared.opts.tagger_exclude_tags,
max_tags=shared.opts.tagger_max_tags,
sort_alpha=shared.opts.tagger_sort_alpha,
use_spaces=shared.opts.tagger_use_spaces,
escape_brackets=shared.opts.tagger_escape_brackets,
)
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
return prompt
elif shared.opts.interrogate_default_type == 'VLM':

View File

@ -11,7 +11,7 @@ from modules.interrogate import vqa_detection
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:

View File

@ -1,4 +1,5 @@
import os
import time
from collections import namedtuple
import threading
import re
@ -7,6 +8,23 @@ from PIL import Image
from modules import devices, paths, shared, errors, sd_models
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
def _apply_blip2_fix(model, processor):
"""Apply compatibility fix for BLIP2 models with newer transformers versions."""
from transformers import AddedToken
if not hasattr(model.config, 'num_query_tokens'):
return
processor.num_query_tokens = model.config.num_query_tokens
image_token = AddedToken("<image>", normalized=False, special=True)
processor.tokenizer.add_tokens([image_token], special_tokens=True)
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64)
model.config.image_token_index = len(processor.tokenizer) - 1
debug_log(f'CLIP load: applied BLIP2 tokenizer fix num_query_tokens={model.config.num_query_tokens}')
caption_models = {
'blip-base': 'Salesforce/blip-image-captioning-base',
'blip-large': 'Salesforce/blip-image-captioning-large',
@ -79,10 +97,14 @@ def load_interrogator(clip_model, blip_model):
clip_interrogator.clip_interrogator.CAPTION_MODELS = caption_models
global ci # pylint: disable=global-statement
if ci is None:
shared.log.debug(f'Interrogate load: clip="{clip_model}" blip="{blip_model}"')
t0 = time.time()
device = devices.get_optimal_device()
cache_path = os.path.join(paths.models_path, 'Interrogator')
shared.log.info(f'CLIP load: clip="{clip_model}" blip="{blip_model}" device={device}')
debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.interrogate_clip_max_length} chunk_size={shared.opts.interrogate_clip_chunk_size} flavor_count={shared.opts.interrogate_clip_flavor_count} offload={shared.opts.interrogate_offload}')
interrogator_config = clip_interrogator.Config(
device=devices.get_optimal_device(),
cache_path=os.path.join(paths.models_path, 'Interrogator'),
device=device,
cache_path=cache_path,
clip_model_name=clip_model,
caption_model_name=blip_model,
quiet=True,
@ -93,22 +115,39 @@ def load_interrogator(clip_model, blip_model):
caption_offload=shared.opts.interrogate_offload,
)
ci = clip_interrogator.Interrogator(interrogator_config)
if blip_model.startswith('blip2-'):
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
elif clip_model != ci.config.clip_model_name or blip_model != ci.config.caption_model_name:
ci.config.clip_model_name = clip_model
ci.config.clip_model = None
ci.load_clip_model()
ci.config.caption_model_name = blip_model
ci.config.caption_model = None
ci.load_caption_model()
t0 = time.time()
if clip_model != ci.config.clip_model_name:
shared.log.info(f'CLIP load: clip="{clip_model}" reloading')
debug_log(f'CLIP load: previous clip="{ci.config.clip_model_name}"')
ci.config.clip_model_name = clip_model
ci.config.clip_model = None
ci.load_clip_model()
if blip_model != ci.config.caption_model_name:
shared.log.info(f'CLIP load: blip="{blip_model}" reloading')
debug_log(f'CLIP load: previous blip="{ci.config.caption_model_name}"')
ci.config.caption_model_name = blip_model
ci.config.caption_model = None
ci.load_caption_model()
if blip_model.startswith('blip2-'):
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
else:
debug_log(f'CLIP: models already loaded clip="{clip_model}" blip="{blip_model}"')
def unload_clip_model():
if ci is not None and shared.opts.interrogate_offload:
shared.log.debug('CLIP unload: offloading models to CPU')
sd_models.move_model(ci.caption_model, devices.cpu)
sd_models.move_model(ci.clip_model, devices.cpu)
ci.caption_offloaded = True
ci.clip_offloaded = True
devices.torch_gc()
debug_log('CLIP unload: complete')
def interrogate(image, mode, caption=None):
@ -119,6 +158,8 @@ def interrogate(image, mode, caption=None):
if image is None:
return ''
image = image.convert("RGB")
t0 = time.time()
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={shared.opts.interrogate_clip_min_flavors} max_flavors={shared.opts.interrogate_clip_max_flavors}')
if mode == 'best':
prompt = ci.interrogate(image, caption=caption, min_flavors=shared.opts.interrogate_clip_min_flavors, max_flavors=shared.opts.interrogate_clip_max_flavors, )
elif mode == 'caption':
@ -131,22 +172,27 @@ def interrogate(image, mode, caption=None):
prompt = ci.interrogate_negative(image, max_flavors=shared.opts.interrogate_clip_max_flavors)
else:
raise RuntimeError(f"Unknown mode {mode}")
debug_log(f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt[:100]}..."' if len(prompt) > 100 else f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt}"')
return prompt
def interrogate_image(image, clip_model, blip_model, mode):
jobid = shared.state.begin('Interrogate CLiP')
t0 = time.time()
shared.log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
try:
if shared.sd_loaded:
from modules.sd_models import apply_balanced_offload # prevent circular import
apply_balanced_offload(shared.sd_model)
debug_log('CLIP: applied balanced offload to sd_model')
load_interrogator(clip_model, blip_model)
image = image.convert('RGB')
prompt = interrogate(image, mode)
devices.torch_gc()
shared.log.debug(f'CLIP: complete time={time.time()-t0:.2f}')
except Exception as e:
prompt = f"Exception {type(e)}"
shared.log.error(f'Interrogate: {e}')
shared.log.error(f'CLIP: {e}')
errors.display(e, 'Interrogate')
shared.state.end(jobid)
return prompt
@ -162,8 +208,11 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
from modules.files_cache import list_files
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
if len(files) == 0:
shared.log.warning('Interrogate batch: type=clip no images')
shared.log.warning('CLIP batch: no images found')
return ''
t0 = time.time()
shared.log.info(f'CLIP batch: mode="{mode}" images={len(files)} clip="{clip_model}" blip="{blip_model}" write={write} append={append}')
debug_log(f'CLIP batch: recursive={recursive} files={files[:5]}{"..." if len(files) > 5 else ""}')
jobid = shared.state.begin('Interrogate batch')
prompts = []
@ -171,6 +220,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
if write:
file_mode = 'w' if not append else 'a'
writer = BatchWriter(os.path.dirname(files[0]), mode=file_mode)
debug_log(f'CLIP batch: writing to "{os.path.dirname(files[0])}" mode="{file_mode}"')
import rich.progress as rp
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
with pbar:
@ -179,6 +229,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
pbar.update(task, advance=1, description=file)
try:
if shared.state.interrupted:
shared.log.info('CLIP batch: interrupted')
break
image = Image.open(file).convert('RGB')
prompt = interrogate(image, mode)
@ -186,19 +237,23 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
if write:
writer.add(file, prompt)
except OSError as e:
shared.log.error(f'Interrogate batch: {e}')
shared.log.error(f'CLIP batch: file="{file}" error={e}')
if write:
writer.close()
ci.config.quiet = False
unload_clip_model()
shared.state.end(jobid)
shared.log.info(f'CLIP batch: complete images={len(prompts)} time={time.time()-t0:.2f}')
return '\n\n'.join(prompts)
def analyze_image(image, clip_model, blip_model):
t0 = time.time()
shared.log.info(f'CLIP analyze: clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
load_interrogator(clip_model, blip_model)
image = image.convert('RGB')
image_features = ci.image_to_features(image)
debug_log(f'CLIP analyze: features shape={image_features.shape if hasattr(image_features, "shape") else "unknown"}')
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
@ -209,6 +264,7 @@ def analyze_image(image, clip_model, blip_model):
movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements)), key=lambda x: x[1], reverse=True))
trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings)), key=lambda x: x[1], reverse=True))
flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors)), key=lambda x: x[1], reverse=True))
shared.log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}')
# Format labels as text
def format_category(name, ranks):

View File

@ -0,0 +1,79 @@
# Unified Tagger Interface - Dispatches to WaifuDiffusion or DeepBooru based on model selection
# Provides a common interface for the Booru Tags tab
from modules import shared
DEEPBOORU_MODEL = "DeepBooru"
def get_models() -> list:
"""Return combined list: DeepBooru + WaifuDiffusion models."""
from modules.interrogate import waifudiffusion
return [DEEPBOORU_MODEL] + waifudiffusion.get_models()
def refresh_models() -> list:
"""Refresh and return all models."""
return get_models()
def is_deepbooru(model_name: str) -> bool:
"""Check if selected model is DeepBooru."""
return model_name == DEEPBOORU_MODEL
def load_model(model_name: str) -> bool:
"""Load appropriate backend."""
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
return deepbooru.load_model()
else:
from modules.interrogate import waifudiffusion
return waifudiffusion.load_model(model_name)
def unload_model():
"""Unload both backends to ensure memory is freed."""
from modules.interrogate import deepbooru, waifudiffusion
deepbooru.unload_model()
waifudiffusion.unload_model()
def tag(image, model_name: str = None, **kwargs) -> str:
"""Unified tagging - dispatch to correct backend.
Args:
image: PIL Image to tag
model_name: Model to use (DeepBooru or WaifuDiffusion model name)
**kwargs: Additional arguments passed to the backend
Returns:
Formatted tag string
"""
if model_name is None:
model_name = shared.opts.waifudiffusion_model
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
return deepbooru.tag(image, **kwargs)
else:
from modules.interrogate import waifudiffusion
return waifudiffusion.tag(image, model_name=model_name, **kwargs)
def batch(model_name: str, **kwargs) -> str:
"""Unified batch processing.
Args:
model_name: Model to use (DeepBooru or WaifuDiffusion model name)
**kwargs: Additional arguments passed to the backend
Returns:
Combined tag results
"""
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
return deepbooru.batch(model_name=model_name, **kwargs)
else:
from modules.interrogate import waifudiffusion
return waifudiffusion.batch(model_name=model_name, **kwargs)

View File

@ -13,7 +13,7 @@ from modules.interrogate import vqa_detection
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:

View File

@ -0,0 +1,544 @@
# WaifuDiffusion Tagger - ONNX-based anime/illustration tagging
# Based on SmilingWolf's tagger models: https://huggingface.co/SmilingWolf
import os
import re
import time
import threading
import numpy as np
from PIL import Image
from modules import shared, devices, errors
# Debug logging - enable with SD_INTERROGATE_DEBUG environment variable
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
re_special = re.compile(r'([\\()])')
load_lock = threading.Lock()
# WaifuDiffusion model repository mappings
WAIFUDIFFUSION_MODELS = {
# v3 models (latest, recommended)
"wd-eva02-large-tagger-v3": "SmilingWolf/wd-eva02-large-tagger-v3",
"wd-vit-tagger-v3": "SmilingWolf/wd-vit-tagger-v3",
"wd-convnext-tagger-v3": "SmilingWolf/wd-convnext-tagger-v3",
"wd-swinv2-tagger-v3": "SmilingWolf/wd-swinv2-tagger-v3",
# v2 models
"wd-v1-4-moat-tagger-v2": "SmilingWolf/wd-v1-4-moat-tagger-v2",
"wd-v1-4-swinv2-tagger-v2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
"wd-v1-4-convnext-tagger-v2": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
"wd-v1-4-convnextv2-tagger-v2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
"wd-v1-4-vit-tagger-v2": "SmilingWolf/wd-v1-4-vit-tagger-v2",
}
# Tag categories from selected_tags.csv
CATEGORY_GENERAL = 0
CATEGORY_CHARACTER = 4
CATEGORY_RATING = 9
class WaifuDiffusionTagger:
"""WaifuDiffusion Tagger using ONNX inference."""
def __init__(self):
self.session = None
self.tags = None
self.tag_categories = None
self.model_name = None
self.model_path = None
self.image_size = 448 # Standard for WD models
def load(self, model_name: str = None):
"""Load the ONNX model and tags from HuggingFace."""
import huggingface_hub
if model_name is None:
model_name = shared.opts.waifudiffusion_model
if model_name not in WAIFUDIFFUSION_MODELS:
shared.log.error(f'WaifuDiffusion: unknown model "{model_name}"')
return False
with load_lock:
if self.session is not None and self.model_name == model_name:
debug_log(f'WaifuDiffusion: model already loaded model="{model_name}"')
return True # Already loaded
# Unload previous model if different
if self.model_name != model_name and self.session is not None:
debug_log(f'WaifuDiffusion: switching model from "{self.model_name}" to "{model_name}"')
self.unload()
repo_id = WAIFUDIFFUSION_MODELS[model_name]
t0 = time.time()
shared.log.info(f'WaifuDiffusion load: model="{model_name}" repo="{repo_id}"')
try:
# Download only ONNX model and tags CSV (skip safetensors/msgpack variants)
debug_log(f'WaifuDiffusion load: downloading from HuggingFace cache_dir="{shared.opts.hfcache_dir}"')
self.model_path = huggingface_hub.snapshot_download(
repo_id,
cache_dir=shared.opts.hfcache_dir,
allow_patterns=["model.onnx", "selected_tags.csv"],
)
debug_log(f'WaifuDiffusion load: model_path="{self.model_path}"')
# Load ONNX model
model_file = os.path.join(self.model_path, "model.onnx")
if not os.path.exists(model_file):
shared.log.error(f'WaifuDiffusion load: model file not found: {model_file}')
return False
import onnxruntime as ort
debug_log(f'WaifuDiffusion load: onnxruntime version={ort.__version__}')
self.session = ort.InferenceSession(model_file, providers=devices.onnx)
self.model_name = model_name
# Get actual providers used
actual_providers = self.session.get_providers()
debug_log(f'WaifuDiffusion load: active providers={actual_providers}')
# Load tags from CSV
self._load_tags()
load_time = time.time() - t0
shared.log.debug(f'WaifuDiffusion load: time={load_time:.2f} tags={len(self.tags)}')
debug_log(f'WaifuDiffusion load: input_name={self.session.get_inputs()[0].name} output_name={self.session.get_outputs()[0].name}')
return True
except Exception as e:
shared.log.error(f'WaifuDiffusion load: failed error={e}')
errors.display(e, 'WaifuDiffusion load')
self.unload()
return False
def _load_tags(self):
"""Load tags and categories from selected_tags.csv."""
import csv
csv_path = os.path.join(self.model_path, "selected_tags.csv")
if not os.path.exists(csv_path):
shared.log.error(f'WaifuDiffusion load: tags file not found: {csv_path}')
return
self.tags = []
self.tag_categories = []
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
self.tags.append(row['name'])
self.tag_categories.append(int(row['category']))
# Count tags by category
category_counts = {}
for cat in self.tag_categories:
category_counts[cat] = category_counts.get(cat, 0) + 1
debug_log(f'WaifuDiffusion load: tag categories={category_counts}')
def unload(self):
"""Unload the model and free resources."""
if self.session is not None:
shared.log.debug(f'WaifuDiffusion unload: model="{self.model_name}"')
self.session = None
self.tags = None
self.tag_categories = None
self.model_name = None
self.model_path = None
devices.torch_gc(force=True)
debug_log('WaifuDiffusion unload: complete')
else:
debug_log('WaifuDiffusion unload: no model loaded')
def preprocess_image(self, image: Image.Image) -> np.ndarray:
"""Preprocess image for WaifuDiffusion model input.
- Resize to 448x448 (standard for WD models)
- Pad to square with white background
- Normalize to [0, 1] range
- BGR channel order (as used by these models)
"""
original_size = image.size
original_mode = image.mode
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Pad to square with white background
w, h = image.size
max_dim = max(w, h)
pad_left = (max_dim - w) // 2
pad_top = (max_dim - h) // 2
padded = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
padded.paste(image, (pad_left, pad_top))
# Resize to model input size
if max_dim != self.image_size:
padded = padded.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
# Convert to numpy array and normalize
img_array = np.array(padded, dtype=np.float32)
# Convert RGB to BGR (model expects BGR)
img_array = img_array[:, :, ::-1]
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
debug_log(f'WaifuDiffusion preprocess: original_size={original_size} mode={original_mode} padded_size={max_dim} output_shape={img_array.shape}')
return img_array
def predict(
self,
image: Image.Image,
general_threshold: float = None,
character_threshold: float = None,
include_rating: bool = None,
exclude_tags: str = None,
max_tags: int = None,
sort_alpha: bool = None,
use_spaces: bool = None,
escape_brackets: bool = None,
) -> str:
"""Run inference and return formatted tag string.
Args:
image: PIL Image to tag
general_threshold: Threshold for general tags (0-1)
character_threshold: Threshold for character tags (0-1)
include_rating: Whether to include rating tags
exclude_tags: Comma-separated tags to exclude
max_tags: Maximum number of tags to return
sort_alpha: Sort tags alphabetically vs by confidence
use_spaces: Use spaces instead of underscores
escape_brackets: Escape parentheses/brackets in tags
Returns:
Formatted tag string
"""
t0 = time.time()
# Use settings defaults if not specified
general_threshold = general_threshold or shared.opts.tagger_threshold
character_threshold = character_threshold or shared.opts.waifudiffusion_character_threshold
include_rating = include_rating if include_rating is not None else shared.opts.tagger_include_rating
exclude_tags = exclude_tags or shared.opts.tagger_exclude_tags
max_tags = max_tags or shared.opts.tagger_max_tags
sort_alpha = sort_alpha if sort_alpha is not None else shared.opts.tagger_sort_alpha
use_spaces = use_spaces if use_spaces is not None else shared.opts.tagger_use_spaces
escape_brackets = escape_brackets if escape_brackets is not None else shared.opts.tagger_escape_brackets
debug_log(f'WaifuDiffusion predict: general_threshold={general_threshold} character_threshold={character_threshold} max_tags={max_tags} include_rating={include_rating} sort_alpha={sort_alpha}')
# Handle input variations
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if image is None:
shared.log.error('WaifuDiffusion predict: no image provided')
return ''
# Load model if needed
if self.session is None:
if not self.load():
return ''
# Preprocess image
img_input = self.preprocess_image(image)
# Run inference
t_infer = time.time()
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
probs = self.session.run([output_name], {input_name: img_input})[0][0]
infer_time = time.time() - t_infer
debug_log(f'WaifuDiffusion predict: inference time={infer_time:.3f}s output_shape={probs.shape}')
# Build tag list with probabilities
tag_probs = {}
exclude_set = {x.strip().replace(' ', '_').lower() for x in exclude_tags.split(',') if x.strip()}
if exclude_set:
debug_log(f'WaifuDiffusion predict: exclude_tags={exclude_set}')
general_count = 0
character_count = 0
rating_count = 0
for i, (tag_name, prob) in enumerate(zip(self.tags, probs)):
category = self.tag_categories[i]
tag_lower = tag_name.lower()
# Skip excluded tags
if tag_lower in exclude_set:
continue
# Apply category-specific thresholds
if category == CATEGORY_RATING:
if not include_rating:
continue
# Always include rating if enabled
tag_probs[tag_name] = float(prob)
rating_count += 1
elif category == CATEGORY_CHARACTER:
if prob >= character_threshold:
tag_probs[tag_name] = float(prob)
character_count += 1
elif category == CATEGORY_GENERAL:
if prob >= general_threshold:
tag_probs[tag_name] = float(prob)
general_count += 1
else:
# Other categories use general threshold
if prob >= general_threshold:
tag_probs[tag_name] = float(prob)
debug_log(f'WaifuDiffusion predict: matched tags general={general_count} character={character_count} rating={rating_count} total={len(tag_probs)}')
# Sort tags
if sort_alpha:
sorted_tags = sorted(tag_probs.keys())
else:
sorted_tags = [t for t, _ in sorted(tag_probs.items(), key=lambda x: -x[1])]
# Limit number of tags
if max_tags > 0 and len(sorted_tags) > max_tags:
sorted_tags = sorted_tags[:max_tags]
debug_log(f'WaifuDiffusion predict: limited to max_tags={max_tags}')
# Format output
result = []
for tag_name in sorted_tags:
formatted_tag = tag_name
if use_spaces:
formatted_tag = formatted_tag.replace('_', ' ')
if escape_brackets:
formatted_tag = re.sub(re_special, r'\\\1', formatted_tag)
if shared.opts.tagger_show_scores:
formatted_tag = f"({formatted_tag}:{tag_probs[tag_name]:.2f})"
result.append(formatted_tag)
output = ", ".join(result)
total_time = time.time() - t0
debug_log(f'WaifuDiffusion predict: complete tags={len(result)} time={total_time:.2f} result="{output[:100]}..."' if len(output) > 100 else f'WaifuDiffusion predict: complete tags={len(result)} time={total_time:.2f} result="{output}"')
return output
def tag(self, image: Image.Image, **kwargs) -> str:
"""Alias for predict() to match deepbooru interface."""
return self.predict(image, **kwargs)
# Global tagger instance
tagger = WaifuDiffusionTagger()
def _save_tags_to_file(img_path, tags_str: str, save_append: bool) -> bool:
"""Save tags to a text file with error handling.
Args:
img_path: Path to the image file
tags_str: Tags string to save
save_append: If True, append to existing file; otherwise overwrite
Returns:
True if save succeeded, False otherwise
"""
try:
txt_path = img_path.with_suffix('.txt')
if save_append and txt_path.exists():
with open(txt_path, 'a', encoding='utf-8') as f:
f.write(f', {tags_str}')
else:
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(tags_str)
return True
except Exception as e:
shared.log.error(f'WaifuDiffusion batch: failed to save file="{img_path}" error={e}')
return False
def get_models() -> list:
"""Return list of available WaifuDiffusion model names."""
return list(WAIFUDIFFUSION_MODELS.keys())
def refresh_models() -> list:
"""Refresh and return list of available models."""
# For now, just return the static list
# Could be extended to check for locally cached models
return get_models()
def load_model(model_name: str = None) -> bool:
"""Load the specified WaifuDiffusion model."""
return tagger.load(model_name)
def unload_model():
"""Unload the current WaifuDiffusion model."""
tagger.unload()
def tag(image: Image.Image, model_name: str = None, **kwargs) -> str:
"""Tag an image using WaifuDiffusion tagger.
Args:
image: PIL Image to tag
model_name: Model to use (loads if needed)
**kwargs: Additional arguments passed to predict()
Returns:
Formatted tag string
"""
t0 = time.time()
jobid = shared.state.begin('WaifuDiffusion Tag')
shared.log.info(f'WaifuDiffusion: model="{model_name or tagger.model_name or shared.opts.waifudiffusion_model}" image_size={image.size if image else None}')
try:
if model_name and model_name != tagger.model_name:
tagger.load(model_name)
result = tagger.predict(image, **kwargs)
shared.log.debug(f'WaifuDiffusion: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
# Offload model if setting enabled
if shared.opts.interrogate_offload:
tagger.unload()
except Exception as e:
result = f"Exception {type(e)}"
shared.log.error(f'WaifuDiffusion: {e}')
errors.display(e, 'WaifuDiffusion Tag')
shared.state.end(jobid)
return result
def batch(
model_name: str,
batch_files: list,
batch_folder: str,
batch_str: str,
save_output: bool = True,
save_append: bool = False,
recursive: bool = False,
**kwargs
) -> str:
"""Process multiple images in batch mode.
Args:
model_name: Model to use
batch_files: List of file paths
batch_folder: Folder path from file picker
batch_str: Folder path as string
save_output: Save caption to .txt files
save_append: Append to existing caption files
recursive: Recursively process subfolders
**kwargs: Additional arguments passed to predict()
Returns:
Combined tag results
"""
from pathlib import Path
# Load model
if model_name:
tagger.load(model_name)
elif tagger.session is None:
tagger.load()
# Collect image files
image_files = []
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'}
# From file picker
if batch_files:
for f in batch_files:
if isinstance(f, dict):
image_files.append(Path(f['name']))
elif hasattr(f, 'name'):
image_files.append(Path(f.name))
else:
image_files.append(Path(f))
# From folder picker
if batch_folder:
folder_path = None
if isinstance(batch_folder, list) and len(batch_folder) > 0:
f = batch_folder[0]
if isinstance(f, dict):
folder_path = Path(f['name']).parent
elif hasattr(f, 'name'):
folder_path = Path(f.name).parent
if folder_path and folder_path.is_dir():
if recursive:
for ext in image_extensions:
image_files.extend(folder_path.rglob(f'*{ext}'))
else:
for ext in image_extensions:
image_files.extend(folder_path.glob(f'*{ext}'))
# From string path
if batch_str and batch_str.strip():
folder_path = Path(batch_str.strip())
if folder_path.is_dir():
if recursive:
for ext in image_extensions:
image_files.extend(folder_path.rglob(f'*{ext}'))
else:
for ext in image_extensions:
image_files.extend(folder_path.glob(f'*{ext}'))
# Remove duplicates while preserving order
seen = set()
unique_files = []
for f in image_files:
f_resolved = f.resolve()
if f_resolved not in seen:
seen.add(f_resolved)
unique_files.append(f)
image_files = unique_files
if not image_files:
shared.log.warning('WaifuDiffusion batch: no images found')
return ''
t0 = time.time()
jobid = shared.state.begin('WaifuDiffusion Batch')
shared.log.info(f'WaifuDiffusion batch: model="{tagger.model_name}" images={len(image_files)} write={save_output} append={save_append} recursive={recursive}')
debug_log(f'WaifuDiffusion batch: files={[str(f) for f in image_files[:5]]}{"..." if len(image_files) > 5 else ""}')
results = []
# Progress bar
import rich.progress as rp
pbar = rp.Progress(rp.TextColumn('[cyan]WaifuDiffusion:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
with pbar:
task = pbar.add_task(total=len(image_files), description='starting...')
for img_path in image_files:
pbar.update(task, advance=1, description=str(img_path.name))
try:
if shared.state.interrupted:
shared.log.info('WaifuDiffusion batch: interrupted')
break
image = Image.open(img_path)
tags_str = tagger.predict(image, **kwargs)
if save_output:
_save_tags_to_file(img_path, tags_str, save_append)
results.append(f'{img_path.name}: {tags_str[:100]}...' if len(tags_str) > 100 else f'{img_path.name}: {tags_str}')
except Exception as e:
shared.log.error(f'WaifuDiffusion batch: file="{img_path}" error={e}')
results.append(f'{img_path.name}: ERROR - {e}')
elapsed = time.time() - t0
shared.log.info(f'WaifuDiffusion batch: complete images={len(results)} time={elapsed:.1f}s')
shared.state.end(jobid)
return '\n'.join(results)

View File

@ -5,14 +5,18 @@ Lightweight IP-Adapter applied to existing pipeline in Diffusers
- IP adapters: https://huggingface.co/h94/IP-Adapter
"""
from __future__ import annotations
import os
import time
import json
from typing import TYPE_CHECKING
from PIL import Image
import diffusers
import transformers
from modules import processing, shared, devices, sd_models, errors, model_quant
if TYPE_CHECKING:
from diffusers import DiffusionPipeline
clip_loaded = None
adapters_loaded = []
@ -160,7 +164,7 @@ def unapply(pipe, unload: bool = False): # pylint: disable=arguments-differ
pass
def load_image_encoder(pipe: diffusers.DiffusionPipeline, adapter_names: list[str]):
def load_image_encoder(pipe: DiffusionPipeline, adapter_names: list[str]):
global clip_loaded # pylint: disable=global-statement
for adapter_name in adapter_names:
# which clip to use

View File

@ -47,10 +47,10 @@ def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type
log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f} fn={fn}')
except FileNotFoundError as err:
if not silent:
log.debug(f'Reading failed: {filename} {err}')
log.debug(f'Read failed: file="{filename}" {err}')
except Exception as err:
if not silent:
log.error(f'Reading failed: {filename} {err}')
log.error(f'Read failed: file="{filename}" {err}')
try:
if locking_available and lock_file is not None:
lock_file.release_read_lock()

View File

@ -46,6 +46,13 @@ except Exception as e:
sys.exit(1)
timer.startup.record("scipy")
try:
import atexit
import torch._inductor.async_compile as ac
atexit.unregister(ac.shutdown_compile_workers)
except Exception:
pass
import torch # pylint: disable=C0411
if torch.__version__.startswith('2.5.0'):
errors.log.warning(f'Disabling cuDNN for SDP on torch={torch.__version__}')

View File

@ -3,6 +3,7 @@ import re
import time
import torch
import diffusers.models.lora
from modules.errorlimiter import ErrorLimiter
from modules.lora import lora_common as l
from modules import shared, devices, errors, model_quant
@ -141,6 +142,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
if l.debug:
errors.display(e, 'LoRA')
raise RuntimeError('LoRA apply weight') from e
ErrorLimiter.notify(("network_activate", "network_deactivate"))
continue
return batch_updown, batch_ex_bias

View File

@ -269,7 +269,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
continue
if net is None:
failed_to_load_networks.append(name)
shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} not found')
continue
if hasattr(sd_model, 'embedding_db'):
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)

View File

@ -1,6 +1,7 @@
from contextlib import nullcontext
import time
import rich.progress as rp
from modules.errorlimiter import limit_errors
from modules.lora import lora_common as l
from modules.lora.lora_apply import network_apply_weights, network_apply_direct, network_backup_weights, network_calc_weights
from modules import shared, devices, sd_models
@ -12,61 +13,62 @@ default_components = ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'text_
def network_activate(include=[], exclude=[]):
t0 = time.time()
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
if shared.opts.diffusers_offload_mode == "sequential":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, device=devices.cpu)
device = None
modules = {}
components = include if len(include) > 0 else default_components
components = [x for x in components if x not in exclude]
active_components = []
for name in components:
component = getattr(sd_model, name, None)
if component is not None and hasattr(component, 'named_modules'):
active_components.append(name)
modules[name] = list(component.named_modules())
total = sum(len(x) for x in modules.values())
if len(l.loaded_networks) > 0:
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=activate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
task = pbar.add_task(description='' , total=total)
else:
task = None
pbar = nullcontext()
applied_weight = 0
applied_bias = 0
with devices.inference_context(), pbar:
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in l.loaded_networks) if len(l.loaded_networks) > 0 else ()
applied_layers.clear()
backup_size = 0
for component in modules.keys():
device = getattr(sd_model, component, None).device
for _, module in modules[component]:
network_layer_name = getattr(module, 'network_layer_name', None)
current_names = getattr(module, "network_current_names", ())
if getattr(module, 'weight', None) is None or shared.state.interrupted or (network_layer_name is None) or (current_names == wanted_names):
with limit_errors("network_activate"):
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
if shared.opts.diffusers_offload_mode == "sequential":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, device=devices.cpu)
device = None
modules = {}
components = include if len(include) > 0 else default_components
components = [x for x in components if x not in exclude]
active_components = []
for name in components:
component = getattr(sd_model, name, None)
if component is not None and hasattr(component, 'named_modules'):
active_components.append(name)
modules[name] = list(component.named_modules())
total = sum(len(x) for x in modules.values())
if len(l.loaded_networks) > 0:
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=activate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
task = pbar.add_task(description='' , total=total)
else:
task = None
pbar = nullcontext()
applied_weight = 0
applied_bias = 0
with devices.inference_context(), pbar:
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in l.loaded_networks) if len(l.loaded_networks) > 0 else ()
applied_layers.clear()
backup_size = 0
for component in modules.keys():
device = getattr(sd_model, component, None).device
for _, module in modules[component]:
network_layer_name = getattr(module, 'network_layer_name', None)
current_names = getattr(module, "network_current_names", ())
if getattr(module, 'weight', None) is None or shared.state.interrupted or (network_layer_name is None) or (current_names == wanted_names):
if task is not None:
pbar.update(task, advance=1)
continue
backup_size += network_backup_weights(module, network_layer_name, wanted_names)
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name)
if shared.opts.lora_fuse_native:
network_apply_direct(module, batch_updown, batch_ex_bias, device=device)
else:
network_apply_weights(module, batch_updown, batch_ex_bias, device=device)
if batch_updown is not None or batch_ex_bias is not None:
applied_layers.append(network_layer_name)
applied_weight += 1 if batch_updown is not None else 0
applied_bias += 1 if batch_ex_bias is not None else 0
batch_updown, batch_ex_bias = None, None
del batch_updown, batch_ex_bias
module.network_current_names = wanted_names
if task is not None:
pbar.update(task, advance=1)
continue
backup_size += network_backup_weights(module, network_layer_name, wanted_names)
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name)
if shared.opts.lora_fuse_native:
network_apply_direct(module, batch_updown, batch_ex_bias, device=device)
else:
network_apply_weights(module, batch_updown, batch_ex_bias, device=device)
if batch_updown is not None or batch_ex_bias is not None:
applied_layers.append(network_layer_name)
applied_weight += 1 if batch_updown is not None else 0
applied_bias += 1 if batch_ex_bias is not None else 0
batch_updown, batch_ex_bias = None, None
del batch_updown, batch_ex_bias
module.network_current_names = wanted_names
if task is not None:
bs = round(backup_size/1024/1024/1024, 2) if backup_size > 0 else None
pbar.update(task, advance=1, description=f'networks={len(l.loaded_networks)} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={bs} device={device}')
bs = round(backup_size/1024/1024/1024, 2) if backup_size > 0 else None
pbar.update(task, advance=1, description=f'networks={len(l.loaded_networks)} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={bs} device={device}')
if task is not None and len(applied_layers) == 0:
pbar.remove_task(task) # hide progress bar for no action
if task is not None and len(applied_layers) == 0:
pbar.remove_task(task) # hide progress bar for no action
l.timer.activate += time.time() - t0
if l.debug and len(l.loaded_networks) > 0:
shared.log.debug(f'Network load: type=LoRA networks={[n.name for n in l.loaded_networks]} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={round(backup_size/1024/1024/1024, 2)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers} device={device} time={l.timer.summary}')
@ -81,49 +83,49 @@ def network_deactivate(include=[], exclude=[]):
if len(l.previously_loaded_networks) == 0:
return
t0 = time.time()
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
if shared.opts.diffusers_offload_mode == "sequential":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, device=devices.cpu)
modules = {}
with limit_errors("network_deactivate"):
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
if shared.opts.diffusers_offload_mode == "sequential":
sd_models.disable_offload(sd_model)
sd_models.move_model(sd_model, device=devices.cpu)
modules = {}
components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer']
components = [x for x in components if x not in exclude]
active_components = []
for name in components:
component = getattr(sd_model, name, None)
if component is not None and hasattr(component, 'named_modules'):
modules[name] = list(component.named_modules())
active_components.append(name)
total = sum(len(x) for x in modules.values())
if len(l.previously_loaded_networks) > 0 and l.debug:
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
task = pbar.add_task(description='', total=total)
else:
task = None
pbar = nullcontext()
with devices.inference_context(), pbar:
applied_layers.clear()
for component in modules.keys():
device = getattr(sd_model, component, None).device
for _, module in modules[component]:
network_layer_name = getattr(module, 'network_layer_name', None)
if shared.state.interrupted or network_layer_name is None:
components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer']
components = [x for x in components if x not in exclude]
active_components = []
for name in components:
component = getattr(sd_model, name, None)
if component is not None and hasattr(component, 'named_modules'):
modules[name] = list(component.named_modules())
active_components.append(name)
total = sum(len(x) for x in modules.values())
if len(l.previously_loaded_networks) > 0 and l.debug:
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
task = pbar.add_task(description='', total=total)
else:
task = None
pbar = nullcontext()
with devices.inference_context(), pbar:
applied_layers.clear()
for component in modules.keys():
device = getattr(sd_model, component, None).device
for _, module in modules[component]:
network_layer_name = getattr(module, 'network_layer_name', None)
if shared.state.interrupted or network_layer_name is None:
if task is not None:
pbar.update(task, advance=1)
continue
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name, use_previous=True)
if shared.opts.lora_fuse_native:
network_apply_direct(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
else:
network_apply_weights(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
if batch_updown is not None or batch_ex_bias is not None:
applied_layers.append(network_layer_name)
del batch_updown, batch_ex_bias
module.network_current_names = ()
if task is not None:
pbar.update(task, advance=1)
continue
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name, use_previous=True)
if shared.opts.lora_fuse_native:
network_apply_direct(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
else:
network_apply_weights(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
if batch_updown is not None or batch_ex_bias is not None:
applied_layers.append(network_layer_name)
del batch_updown, batch_ex_bias
module.network_current_names = ()
if task is not None:
pbar.update(task, advance=1, description=f'networks={len(l.previously_loaded_networks)} modules={active_components} layers={total} unapply={len(applied_layers)}')
pbar.update(task, advance=1, description=f'networks={len(l.previously_loaded_networks)} modules={active_components} layers={total} unapply={len(applied_layers)}')
l.timer.deactivate = time.time() - t0
if l.debug and len(l.previously_loaded_networks) > 0:
shared.log.debug(f'Network deactivate: type=LoRA networks={[n.name for n in l.previously_loaded_networks]} modules={active_components} layers={total} apply={len(applied_layers)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers} time={l.timer.summary}')

View File

@ -15,7 +15,7 @@ def create_ui(prompt, negative, styles, overrides, init_image, init_strength, la
generate = gr.Button('Generate', elem_id="ltx_generate_btn", variant='primary', visible=False)
with gr.Row():
ltx_models = [m.name for m in models['LTX Video']] if 'LTX Video' in models else ['None']
model = gr.Dropdown(label='LTX model', choices=ltx_models, value=ltx_models[0])
model = gr.Dropdown(label='LTX model', choices=ltx_models, value=ltx_models[0], elem_id="ltx_model")
with gr.Accordion(open=False, label="Condition", elem_id='ltx_condition_accordion'):
with gr.Tabs():
with gr.Tab('Video', id='ltx_condition_video_tab'):

36
modules/migrate.py Normal file
View File

@ -0,0 +1,36 @@
import os
from modules.paths import data_path
from installer import log
files = [
'cache.json',
'metadata.json',
'html/extensions.json',
'html/previews.json',
'html/upscalers.json',
'html/reference.json',
'html/themes.json',
'html/reference-quant.json',
'html/reference-distilled.json',
'html/reference-community.json',
'html/reference-cloud.json',
]
def migrate_data():
for f in files:
old_filename = os.path.join(data_path, f)
new_filename = os.path.join(data_path, "data", os.path.basename(f))
if os.path.exists(old_filename):
if not os.path.exists(new_filename):
log.info(f'Migrating: file="{old_filename}" target="{new_filename}"')
try:
os.rename(old_filename, new_filename)
except Exception as e:
log.error(f'Migrating: file="{old_filename}" target="{new_filename}" {e}')
else:
log.warning(f'Migrating: file="{old_filename}" target="{new_filename}" skip existing')
migrate_data()

View File

@ -58,7 +58,7 @@ def get_model_type(pipe):
model_type = 'sana'
elif "HiDream" in name:
model_type = 'h1'
elif "Cosmos2TextToImage" in name:
elif "Cosmos2TextToImage" in name or "AnimaTextToImage" in name:
model_type = 'cosmos'
elif "FLite" in name:
model_type = 'flite'

View File

@ -11,8 +11,10 @@ if TYPE_CHECKING:
from modules.ui_components import DropdownEditable
def options_section(section_identifier: tuple[str, str], options_dict: dict[str, OptionInfo | LegacyOption]):
def options_section(section_identifier: tuple[str, str], options_dict: dict[str, OptionInfo | LegacyOption]) -> dict[str, OptionInfo | LegacyOption]:
"""Set the `section` value for all OptionInfo/LegacyOption items"""
if len(section_identifier) > 2:
section_identifier = section_identifier[:2]
for v in options_dict.values():
v.section = section_identifier
return options_dict

View File

@ -15,7 +15,7 @@ def apply(p: processing.StableDiffusionProcessing): # pylint: disable=arguments-
cls = unapply()
if p.pag_scale == 0:
return
if 'PAG' in cls.__name__:
if cls is not None and 'PAG' in cls.__name__:
pass
elif detect.is_sd15(cls):
if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:

View File

@ -4,6 +4,7 @@ import sys
import json
import shlex
import argparse
import tempfile
from installer import log
@ -18,12 +19,16 @@ cli = parser.parse_known_args(argv)[0]
parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s") # twice because we want data_dir
cli = parser.parse_known_args(argv)[0]
config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
try:
with open(config_path, 'r', encoding='utf8') as f:
config = json.load(f)
except Exception:
config = {}
temp_dir = config.get('temp_dir', '')
if len(temp_dir) == 0:
temp_dir = tempfile.gettempdir()
reference_path = os.path.join('models', 'Reference')
modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)

View File

@ -422,7 +422,7 @@ class YoloRestorer(Detailer):
pc.image_mask = [item.mask]
pc.overlay_images = []
# explictly disable for detailer pass
pc.enable_hr = False
pc.enable_hr = False
pc.do_not_save_samples = True
pc.do_not_save_grid = True
# set recursion flag to avoid nested detailer calls

View File

@ -243,13 +243,13 @@ def process_init(p: StableDiffusionProcessing):
seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
reset_prompts = False
if p.all_prompts is None:
if not p.all_prompts:
p.all_prompts = p.prompt if isinstance(p.prompt, list) else p.batch_size * p.n_iter * [p.prompt]
reset_prompts = True
if p.all_negative_prompts is None:
if not p.all_negative_prompts:
p.all_negative_prompts = p.negative_prompt if isinstance(p.negative_prompt, list) else p.batch_size * p.n_iter * [p.negative_prompt]
reset_prompts = True
if p.all_seeds is None:
if not p.all_seeds:
reset_prompts = True
if type(seed) == list:
p.all_seeds = [int(s) for s in seed]
@ -262,7 +262,7 @@ def process_init(p: StableDiffusionProcessing):
for i in range(len(p.all_prompts)):
seed = get_fixed_seed(p.seed)
p.all_seeds.append(int(seed) + (i if p.subseed_strength == 0 else 0))
if p.all_subseeds is None:
if not p.all_subseeds:
if type(subseed) == list:
p.all_subseeds = [int(s) for s in subseed]
else:
@ -270,8 +270,8 @@ def process_init(p: StableDiffusionProcessing):
if reset_prompts:
if not hasattr(p, 'keep_prompts'):
p.all_prompts, p.all_negative_prompts = shared.prompt_styles.apply_styles_to_prompts(p.all_prompts, p.all_negative_prompts, p.styles, p.all_seeds)
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
p.prompts = p.all_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
p.prompts, _ = extra_networks.parse_prompts(p.prompts)
@ -427,13 +427,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
continue
if not hasattr(p, 'keep_prompts'):
p.prompts = p.all_prompts[n * p.batch_size:(n+1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n+1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n+1) * p.batch_size]
p.prompts = p.all_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
p.negative_prompts = p.all_negative_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
p.seeds = p.all_seeds[(n * p.batch_size):((n+1) * p.batch_size)]
p.subseeds = p.all_subseeds[(n * p.batch_size):((n+1) * p.batch_size)]
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
if len(p.prompts) == 0:
if not p.prompts:
break
p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts)
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
@ -469,8 +469,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
p.scripts.postprocess_batch(p, samples, batch_number=n)
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
p.prompts = p.all_prompts[n * p.batch_size:(n+1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
p.prompts = p.all_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
p.negative_prompts = p.all_negative_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
batch_params = scripts_manager.PostprocessBatchListArgs(list(samples))
p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
samples = batch_params.images

View File

@ -67,7 +67,7 @@ def task_specific_kwargs(p, model):
if 'hires' not in p.ops:
p.ops.append('img2img')
if p.vae_type == 'Remote':
from modules.sd_vae_remote import remote_encode
from modules.vae.sd_vae_remote import remote_encode
p.init_images = remote_encode(p.init_images)
task_args = {
'image': p.init_images,
@ -117,7 +117,7 @@ def task_specific_kwargs(p, model):
p.ops.append('inpaint')
mask_image = p.task_args.get('image_mask', None) or getattr(p, 'image_mask', None) or getattr(p, 'mask', None)
if p.vae_type == 'Remote':
from modules.sd_vae_remote import remote_encode
from modules.vae.sd_vae_remote import remote_encode
p.init_images = remote_encode(p.init_images)
# mask_image = remote_encode(mask_image)
task_args = {
@ -269,7 +269,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
kwargs['output_type'] = 'np' # only set latent if model has vae
# model specific
if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__:
if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'Anima' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__:
kwargs['output_type'] = 'np' # only set latent if model has vae
if 'StableCascade' in model.__class__.__name__:
kwargs.pop("guidance_scale") # remove

View File

@ -308,15 +308,14 @@ class StableDiffusionProcessing:
shared.log.error(f'Override: {override_settings} {e}')
self.override_settings = {}
# null items initialized later
self.prompts = None
self.negative_prompts = None
self.all_prompts = None
self.all_negative_prompts = None
self.prompts = []
self.negative_prompts = []
self.all_prompts = []
self.all_negative_prompts = []
self.seeds = []
self.subseeds = []
self.all_seeds = None
self.all_subseeds = None
self.all_seeds = []
self.all_subseeds = []
# a1111 compatibility items
self.seed_enable_extras: bool = True

View File

@ -5,7 +5,8 @@ https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space
import os
import torch
from modules import shared, sd_vae_taesd, devices
from modules import shared, devices
from modules.vae import sd_vae_taesd
debug_enabled = os.environ.get('SD_HDR_DEBUG', None) is not None

View File

@ -563,9 +563,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # force pipeline
if len(getattr(p, 'init_images', [])) == 0:
p.init_images = [TF.to_pil_image(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))]
if p.prompts is None or len(p.prompts) == 0:
if not p.prompts:
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
if p.negative_prompts is None or len(p.negative_prompts) == 0:
if not p.negative_prompts:
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes

View File

@ -386,7 +386,7 @@ def calculate_base_steps(p, use_denoise_start, use_refiner_start):
if len(getattr(p, 'timesteps', [])) > 0:
return None
cls = shared.sd_model.__class__.__name__
if 'Flex' in cls or 'Kontext' in cls or 'Edit' in cls or 'Wan' in cls or 'Flux2' in cls or 'Layered' in cls:
if shared.sd_model_type not in ['sd', 'sdxl']:
steps = p.steps
elif is_modular():
steps = p.steps

View File

@ -45,6 +45,7 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
"Steps": p.steps,
"Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
"Sampler": p.sampler_name if p.sampler_name != 'Default' else None,
"Scheduler": shared.sd_model.scheduler.__class__.__name__ if getattr(shared.sd_model, 'scheduler', None) is not None else None,
"Seed": all_seeds[index],
"Seed resize from": None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}",
"CFG scale": p.cfg_scale if p.cfg_scale > 1.0 else 1.0,

View File

@ -2,7 +2,8 @@ import os
import time
import numpy as np
import torch
from modules import shared, devices, sd_models, sd_vae, sd_vae_taesd, errors
from modules import shared, devices, sd_models, sd_vae, errors
from modules.vae import sd_vae_taesd
debug = os.environ.get('SD_VAE_DEBUG', None) is not None
@ -286,13 +287,13 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
if vae_type == 'Remote':
jobid = shared.state.begin('Remote VAE')
from modules.sd_vae_remote import remote_decode
from modules.vae.sd_vae_remote import remote_decode
tensors = remote_decode(latents=latents, width=width, height=height)
shared.state.end(jobid)
if tensors is not None and len(tensors) > 0:
return vae_postprocess(tensors, model, output_type)
if vae_type == 'Repa':
from modules.sd_vae_repa import repa_load
from modules.vae.sd_vae_repa import repa_load
vae = repa_load(latents)
vae_type = 'Full'
if vae is not None:
@ -310,14 +311,17 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
latents = latents.unsqueeze(0)
if latents.shape[-1] <= 4: # not a latent, likely an image
decoded = latents.float().cpu().numpy()
elif vae_type == 'Full' and hasattr(model, "vae"):
decoded = full_vae_decode(latents=latents, model=model)
elif hasattr(model, "vqgan"):
decoded = full_vqgan_decode(latents=latents, model=model)
else:
elif vae_type == 'Tiny':
decoded = taesd_vae_decode(latents=latents)
if torch.is_tensor(decoded):
decoded = 2.0 * decoded - 1.0 # typical normalized range
elif hasattr(model, "vqgan"):
decoded = full_vqgan_decode(latents=latents, model=model)
elif hasattr(model, "vae"):
decoded = full_vae_decode(latents=latents, model=model)
else:
shared.log.error('VAE not found in model')
decoded = []
images = vae_postprocess(decoded, model, output_type)
if shared.cmd_opts.profile or debug:
@ -339,11 +343,14 @@ def vae_encode(image, model, vae_type='Full'): # pylint: disable=unused-variable
shared.log.error('VAE not found in model')
return []
tensor = f.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae)
if vae_type == 'Full':
if vae_type == 'Tiny':
latents = taesd_vae_encode(image=tensor)
elif vae_type == 'Full' and hasattr(model, 'vae'):
tensor = tensor * 2 - 1
latents = full_vae_encode(image=tensor, model=shared.sd_model)
else:
latents = taesd_vae_encode(image=tensor)
shared.log.error('VAE not found in model')
latents = []
devices.torch_gc()
shared.state.end(jobid)
return latents

File diff suppressed because it is too large Load Diff

267
modules/res4lyf/__init__.py Normal file
View File

@ -0,0 +1,267 @@
# res4lyf
from .abnorsett_scheduler import ABNorsettScheduler
from .bong_tangent_scheduler import BongTangentScheduler
from .common_sigma_scheduler import CommonSigmaScheduler
from .deis_scheduler_alt import RESDEISMultistepScheduler
from .etdrk_scheduler import ETDRKScheduler
from .gauss_legendre_scheduler import GaussLegendreScheduler
from .langevin_dynamics_scheduler import LangevinDynamicsScheduler
from .lawson_scheduler import LawsonScheduler
from .linear_rk_scheduler import LinearRKScheduler
from .lobatto_scheduler import LobattoScheduler
from .pec_scheduler import PECScheduler
from .radau_iia_scheduler import RadauIIAScheduler
from .res_multistep_scheduler import RESMultistepScheduler
from .res_multistep_sde_scheduler import RESMultistepSDEScheduler
from .res_singlestep_scheduler import RESSinglestepScheduler
from .res_singlestep_sde_scheduler import RESSinglestepSDEScheduler
from .res_unified_scheduler import RESUnifiedScheduler
from .riemannian_flow_scheduler import RiemannianFlowScheduler
from .rungekutta_44s_scheduler import RungeKutta44Scheduler
from .rungekutta_57s_scheduler import RungeKutta57Scheduler
from .rungekutta_67s_scheduler import RungeKutta67Scheduler
from .simple_exponential_scheduler import SimpleExponentialScheduler
from .specialized_rk_scheduler import SpecializedRKScheduler
from .variants import (
ABNorsett2MScheduler,
ABNorsett3MScheduler,
ABNorsett4MScheduler,
SigmaArcsineScheduler,
DEIS1MultistepScheduler,
DEIS2MScheduler,
DEIS2MultistepScheduler,
DEIS3MScheduler,
DEIS3MultistepScheduler,
DEISUnified1SScheduler,
DEISUnified2MScheduler,
DEISUnified3MScheduler,
SigmaEasingScheduler,
ETDRK2Scheduler,
ETDRK3AScheduler,
ETDRK3BScheduler,
ETDRK4AltScheduler,
ETDRK4Scheduler,
FlowEuclideanScheduler,
FlowHyperbolicScheduler,
Lawson2AScheduler,
Lawson2BScheduler,
Lawson4Scheduler,
LinearRK2Scheduler,
LinearRK3Scheduler,
LinearRK4Scheduler,
LinearRKMidpointScheduler,
LinearRKRalsstonScheduler,
Lobatto2Scheduler,
Lobatto3Scheduler,
Lobatto4Scheduler,
FlowLorentzianScheduler,
PEC2H2SScheduler,
PEC2H3SScheduler,
RadauIIA2Scheduler,
RadauIIA3Scheduler,
RES2MScheduler,
RES2MSDEScheduler,
RES2SScheduler,
RES2SSDEScheduler,
RES3MScheduler,
RES3MSDEScheduler,
RES3SScheduler,
RES3SSDEScheduler,
RES5SScheduler,
RES5SSDEScheduler,
RES6SScheduler,
RES6SSDEScheduler,
RESUnified2MScheduler,
RESUnified2SScheduler,
RESUnified3MScheduler,
RESUnified3SScheduler,
RESUnified5SScheduler,
RESUnified6SScheduler,
SigmaSigmoidScheduler,
SigmaSineScheduler,
SigmaSmoothScheduler,
FlowSphericalScheduler,
GaussLegendre2SScheduler,
GaussLegendre3SScheduler,
GaussLegendre4SScheduler,
)
__all__ = [ # noqa: RUF022
# Base
"RESUnifiedScheduler",
"RESMultistepScheduler",
"RESMultistepSDEScheduler",
"RESSinglestepScheduler",
"RESSinglestepSDEScheduler",
"RESDEISMultistepScheduler",
"ETDRKScheduler",
"LawsonScheduler",
"ABNorsettScheduler",
"PECScheduler",
"BongTangentScheduler",
"RiemannianFlowScheduler",
"LangevinDynamicsScheduler",
"CommonSigmaScheduler",
"SimpleExponentialScheduler",
"LinearRKScheduler",
"LobattoScheduler",
"RadauIIAScheduler",
"GaussLegendreScheduler",
"SpecializedRKScheduler",
# Variants
"RES2MScheduler",
"RES3MScheduler",
"DEIS2MScheduler",
"DEIS3MScheduler",
"RES2MSDEScheduler",
"RES3MSDEScheduler",
"RES2SScheduler",
"RES3SScheduler",
"RES5SScheduler",
"RES6SScheduler",
"RES2SSDEScheduler",
"RES3SSDEScheduler",
"RES5SSDEScheduler",
"RES6SSDEScheduler",
"ETDRK2Scheduler",
"ETDRK3AScheduler",
"ETDRK3BScheduler",
"ETDRK4Scheduler",
"ETDRK4AltScheduler",
"Lawson2AScheduler",
"Lawson2BScheduler",
"Lawson4Scheduler",
"ABNorsett2MScheduler",
"ABNorsett3MScheduler",
"ABNorsett4MScheduler",
"PEC2H2SScheduler",
"PEC2H3SScheduler",
"FlowEuclideanScheduler",
"FlowHyperbolicScheduler",
"FlowSphericalScheduler",
"FlowLorentzianScheduler",
"SigmaSigmoidScheduler",
"SigmaSineScheduler",
"SigmaEasingScheduler",
"SigmaArcsineScheduler",
"SigmaSmoothScheduler",
"DEISUnified1SScheduler",
"DEISUnified2MScheduler",
"DEISUnified3MScheduler",
"RESUnified2MScheduler",
"RESUnified3MScheduler",
"RESUnified2SScheduler",
"RESUnified3SScheduler",
"RESUnified5SScheduler",
"RESUnified6SScheduler",
"DEIS1MultistepScheduler",
"DEIS2MultistepScheduler",
"DEIS3MultistepScheduler",
"LinearRK2Scheduler",
"LinearRK3Scheduler",
"LinearRK4Scheduler",
"LinearRKRalsstonScheduler",
"LinearRKMidpointScheduler",
"Lobatto2Scheduler",
"Lobatto3Scheduler",
"Lobatto4Scheduler",
"RadauIIA2Scheduler",
"RadauIIA3Scheduler",
"GaussLegendre2SScheduler",
"GaussLegendre3SScheduler",
"GaussLegendre4SScheduler",
"RungeKutta44Scheduler",
"RungeKutta57Scheduler",
"RungeKutta67Scheduler",
]
BASE = [
("RES Unified", RESUnifiedScheduler),
("RES Multistep", RESMultistepScheduler),
("RES Multistep SDE", RESMultistepSDEScheduler),
("RES Singlestep", RESSinglestepScheduler),
("RES Singlestep SDE", RESSinglestepSDEScheduler),
("DEIS Multistep", RESDEISMultistepScheduler),
("ETDRK", ETDRKScheduler),
("Lawson", LawsonScheduler),
("ABNorsett", ABNorsettScheduler),
("PEC", PECScheduler),
("Common Sigma", CommonSigmaScheduler),
("Riemannian Flow", RiemannianFlowScheduler),
("Specialized RK", SpecializedRKScheduler),
]
SIMPLE = [
("Bong Tangent", BongTangentScheduler),
("Langevin Dynamics", LangevinDynamicsScheduler),
("Simple Exponential", SimpleExponentialScheduler),
]
VARIANTS = [
("RES 2M", RES2MScheduler),
("RES 3M", RES3MScheduler),
("DEIS 2M", DEIS2MScheduler),
("DEIS 3M", DEIS3MScheduler),
("RES 2M SDE", RES2MSDEScheduler),
("RES 3M SDE", RES3MSDEScheduler),
("RES 2S", RES2SScheduler),
("RES 3S", RES3SScheduler),
("RES 5S", RES5SScheduler),
("RES 6S", RES6SScheduler),
("RES 2S SDE", RES2SSDEScheduler),
("RES 3S SDE", RES3SSDEScheduler),
("RES 5S SDE", RES5SSDEScheduler),
("RES 6S SDE", RES6SSDEScheduler),
("ETDRK 2", ETDRK2Scheduler),
("ETDRK 3A", ETDRK3AScheduler),
("ETDRK 3B", ETDRK3BScheduler),
("ETDRK 4", ETDRK4Scheduler),
("ETDRK 4 Alt", ETDRK4AltScheduler),
("Lawson 2A", Lawson2AScheduler),
("Lawson 2B", Lawson2BScheduler),
("Lawson 4", Lawson4Scheduler),
("ABNorsett 2M", ABNorsett2MScheduler),
("ABNorsett 3M", ABNorsett3MScheduler),
("ABNorsett 4M", ABNorsett4MScheduler),
("PEC 2H2S", PEC2H2SScheduler),
("PEC 2H3S", PEC2H3SScheduler),
("Euclidean Flow", FlowEuclideanScheduler),
("Hyperbolic Flow", FlowHyperbolicScheduler),
("Spherical Flow", FlowSphericalScheduler),
("Lorentzian Flow", FlowLorentzianScheduler),
("Sigmoid Sigma", SigmaSigmoidScheduler),
("Sine Sigma", SigmaSineScheduler),
("Easing Sigma", SigmaEasingScheduler),
("Arcsine Sigma", SigmaArcsineScheduler),
("Smoothstep Sigma", SigmaSmoothScheduler),
("DEIS Unified 1", DEISUnified1SScheduler),
("DEIS Unified 2", DEISUnified2MScheduler),
("DEIS Unified 3", DEISUnified3MScheduler),
("RES Unified 2M", RESUnified2MScheduler),
("RES Unified 3M", RESUnified3MScheduler),
("RES Unified 2S", RESUnified2SScheduler),
("RES Unified 3S", RESUnified3SScheduler),
("RES Unified 5S", RESUnified5SScheduler),
("RES Unified 6S", RESUnified6SScheduler),
("DEIS Multistep 1", DEIS1MultistepScheduler),
("DEIS Multistep 2", DEIS2MultistepScheduler),
("DEIS Multistep 3", DEIS3MultistepScheduler),
("Linear-RK 2", LinearRK2Scheduler),
("Linear-RK 3", LinearRK3Scheduler),
("Linear-RK 4", LinearRK4Scheduler),
("Linear-RK Ralston", LinearRKRalsstonScheduler),
("Linear-RK Midpoint", LinearRKMidpointScheduler),
("Lobatto 2", Lobatto2Scheduler),
("Lobatto 3", Lobatto3Scheduler),
("Lobatto 4", Lobatto4Scheduler),
("Radau-IIA 2", RadauIIA2Scheduler),
("Radau-IIA 3", RadauIIA3Scheduler),
("Gauss-Legendre 2S", GaussLegendre2SScheduler),
("Gauss-Legendre 3S", GaussLegendre3SScheduler),
("Gauss-Legendre 4S", GaussLegendre4SScheduler),
("Runge-Kutta 4/4", RungeKutta44Scheduler),
("Runge-Kutta 5/7", RungeKutta57Scheduler),
("Runge-Kutta 6/7", RungeKutta67Scheduler),
]

View File

@ -0,0 +1,340 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
from .phi_functions import Phi
logger = logging.get_logger(__name__)
class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
"""
Adams-Bashforth Norsett (ABNorsett) scheduler.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: Literal["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"] = "abnorsett_2m",
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Buffer for multistep
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
s_min = getattr(self.config, "sigma_min", None)
s_max = getattr(self.config, "sigma_max", None)
if s_min is None:
s_min = 0.001
if s_max is None:
s_max = 1.0
sigmas = np.linspace(s_max, s_min, num_inference_steps)
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
# Map shifted sigmas back to timesteps (Linear mapping for Flow)
# t = sigma * 1000. Use standard linear scaling.
# This ensures the model receives the correct time embedding for the shifted noise level.
# We assume Flow sigmas are in [1.0, 0.0] range (before shift) and model expects [1000, 0].
timesteps = sigmas * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step = self._step_index
sigma = self.sigmas[step]
sigma_next = self.sigmas[step + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
self.model_outputs.append(model_output)
self.x0_outputs.append(x0)
self.prev_sigmas.append(sigma)
variant = self.config.variant
order = int(variant[-2])
curr_order = min(len(self.prev_sigmas), order)
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
if sigma_next == 0:
x_next = x0
else:
# Multi-step coefficients b for ABNorsett family
if curr_order == 1:
b = [[phi(1)]]
elif curr_order == 2:
b2 = -phi(2)
b1 = phi(1) - b2
b = [[b1, b2]]
elif curr_order == 3:
b2 = -2 * phi(2) - 2 * phi(3)
b3 = 0.5 * phi(2) + phi(3)
b1 = phi(1) - (b2 + b3)
b = [[b1, b2, b3]]
elif curr_order == 4:
b2 = -3 * phi(2) - 5 * phi(3) - 3 * phi(4)
b3 = 1.5 * phi(2) + 4 * phi(3) + 3 * phi(4)
b4 = -1 / 3 * phi(2) - phi(3) - phi(4)
b1 = phi(1) - (b2 + b3 + b4)
b = [[b1, b2, b3, b4]]
else:
b = [[phi(1)]]
# Apply coefficients to x0 buffer
res = torch.zeros_like(sample)
for i, b_val in enumerate(b[0]):
idx = len(self.x0_outputs) - 1 - i
if idx >= 0:
res += b_val * self.x0_outputs[idx]
# Exponential Integrator Update
if self.config.prediction_type == "flow_prediction":
# Variable Step Adams-Bashforth for Flow Matching
# x_{n+1} = x_n + \int_{t_n}^{t_{n+1}} v(t) dt
sigma_curr = sigma
dt = sigma_next - sigma_curr
# Current derivative v_n is self.model_outputs[-1]
v_n = self.model_outputs[-1]
if curr_order == 1:
# Euler: x_{n+1} = x_n + dt * v_n
x_next = sample + dt * v_n
elif curr_order == 2:
# AB2 Variable Step
# x_{n+1} = x_n + dt * [ (1 + r/2) * v_n - (r/2) * v_{n-1} ]
# where r = dt_cur / dt_prev
v_nm1 = self.model_outputs[-2]
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma_curr - sigma_prev
if abs(dt_prev) < 1e-8:
# Fallback to Euler if division by zero risk
x_next = sample + dt * v_n
else:
r = dt / dt_prev
# Standard variable step AB2 coefficients
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * v_nm1)
elif curr_order >= 3:
# For now, fallback to AB2 (variable) for higher orders to ensure stability
# given the complexity of variable-step AB3/4 formulas inline.
# The user specifically requested abnorsett_2m.
v_nm1 = self.model_outputs[-2]
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma_curr - sigma_prev
if abs(dt_prev) < 1e-8:
x_next = sample + dt * v_n
else:
r = dt / dt_prev
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * v_nm1)
else:
x_next = sample + dt * v_n
else:
x_next = torch.exp(-h) * sample + h * res
self._step_index += 1
if len(self.x0_outputs) > order:
self.x0_outputs.pop(0)
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,278 @@
# Copyright 2025 The RES4LYF Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class BongTangentScheduler(SchedulerMixin, ConfigMixin):
"""
BongTangent scheduler using Exponential Integrator step.
"""
_compatibles: ClassVar[List[str]] = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
start: float = 1.0,
middle: float = 0.5,
end: float = 0.0,
pivot_1: float = 0.6,
pivot_2: float = 0.6,
slope_1: float = 0.2,
slope_2: float = 0.2,
pad: bool = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = torch.Tensor([])
self.timesteps = torch.Tensor([])
self.num_inference_steps = None
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
timestep_spacing = getattr(self.config, "timestep_spacing", "linspace")
steps_offset = getattr(self.config, "steps_offset", 0)
if timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += steps_offset
elif timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {timestep_spacing} is not supported.")
# Derived sigma range from alphas_cumprod
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
# Note: alphas_cumprod[0] is ~0.999 (small sigma), alphas_cumprod[-1] is ~0.0001 (large sigma)
sigma_max = base_sigmas[-1]
sigma_min = base_sigmas[0]
sigma_mid = (sigma_max + sigma_min) / 2 # Default midpoint for tangent nodes
steps = num_inference_steps
midpoint = int(steps * getattr(self.config, "midpoint", 0.5))
p1 = int(steps * getattr(self.config, "pivot_1", 0.6))
p2 = int(steps * getattr(self.config, "pivot_2", 0.6))
s1 = getattr(self.config, "slope_1", 0.2) / (steps / 40)
s2 = getattr(self.config, "slope_2", 0.2) / (steps / 40)
stage_1_len = midpoint
stage_2_len = steps - midpoint + 1
# Use model's sigma range for start/middle/end
start_cfg = getattr(self.config, "start", 1.0)
start_val = sigma_max * start_cfg if start_cfg > 1.0 else sigma_max
end_val = sigma_min
mid_val = sigma_mid
tan_sigmas_1 = self._get_bong_tangent_sigmas(stage_1_len, s1, p1, start_val, mid_val, dtype=dtype)
tan_sigmas_2 = self._get_bong_tangent_sigmas(stage_2_len, s2, p2 - stage_1_len, mid_val, end_val, dtype=dtype)
tan_sigmas_1 = tan_sigmas_1[:-1]
sigmas_list = tan_sigmas_1 + tan_sigmas_2
if getattr(self.config, "pad", False):
sigmas_list.append(0.0)
sigmas = np.array(sigmas_list)
if getattr(self.config, "use_karras_sigmas", False):
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_exponential_sigmas", False):
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_beta_sigmas", False):
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_flow_sigmas", False):
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
shift = getattr(self.config, "shift", 1.0)
use_dynamic_shifting = getattr(self.config, "use_dynamic_shifting", False)
if shift != 1.0 or use_dynamic_shifting:
if use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
getattr(self.config, "base_shift", 0.5),
getattr(self.config, "max_shift", 1.5),
getattr(self.config, "base_image_seq_len", 256),
getattr(self.config, "max_image_seq_len", 4096),
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def _get_bong_tangent_sigmas(self, steps: int, slope: float, pivot: int, start: float, end: float, dtype: torch.dtype = torch.float32) -> List[float]:
x = torch.arange(steps, dtype=dtype)
def bong_fn(val):
return ((2 / torch.pi) * torch.atan(-slope * (val - pivot)) + 1) / 2
smax = bong_fn(torch.tensor(0.0))
smin = bong_fn(torch.tensor(steps - 1.0))
srange = smax - smin
sscale = start - end
sigmas = ((bong_fn(x) - smin) * (1 / srange) * sscale + end)
return sigmas.tolist()
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self._step_index]
sigma_next = self.sigmas[self._step_index + 1]
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1)**0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
# Exponential Integrator Update
if sigma_next == 0:
x_next = x0
else:
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,263 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class CommonSigmaScheduler(SchedulerMixin, ConfigMixin):
"""
Common Sigma scheduler using Exponential Integrator step.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
profile: Literal["sigmoid", "sine", "easing", "arcsine", "smoothstep"] = "sigmoid",
variant: str = "logistic",
strength: float = 1.0,
gain: float = 1.0,
offset: float = 0.0,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# Setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = torch.Tensor([])
self._step_index = None
self._begin_index = None
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
# Derived sigma range from alphas_cumprod
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigma_max = base_sigmas[-1]
sigma_min = base_sigmas[0]
t = torch.linspace(0, 1, num_inference_steps)
profile = self.config.profile
variant = self.config.variant
gain = self.config.gain
offset = self.config.offset
if profile == "sigmoid":
x = gain * (t * 10 - 5 + offset)
if variant == "logistic":
result = 1.0 / (1.0 + torch.exp(-x))
elif variant == "tanh":
result = (torch.tanh(x) + 1) / 2
else:
result = torch.sigmoid(x)
elif profile == "sine":
result = torch.sin(t * math.pi / 2)
elif profile == "easing":
result = t * t * (3 - 2 * t)
elif profile == "arcsine":
result = torch.arcsin(t) / (math.pi / 2)
else:
result = t
# Map profile to sigma range
sigmas = (sigma_max * (1 - result) + sigma_min * result).cpu().numpy()
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1)**0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
# Exponential Integrator Update
if sigma_next == 0:
x_next = x0
else:
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,403 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from .phi_functions import Phi
def get_def_integral_2(a, b, start, end, c):
coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
return coeff / ((c - a) * (c - b))
def get_def_integral_3(a, b, c, start, end, d):
coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 + (end**2 - start**2) * (a * b + a * c + b * c) / 2 - (end - start) * a * b * c
return coeff / ((d - a) * (d - b) * (d - c))
class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
RESDEISMultistepScheduler: Diffusion Explicit Iterative Sampler with high-order multistep.
Adapted from the RES4LYF repository.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
solver_order: int = 2,
use_analytic_solution: bool = True,
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.init_noise_sigma = 1.0
# Internal state
self.model_outputs = []
self.hist_samples = []
self._step_index = None
self._sigmas_cpu = None
self.all_coeffs = []
self.prev_sigmas = []
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float).copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (np.arange(num_inference_steps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= step_ratio
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
if self.config.timestep_spacing == "trailing":
timesteps = np.maximum(timesteps, 0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas_all = np.log(np.maximum(sigmas, 1e-10))
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
# 2. Sigma Schedule
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# Map back to timesteps
if self.config.use_flow_sigmas:
timesteps = sigmas * self.config.num_train_timesteps
else:
timesteps = np.interp(np.log(np.maximum(sigmas, 1e-10)), log_sigmas_all, np.arange(len(log_sigmas_all)))
self.sigmas = torch.from_numpy(np.append(sigmas, 0.0)).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
# Precompute coefficients
self.all_coeffs = []
num_steps = len(timesteps)
for i in range(num_steps):
sigma_t = self._sigmas_cpu[i]
sigma_next = self._sigmas_cpu[i + 1]
if sigma_next <= 0:
coeffs = None
else:
current_order = min(i + 1, self.config.solver_order)
if current_order == 1:
coeffs = [sigma_next - sigma_t]
else:
ts = [self._sigmas_cpu[i - j] for j in range(current_order)]
t_next = sigma_next
if current_order == 2:
t_cur, t_prev1 = ts[0], ts[1]
coeff_cur = ((t_next - t_prev1) ** 2 - (t_cur - t_prev1) ** 2) / (2 * (t_cur - t_prev1))
coeff_prev1 = (t_next - t_cur) ** 2 / (2 * (t_prev1 - t_cur))
coeffs = [coeff_cur, coeff_prev1]
elif current_order == 3:
t_cur, t_prev1, t_prev2 = ts[0], ts[1], ts[2]
coeffs = [
get_def_integral_2(t_prev1, t_prev2, t_cur, t_next, t_cur),
get_def_integral_2(t_cur, t_prev2, t_cur, t_next, t_prev1),
get_def_integral_2(t_cur, t_prev1, t_cur, t_next, t_prev2),
]
elif current_order == 4:
t_cur, t_prev1, t_prev2, t_prev3 = ts[0], ts[1], ts[2], ts[3]
coeffs = [
get_def_integral_3(t_prev1, t_prev2, t_prev3, t_cur, t_next, t_cur),
get_def_integral_3(t_cur, t_prev2, t_prev3, t_cur, t_next, t_prev1),
get_def_integral_3(t_cur, t_prev1, t_prev3, t_cur, t_next, t_prev2),
get_def_integral_3(t_cur, t_prev1, t_prev2, t_cur, t_next, t_prev3),
]
else:
coeffs = [(sigma_next - sigma_t) / sigma_t] # Fallback to Euler
self.all_coeffs.append(coeffs)
# Reset history
self.model_outputs = []
self.hist_samples = []
self._step_index = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def _init_step_index(self, timestep):
if self._step_index is None:
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma_t = self.sigmas[step_index]
# RECONSTRUCT X0 (Matching PEC pattern)
if self.config.prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif self.config.prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {self.config.prediction_type}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
if self.config.prediction_type == "flow_prediction":
# Variable Step Adams-Bashforth for Flow Matching
self.model_outputs.append(model_output)
self.prev_sigmas.append(sigma_t)
# Note: deis uses hist_samples for x0? I'll use model_outputs for v.
if len(self.model_outputs) > 4:
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
dt = self.sigmas[step_index + 1] - sigma_t
v_n = model_output
curr_order = min(len(self.prev_sigmas), 3)
if curr_order == 1:
x_next = sample + dt * v_n
elif curr_order == 2:
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma_t - sigma_prev
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
if dt_prev == 0 or r < -0.9 or r > 2.0:
x_next = sample + dt * v_n
else:
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
else:
# AB2 fallback
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma_t - sigma_prev
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
sigma_next = self.sigmas[step_index + 1]
if self.config.solver_order == 1:
# 1st order step (Euler) in x-space
x_next = (sigma_next / sigma_t) * sample + (1 - sigma_next / sigma_t) * denoised
prev_sample = x_next
else:
# Multistep weights based on phi functions (consistent with RESMultistep)
h = -torch.log(sigma_next / sigma_t) if sigma_t > 0 and sigma_next > 0 else torch.zeros_like(sigma_t)
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
phi_1 = phi(1)
# History of denoised samples
x0s = [denoised] + self.model_outputs[::-1]
orders = min(len(x0s), self.config.solver_order)
# Force Order 1 at the end of schedule
if self.num_inference_steps is not None and step_index >= self.num_inference_steps - 3:
res = phi_1 * denoised
elif orders == 1:
res = phi_1 * denoised
elif orders == 2:
# Use phi(2) for 2nd order interpolation
h_prev = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
h_prev_t = torch.tensor(h_prev, device=sample.device, dtype=sample.dtype)
r = h_prev_t / (h + 1e-9)
h_prev = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
h_prev_t = torch.tensor(h_prev, device=sample.device, dtype=sample.dtype)
r = h_prev_t / (h + 1e-9)
# Hard Restart
if r < 0.5 or r > 2.0:
res = phi_1 * denoised
else:
phi_2 = phi(2)
# Correct Adams-Bashforth-like coefficients: b2 = -phi_2 / r
b2 = -phi_2 / (r + 1e-9)
b1 = phi_1 - b2
res = b1 * x0s[0] + b2 * x0s[1]
elif orders == 3:
# 3rd order with varying step sizes
# 3rd order with varying step sizes
h_p1 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
h_p2 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 2] + 1e-9))
r1 = torch.tensor(h_p1, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
r2 = torch.tensor(h_p2, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
h_p1 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
h_p2 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 2] + 1e-9))
r1 = torch.tensor(h_p1, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
r2 = torch.tensor(h_p2, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
# Hard Restart
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
res = phi_1 * denoised
else:
phi_2, phi_3 = phi(2), phi(3)
denom = r2 - r1 + 1e-9
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
b1 = phi_1 - b2 - b3
res = b1 * x0s[0] + b2 * x0s[1] + b3 * x0s[2]
else:
# Fallback to Euler or lower order
res = phi_1 * denoised
# Stable update in x-space
if sigma_next == 0:
x_next = denoised
else:
x_next = torch.exp(-h) * sample + h * res
prev_sample = x_next
# Store state (always store x0)
self.model_outputs.append(denoised)
self.hist_samples.append(sample)
if len(self.model_outputs) > 4:
self.model_outputs.pop(0)
self.hist_samples.pop(0)
if self._step_index is not None:
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,285 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
from .phi_functions import Phi
logger = logging.get_logger(__name__)
class ETDRKScheduler(SchedulerMixin, ConfigMixin):
"""
Exponential Time Differencing Runge-Kutta (ETDRK) scheduler.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: Literal["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"] = "etdrk4_4s",
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Buffer for multistage/multistep
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
self.model_outputs.append(model_output)
self.x0_outputs.append(x0)
self.prev_sigmas.append(sigma)
variant = self.config.variant
if sigma_next == 0:
x_next = x0
else:
# ETDRK coefficients
if variant == "etdrk2_2s":
ci = [0.0, 1.0]
phi = Phi(h, ci, self.config.use_analytic_solution)
if len(self.x0_outputs) < 2:
res = phi(1) * x0
else:
eps_1, eps_2 = self.x0_outputs[-2:]
b2 = phi(2)
b1 = phi(1) - b2
res = b1 * eps_1 + b2 * eps_2
elif variant == "etdrk3_b_3s":
ci = [0, 4/9, 2/3]
phi = Phi(h, ci, self.config.use_analytic_solution)
if len(self.x0_outputs) < 3:
res = phi(1) * x0
else:
eps_1, eps_2, eps_3 = self.x0_outputs[-3:]
b3 = (3/2) * phi(2)
b2 = 0
b1 = phi(1) - b3
res = b1 * eps_1 + b2 * eps_2 + b3 * eps_3
elif variant == "etdrk4_4s":
ci = [0, 1/2, 1/2, 1]
phi = Phi(h, ci, self.config.use_analytic_solution)
if len(self.x0_outputs) < 4:
res = phi(1) * x0
else:
e1, e2, e3, e4 = self.x0_outputs[-4:]
b2 = 2*phi(2) - 4*phi(3)
b3 = 2*phi(2) - 4*phi(3)
b4 = -phi(2) + 4*phi(3)
b1 = phi(1) - (b2 + b3 + b4)
res = b1 * e1 + b2 * e2 + b3 * e3 + b4 * e4
else:
res = Phi(h, [0], self.config.use_analytic_solution)(1) * x0
# Exponential Integrator Update
x_next = torch.exp(-h) * sample + h * res
self._step_index += 1
# Buffer control
limit = 4 if variant.startswith("etdrk4") else 3
if len(self.x0_outputs) > limit:
self.x0_outputs.pop(0)
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,384 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
class GaussLegendreScheduler(SchedulerMixin, ConfigMixin):
"""
GaussLegendreScheduler: High-accuracy implicit symplectic integrators.
Supports various orders (2s, 3s, 4s, 5s, 8s-diagonal).
Adapted from the RES4LYF repository.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: str = "gauss-legendre_2s", # 2s to 8s variants
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.init_noise_sigma = 1.0
# Internal state
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
def _get_tableau(self):
v = self.config.variant
if v == "gauss-legendre_2s":
r3 = 3**0.5
a = [[1 / 4, 1 / 4 - r3 / 6], [1 / 4 + r3 / 6, 1 / 4]]
b = [1 / 2, 1 / 2]
c = [1 / 2 - r3 / 6, 1 / 2 + r3 / 6]
elif v == "gauss-legendre_3s":
r15 = 15**0.5
a = [[5 / 36, 2 / 9 - r15 / 15, 5 / 36 - r15 / 30], [5 / 36 + r15 / 24, 2 / 9, 5 / 36 - r15 / 24], [5 / 36 + r15 / 30, 2 / 9 + r15 / 15, 5 / 36]]
b = [5 / 18, 4 / 9, 5 / 18]
c = [1 / 2 - r15 / 10, 1 / 2, 1 / 2 + r15 / 10]
elif v == "gauss-legendre_4s":
r15 = 15**0.5
a = [[1 / 4, 1 / 4 - r15 / 6, 1 / 4 + r15 / 6, 1 / 4], [1 / 4 + r15 / 6, 1 / 4, 1 / 4 - r15 / 6, 1 / 4], [1 / 4, 1 / 4 + r15 / 6, 1 / 4, 1 / 4 - r15 / 6], [1 / 4 - r15 / 6, 1 / 4, 1 / 4 + r15 / 6, 1 / 4]]
b = [1 / 8, 3 / 8, 3 / 8, 1 / 8]
c = [1 / 2 - r15 / 10, 1 / 2 + r15 / 10, 1 / 2 + r15 / 10, 1 / 2 - r15 / 10]
elif v == "gauss-legendre_5s":
r739 = 739**0.5
a = [
[
4563950663 / 32115191526,
(310937500000000 / 2597974476091533 + 45156250000 * r739 / 8747388808389),
(310937500000000 / 2597974476091533 - 45156250000 * r739 / 8747388808389),
(5236016175 / 88357462711 + 709703235 * r739 / 353429850844),
(5236016175 / 88357462711 - 709703235 * r739 / 353429850844),
],
[
(4563950663 / 32115191526 - 38339103 * r739 / 6250000000),
(310937500000000 / 2597974476091533 + 9557056475401 * r739 / 3498955523355600000),
(310937500000000 / 2597974476091533 - 14074198220719489 * r739 / 3498955523355600000),
(5236016175 / 88357462711 + 5601362553163918341 * r739 / 2208936567775000000000),
(5236016175 / 88357462711 - 5040458465159165409 * r739 / 2208936567775000000000),
],
[
(4563950663 / 32115191526 + 38339103 * r739 / 6250000000),
(310937500000000 / 2597974476091533 + 14074198220719489 * r739 / 3498955523355600000),
(310937500000000 / 2597974476091533 - 9557056475401 * r739 / 3498955523355600000),
(5236016175 / 88357462711 + 5040458465159165409 * r739 / 2208936567775000000000),
(5236016175 / 88357462711 - 5601362553163918341 * r739 / 2208936567775000000000),
],
[
(4563950663 / 32115191526 - 38209 * r739 / 7938810),
(310937500000000 / 2597974476091533 - 359369071093750 * r739 / 70145310854471391),
(310937500000000 / 2597974476091533 - 323282178906250 * r739 / 70145310854471391),
(5236016175 / 88357462711 - 470139 * r739 / 1413719403376),
(5236016175 / 88357462711 - 44986764863 * r739 / 21205791050640),
],
[
(4563950663 / 32115191526 + 38209 * r739 / 7938810),
(310937500000000 / 2597974476091533 + 359369071093750 * r739 / 70145310854471391),
(310937500000000 / 2597974476091533 + 323282178906250 * r739 / 70145310854471391),
(5236016175 / 88357462711 + 44986764863 * r739 / 21205791050640),
(5236016175 / 88357462711 + 470139 * r739 / 1413719403376),
],
]
b = [4563950663 / 16057595763, 621875000000000 / 2597974476091533, 621875000000000 / 2597974476091533, 10472032350 / 88357462711, 10472032350 / 88357462711]
c = [1 / 2, 1 / 2 - 99 * r739 / 10000, 1 / 2 + 99 * r739 / 10000, 1 / 2 - r739 / 60, 1 / 2 + r739 / 60]
elif v == "gauss-legendre_diag_8s":
a = [
[0.5, 0, 0, 0, 0, 0, 0, 0],
[1.0818949631055815, 0.5, 0, 0, 0, 0, 0, 0],
[0.9599572962220549, 1.0869589243008327, 0.5, 0, 0, 0, 0, 0],
[1.0247213458032004, 0.9550588736973743, 1.0880938387323083, 0.5, 0, 0, 0, 0],
[0.9830238267636289, 1.0287597754747493, 0.9538345351852, 1.0883471611098278, 0.5, 0, 0, 0],
[1.0122259141132982, 0.9799828723635913, 1.0296038730649779, 0.9538345351852, 1.0880938387323083, 0.5, 0, 0],
[0.9912514332308026, 1.0140743558891669, 0.9799828723635913, 1.0287597754747493, 0.9550588736973743, 1.0869589243008327, 0.5, 0],
[1.0054828082532159, 0.9912514332308026, 1.0122259141132982, 0.9830238267636289, 1.0247213458032004, 0.9599572962220549, 1.0818949631055815, 0.5],
]
b = [0.05061426814518813, 0.11119051722668724, 0.15685332293894364, 0.181341891689181, 0.181341891689181, 0.15685332293894364, 0.11119051722668724, 0.05061426814518813]
c = [0.019855071751231884, 0.10166676129318663, 0.2372337950418355, 0.4082826787521751, 0.5917173212478249, 0.7627662049581645, 0.8983332387068134, 0.9801449282487681]
else:
raise ValueError(f"Unknown variant: {v}")
return np.array(a), np.array(b), np.array(c)
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
# 2. Sigma Schedule
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# We handle multi-history expansion
_a_mat, _b_vec, c_vec = self._get_tableau()
len(c_vec)
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
for c_val in c_vec:
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
sigmas_expanded.append(0.0)
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def _init_step_index(self, timestep):
if self._step_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
a_mat, b_vec, c_vec = self._get_tableau()
num_stages = len(c_vec)
stage_index = step_index % num_stages
base_step_index = (step_index // num_stages) * num_stages
sigma_curr = self.sigmas[base_step_index]
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
sigma_next = self.sigmas[sigma_next_idx]
if sigma_next <= 0:
sigma_t = self.sigmas[step_index]
prediction_type = getattr(self.config, "prediction_type", "epsilon")
if prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
else:
denoised = model_output
if getattr(self.config, "clip_sample", False):
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
prev_sample = denoised
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
h = sigma_next - sigma_curr
sigma_t = self.sigmas[step_index]
prediction_type = getattr(self.config, "prediction_type", "epsilon")
if prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {prediction_type}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self.sigmas[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
else:
self.model_outputs.append(derivative)
# Predict sample for next stage
next_stage_idx = stage_index + 1
if next_stage_idx < num_stages:
sum_ak = 0
for j in range(len(self.model_outputs)):
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
sigma_next_stage = self.sigmas[min(step_index + 1, len(self.sigmas) - 1)]
# Update x (unnormalized sample)
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
else:
# Final step update using b coefficients
sum_bk = 0
for j in range(len(self.model_outputs)):
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_bk
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,249 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import ClassVar, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
"""
Langevin Dynamics sigma scheduler using Exponential Integrator step.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
temperature: float = 0.5,
friction: float = 1.0,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# Setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = torch.zeros((num_train_timesteps,), dtype=torch.float32)
self._step_index = None
self._begin_index = None
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
generator: Optional[torch.Generator] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
# Discretization parameters for Langevin schedule generation
dt = 1.0 / num_inference_steps
sqrt_2dt = math.sqrt(2 * dt)
start_sigma = 10.0
if hasattr(self, "alphas_cumprod"):
start_sigma = float(((1 - self.alphas_cumprod[-1]) / self.alphas_cumprod[-1]) ** 0.5)
end_sigma = 0.01
def grad_U(x):
return x - end_sigma
x = torch.tensor([start_sigma], dtype=dtype)
v = torch.zeros(1)
trajectory = [start_sigma]
temperature = self.config.temperature
friction = self.config.friction
for _ in range(num_inference_steps - 1):
v = v - dt * friction * v - dt * grad_U(x) / 2
x = x + dt * v
noise = torch.randn(1, generator=generator) * sqrt_2dt * temperature
v = v - dt * friction * v - dt * grad_U(x) / 2 + noise
trajectory.append(x.item())
sigmas = np.array(trajectory)
# Force monotonicity to prevent negative h in step()
sigmas = np.sort(sigmas)[::-1]
sigmas[-1] = end_sigma
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(np.linspace(1000, 0, num_inference_steps)).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
# Determine denoised (x_0 prediction)
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1)**0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
# Exponential Integrator Update
if sigma_next == 0:
x_next = x0
else:
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,277 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class LawsonScheduler(SchedulerMixin, ConfigMixin):
"""
Lawson's integration method scheduler.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: Literal["lawson2a_2s", "lawson2b_2s", "lawson4_4s"] = "lawson4_4s",
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Buffer for multistage/multistep
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
exp_h = torch.exp(-h)
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
self.model_outputs.append(model_output)
self.x0_outputs.append(x0)
self.prev_sigmas.append(sigma)
variant = self.config.variant
if sigma_next == 0:
x_next = x0
else:
# Lawson coefficients (anchored at x0)
if variant == "lawson2a_2s":
if len(self.x0_outputs) < 2:
res = (1 - exp_h) / h * x0
else:
x0_1, x0_2 = self.x0_outputs[-2:]
# b2 = exp(-h/2)
# b1 = phi(1) - b2? No, Lawson is different.
# But if we want it to be a valid exponential integrator,
# we use the Lawson-specific weighting.
res = torch.exp(-h/2) * x0_2
elif variant == "lawson2b_2s":
if len(self.x0_outputs) < 2:
res = (1 - exp_h) / h * x0
else:
x0_1, x0_2 = self.x0_outputs[-2:]
res = 0.5 * exp_h * x0_1 + 0.5 * x0_2
elif variant == "lawson4_4s":
if len(self.x0_outputs) < 4:
res = (1 - exp_h) / h * x0
else:
e1, e2, e3, e4 = self.x0_outputs[-4:]
b1 = (1/6) * exp_h
b2 = (1/3) * torch.exp(-h/2)
b3 = (1/3) * torch.exp(-h/2)
b4 = 1/6
res = b1 * e1 + b2 * e2 + b3 * e3 + b4 * e4
else:
res = (1 - exp_h) / h * x0
# Update
x_next = exp_h * sample + h * res
self._step_index += 1
# Buffer control
limit = 4 if variant == "lawson4_4s" else 2
if len(self.x0_outputs) > limit:
self.x0_outputs.pop(0)
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,321 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
class LinearRKScheduler(SchedulerMixin, ConfigMixin):
"""
LinearRKScheduler: Standard explicit Runge-Kutta integrators.
Supports Ralston, Midpoint, Heun, Kutta, and standard RK4.
Adapted from the RES4LYF repository.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: str = "rk4", # euler, heun, rk2, rk3, rk4, ralston, midpoint
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.init_noise_sigma = 1.0
# Internal state
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
def _get_tableau(self):
v = str(self.config.variant).lower().strip()
if v in ["ralston", "ralston_2s"]:
a, b, c = [[2 / 3]], [1 / 4, 3 / 4], [0, 2 / 3]
elif v in ["midpoint", "midpoint_2s"]:
a, b, c = [[1 / 2]], [0, 1], [0, 1 / 2]
elif v in ["heun", "heun_2s"]:
a, b, c = [[1]], [1 / 2, 1 / 2], [0, 1]
elif v == "heun_3s":
a, b, c = [[1 / 3], [0, 2 / 3]], [1 / 4, 0, 3 / 4], [0, 1 / 3, 2 / 3]
elif v in ["kutta", "kutta_3s"]:
a, b, c = [[1 / 2], [-1, 2]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
elif v in ["rk4", "rk4_4s"]:
a, b, c = [[1 / 2], [0, 1 / 2], [0, 0, 1]], [1 / 6, 1 / 3, 1 / 3, 1 / 6], [0, 1 / 2, 1 / 2, 1]
elif v in ["rk2", "heun"]:
a, b, c = [[1]], [1 / 2, 1 / 2], [0, 1]
elif v == "rk3":
a, b, c = [[1 / 2], [-1, 2]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
elif v == "euler":
a, b, c = [], [1], [0]
else:
raise ValueError(f"Unknown variant: {v}")
# Expand 'a' to full matrix
stages = len(c)
full_a = np.zeros((stages, stages))
for i, row in enumerate(a):
full_a[i + 1, : len(row)] = row
return full_a, np.array(b), np.array(c)
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
# 2. Sigma Schedule
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# We handle multi-history expansion
_a_mat, _b_vec, c_vec = self._get_tableau()
len(c_vec)
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
for c_val in c_vec:
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
sigmas_expanded.append(0.0)
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def _init_step_index(self, timestep):
if self._step_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
num_stages = len(c_vec)
stage_index = self._step_index % num_stages
base_step_index = (self._step_index // num_stages) * num_stages
sigma_curr = self.sigmas[base_step_index]
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
sigma_next = self.sigmas[sigma_next_idx]
if sigma_next <= 0:
sigma_t = self.sigmas[self._step_index]
denoised = sample - sigma_t * model_output if self.config.prediction_type == "epsilon" else model_output
prev_sample = denoised
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
h = sigma_next - sigma_curr
sigma_t = self.sigmas[self._step_index]
if self.config.prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif self.config.prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type {self.config.prediction_type} is not supported.")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self.sigmas[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
else:
self.model_outputs.append(derivative)
next_stage_idx = stage_index + 1
if next_stage_idx < num_stages:
sum_ak = 0
for j in range(len(self.model_outputs)):
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
sigma_next_stage = self.sigmas[self._step_index + 1]
# Update x (unnormalized sample)
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
else:
sum_bk = 0
for j in range(len(self.model_outputs)):
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_bk
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,321 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
# pylint: disable=no-member
class LobattoScheduler(SchedulerMixin, ConfigMixin):
"""
LobattoScheduler: High-accuracy implicit integrators from the Lobatto family.
Supports variants IIIA, IIIB, IIIC, IIIC*, IIID (orders 2, 3, 4).
Adapted from the RES4LYF repository.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: str = "lobatto_iiia_3s", # Available: iiia, iiib, iiic
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.init_noise_sigma = 1.0
# Internal state
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
def _get_tableau(self):
v = self.config.variant
r5 = 5**0.5
if v == "lobatto_iiia_2s":
a, b, c = [[0, 0], [1 / 2, 1 / 2]], [1 / 2, 1 / 2], [0, 1]
elif v == "lobatto_iiia_3s":
a, b, c = [[0, 0, 0], [5 / 24, 1 / 3, -1 / 24], [1 / 6, 2 / 3, 1 / 6]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
elif v == "lobatto_iiia_4s":
a = [[0, 0, 0, 0], [(11 + r5) / 120, (25 - r5) / 120, (25 - 13 * r5) / 120, (-1 + r5) / 120], [(11 - r5) / 120, (25 + 13 * r5) / 120, (25 + r5) / 120, (-1 - r5) / 120], [1 / 12, 5 / 12, 5 / 12, 1 / 12]]
b = [1 / 12, 5 / 12, 5 / 12, 1 / 12]
c = [0, (5 - r5) / 10, (5 + r5) / 10, 1]
elif v == "lobatto_iiib_2s":
a, b, c = [[1 / 2, 0], [1 / 2, 0]], [1 / 2, 1 / 2], [0, 1]
elif v == "lobatto_iiib_3s":
a, b, c = [[1 / 6, -1 / 6, 0], [1 / 6, 1 / 3, 0], [1 / 6, 5 / 6, 0]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
elif v == "lobatto_iiic_2s":
a, b, c = [[1 / 2, -1 / 2], [1 / 2, 1 / 2]], [1 / 2, 1 / 2], [0, 1]
elif v == "lobatto_iiic_3s":
a, b, c = [[1 / 6, -1 / 3, 1 / 6], [1 / 6, 5 / 12, -1 / 12], [1 / 6, 2 / 3, 1 / 6]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
elif v == "kraaijevanger_spijker_2s":
a, b, c = [[1 / 2, 0], [-1 / 2, 2]], [-1 / 2, 3 / 2], [1 / 2, 3 / 2]
elif v == "qin_zhang_2s":
a, b, c = [[1 / 4, 0], [1 / 2, 1 / 4]], [1 / 2, 1 / 2], [1 / 4, 3 / 4]
elif v == "pareschi_russo_2s":
gamma = 1 - 2**0.5 / 2
a, b, c = [[gamma, 0], [1 - 2 * gamma, gamma]], [1 / 2, 1 / 2], [gamma, 1 - gamma]
else:
raise ValueError(f"Unknown variant: {v}")
return np.array(a), np.array(b), np.array(c)
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
# 2. Sigma Schedule
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# We handle multi-history expansion
_a_mat, _b_vec, c_vec = self._get_tableau()
len(c_vec)
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
for c_val in c_vec:
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
sigmas_expanded.append(0.0) # Add the final sigma=0 for the last step
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def _init_step_index(self, timestep):
if self._step_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
num_stages = len(c_vec)
stage_index = self._step_index % num_stages
base_step_index = (self._step_index // num_stages) * num_stages
sigma_curr = self.sigmas[base_step_index]
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
sigma_next = self.sigmas[sigma_next_idx]
if sigma_next <= 0:
sigma_t = self.sigmas[self._step_index]
denoised = sample - sigma_t * model_output if self.config.prediction_type == "epsilon" else model_output
prev_sample = denoised
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
h = sigma_next - sigma_curr
sigma_t = self.sigmas[self._step_index]
if self.config.prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif self.config.prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {getattr(self.config, 'prediction_type', 'epsilon')}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self.sigmas[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
else:
self.model_outputs.append(derivative)
next_stage_idx = stage_index + 1
if next_stage_idx < num_stages:
sum_ak = 0
for j in range(len(self.model_outputs)):
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
sigma_next_stage = self.sigmas[self._step_index + 1]
# Update x (unnormalized sample)
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
else:
sum_bk = 0
for j in range(len(self.model_outputs)):
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_bk
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,275 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
from .phi_functions import Phi
logger = logging.get_logger(__name__)
class PECScheduler(SchedulerMixin, ConfigMixin):
"""
Predictor-Corrector (PEC) scheduler.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: Literal["pec423_2h2s", "pec433_2h3s"] = "pec423_2h2s",
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Buffer for multistep
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
dtype: torch.dtype = torch.float32,
):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
# This x0 is actually a * x0 in discrete NSR space
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
# This x0 is the true clean x0
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
self.model_outputs.append(model_output)
self.x0_outputs.append(x0)
self.prev_sigmas.append(sigma)
variant = self.config.variant
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
if sigma_next == 0:
x_next = x0
else:
# PEC coefficients (anchored at x0)
if variant == "pec423_2h2s":
if len(self.x0_outputs) < 2:
res = phi(1) * x0
else:
x0_n, x0_p1 = self.x0_outputs[-2:]
b2 = (1/3)*phi(2) + phi(3) + phi(4)
b1 = phi(1) - b2
res = b1 * x0_n + b2 * x0_p1
elif variant == "pec433_2h3s":
if len(self.x0_outputs) < 3:
res = phi(1) * x0
else:
x0_n, x0_p1, x0_p2 = self.x0_outputs[-3:]
b3 = (1/3)*phi(2) + phi(3) + phi(4)
b2 = 0
b1 = phi(1) - b3
res = b1 * x0_n + b2 * x0_p1 + b3 * x0_p2
else:
res = phi(1) * x0
# Update in x-space
x_next = torch.exp(-h) * sample + h * res
self._step_index += 1
# Buffer control
limit = 3 if variant == "pec433_2h3s" else 2
if len(self.x0_outputs) > limit:
self.x0_outputs.pop(0)
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,143 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Tuple, Union
import torch
from mpmath import exp as mp_exp
from mpmath import factorial as mp_factorial
from mpmath import mp, mpf
# Set precision for mpmath
mp.dps = 80
def calculate_gamma(c2: float, c3: float) -> float:
"""Calculates the gamma parameter for RES 3s samplers."""
return (3 * (c3**3) - 2 * c3) / (c2 * (2 - 3 * c2))
def _torch_factorial(n: int) -> float:
return float(math.factorial(n))
def phi_standard_torch(j: int, neg_h: torch.Tensor) -> torch.Tensor:
r"""
Standard implementation of phi functions using torch.
ϕj(-h) = (e^(-h) - \sum_{k=0}^{j-1} (-h)^k / k!) / (-h)^j
For h=0, ϕj(0) = 1/j!
"""
assert j > 0
# Handle h=0 case
if torch.all(neg_h == 0):
return torch.full_like(neg_h, 1.0 / _torch_factorial(j))
# We use double precision for the series to avoid early overflow/precision loss
orig_dtype = neg_h.dtype
neg_h = neg_h.to(torch.float64)
# For very small h, use series expansion to avoid 0/0
if torch.any(torch.abs(neg_h) < 1e-4):
# 1/j! + z/(j+1)! + z^2/(2!(j+2)!) ...
result = torch.full_like(neg_h, 1.0 / _torch_factorial(j))
term = torch.full_like(neg_h, 1.0 / _torch_factorial(j))
for k in range(1, 5):
term = term * neg_h / (j + k)
result += term
return result.to(orig_dtype)
remainder = torch.zeros_like(neg_h)
for k in range(j):
remainder += (neg_h**k) / _torch_factorial(k)
phi_val = (torch.exp(neg_h) - remainder) / (neg_h**j)
return phi_val.to(orig_dtype)
def phi_mpmath_series(j: int, neg_h: float) -> float:
"""Arbitrary-precision phi_j(-h) via series definition."""
j = int(j)
z = mpf(float(neg_h))
# Handle h=0 case: phi_j(0) = 1/j!
if z == 0:
return float(1.0 / mp_factorial(j))
s_val = mp.mpf("0")
for k in range(j):
s_val += (z**k) / mp_factorial(k)
phi_val = (mp_exp(z) - s_val) / (z**j)
return float(phi_val)
class Phi:
"""
Class to manage phi function calculations and caching.
Supports both standard torch-based and high-precision mpmath-based solutions.
"""
def __init__(self, h: torch.Tensor, c: List[Union[float, mpf]], analytic_solution: bool = True):
self.h = h
self.c = c
self.cache: Dict[Tuple[int, int], Union[float, torch.Tensor]] = {}
self.analytic_solution = analytic_solution
if analytic_solution:
self.phi_f = phi_mpmath_series
self.h_mpf = mpf(float(h))
self.c_mpf = [mpf(float(c_val)) for c_val in c]
else:
self.phi_f = phi_standard_torch
def __call__(self, j: int, i: int = -1) -> Union[float, torch.Tensor]:
if (j, i) in self.cache:
return self.cache[(j, i)]
if i < 0:
c_val = 1.0
else:
c_val = self.c[i - 1]
if c_val == 0:
self.cache[(j, i)] = 0.0
return 0.0
if self.analytic_solution:
h_val = self.h_mpf
c_mapped = self.c_mpf[i - 1] if i >= 0 else 1.0
if j == 0:
result = float(mp_exp(-h_val * c_mapped))
else:
# Use the mpmath internal function for higher precision
z = -h_val * c_mapped
if z == 0:
result = float(1.0 / mp_factorial(j))
else:
s_val = mp.mpf("0")
for k in range(j):
s_val += (z**k) / mp_factorial(k)
result = float((mp_exp(z) - s_val) / (z**j))
else:
h_val = self.h
c_mapped = float(c_val)
if j == 0:
result = torch.exp(-h_val * c_mapped)
else:
result = self.phi_f(j, -h_val * c_mapped)
self.cache[(j, i)] = result
return result

View File

@ -0,0 +1,364 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
# pylint: disable=no-member
class RadauIIAScheduler(SchedulerMixin, ConfigMixin):
"""
RadauIIAScheduler: Fully implicit Runge-Kutta integrators.
Supports variants with 2, 3, 5, 7, 9, 11 stages.
Adapted from the RES4LYF repository.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
variant: str = "radau_iia_3s", # 2s to 11s variants
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.init_noise_sigma = 1.0
# Internal state
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
def _get_tableau(self):
v = self.config.variant
if v == "radau_iia_2s":
a, b, c = [[5 / 12, -1 / 12], [3 / 4, 1 / 4]], [3 / 4, 1 / 4], [1 / 3, 1]
elif v == "radau_iia_3s":
r6 = 6**0.5
a = [[11 / 45 - 7 * r6 / 360, 37 / 225 - 169 * r6 / 1800, -2 / 225 + r6 / 75], [37 / 225 + 169 * r6 / 1800, 11 / 45 + 7 * r6 / 360, -2 / 225 - r6 / 75], [4 / 9 - r6 / 36, 4 / 9 + r6 / 36, 1 / 9]]
b, c = [4 / 9 - r6 / 36, 4 / 9 + r6 / 36, 1 / 9], [2 / 5 - r6 / 10, 2 / 5 + r6 / 10, 1]
elif v == "radau_iia_5s":
a = [
[0.07299886, -0.02673533, 0.01867693, -0.01287911, 0.00504284],
[0.15377523, 0.14621487, -0.03644457, 0.02123306, -0.00793558],
[0.14006305, 0.29896713, 0.16758507, -0.03396910, 0.01094429],
[0.14489431, 0.27650007, 0.32579792, 0.12875675, -0.01570892],
[0.14371356, 0.28135602, 0.31182652, 0.22310390, 0.04000000],
]
b = [0.14371356, 0.28135602, 0.31182652, 0.22310390, 0.04]
c = [0.05710420, 0.27684301, 0.58359043, 0.86024014, 1.0]
elif v == "radau_iia_7s":
a = [
[0.03754626, -0.01403933, 0.01035279, -0.00815832, 0.00638841, -0.00460233, 0.00182894],
[0.08014760, 0.08106206, -0.02123799, 0.01400029, -0.01023419, 0.00715347, -0.00281264],
[0.07206385, 0.17106835, 0.10961456, -0.02461987, 0.01476038, -0.00957526, 0.00367268],
[0.07570513, 0.15409016, 0.22710774, 0.11747819, -0.02381083, 0.01270999, -0.00460884],
[0.07391234, 0.16135561, 0.20686724, 0.23700712, 0.10308679, -0.01885414, 0.00585890],
[0.07470556, 0.15830722, 0.21415342, 0.21987785, 0.19875212, 0.06926550, -0.00811601],
[0.07449424, 0.15910212, 0.21235189, 0.22355491, 0.19047494, 0.11961374, 0.02040816],
]
b = [0.07449424, 0.15910212, 0.21235189, 0.22355491, 0.19047494, 0.11961374, 0.02040816]
c = [0.02931643, 0.14807860, 0.33698469, 0.55867152, 0.76923386, 0.92694567, 1.0]
elif v == "radau_iia_9s":
a = [
[0.02278838, -0.00858964, 0.00645103, -0.00525753, 0.00438883, -0.00365122, 0.00294049, -0.00214927, 0.00085884],
[0.04890795, 0.05070205, -0.01352381, 0.00920937, -0.00715571, 0.00574725, -0.00454258, 0.00328816, -0.00130907],
[0.04374276, 0.10830189, 0.07291957, -0.01687988, 0.01070455, -0.00790195, 0.00599141, -0.00424802, 0.00167815],
[0.04624924, 0.09656073, 0.15429877, 0.08671937, -0.01845164, 0.01103666, -0.00767328, 0.00522822, -0.00203591],
[0.04483444, 0.10230685, 0.13821763, 0.18126393, 0.09043360, -0.01808506, 0.01019339, -0.00640527, 0.00242717],
[0.04565876, 0.09914547, 0.14574704, 0.16364828, 0.18594459, 0.08361326, -0.01580994, 0.00813825, -0.00291047],
[0.04520060, 0.10085371, 0.14194224, 0.17118947, 0.16978339, 0.16776829, 0.06707903, -0.01179223, 0.00360925],
[0.04541652, 0.10006040, 0.14365284, 0.16801908, 0.17556077, 0.15588627, 0.12889391, 0.04281083, -0.00493457],
[0.04535725, 0.10027665, 0.14319335, 0.16884698, 0.17413650, 0.15842189, 0.12359469, 0.07382701, 0.01234568],
]
b = [0.04535725, 0.10027665, 0.14319335, 0.16884698, 0.17413650, 0.15842189, 0.12359469, 0.07382701, 0.01234568]
c = [0.01777992, 0.09132361, 0.21430848, 0.37193216, 0.54518668, 0.71317524, 0.85563374, 0.95536604, 1.0]
elif v == "radau_iia_11s":
a = [
[0.01528052, -0.00578250, 0.00438010, -0.00362104, 0.00309298, -0.00267283, 0.00230509, -0.00195565, 0.00159387, -0.00117286, 0.00046993],
[0.03288398, 0.03451351, -0.00928542, 0.00641325, -0.00509546, 0.00424609, -0.00358767, 0.00300683, -0.00243267, 0.00178278, -0.00071315],
[0.02933250, 0.07416243, 0.05114868, -0.01200502, 0.00777795, -0.00594470, 0.00480266, -0.00392360, 0.00312733, -0.00227314, 0.00090638],
[0.03111455, 0.06578995, 0.10929963, 0.06381052, -0.01385359, 0.00855744, -0.00630764, 0.00491336, -0.00381400, 0.00273343, -0.00108397],
[0.03005269, 0.07011285, 0.09714692, 0.13539160, 0.07147108, -0.01471024, 0.00873319, -0.00619941, 0.00459164, -0.00321333, 0.00126286],
[0.03072807, 0.06751926, 0.10334060, 0.12083526, 0.15032679, 0.07350932, -0.01451288, 0.00829665, -0.00561283, 0.00376623, -0.00145771],
[0.03029202, 0.06914472, 0.09972096, 0.12801064, 0.13493180, 0.15289670, 0.06975993, -0.01327455, 0.00725877, -0.00448439, 0.00168785],
[0.03056654, 0.06813851, 0.10188107, 0.12403361, 0.14211432, 0.13829395, 0.14289135, 0.06052636, -0.01107774, 0.00559867, -0.00198773],
[0.03040663, 0.06871881, 0.10066096, 0.12619527, 0.13848876, 0.14450774, 0.13065189, 0.12111401, 0.04655548, -0.00802620, 0.00243764],
[0.03048412, 0.06843925, 0.10124185, 0.12518732, 0.14011843, 0.14190387, 0.13500343, 0.11262870, 0.08930604, 0.02896966, -0.00331170],
[0.03046255, 0.06851684, 0.10108155, 0.12546269, 0.13968067, 0.14258278, 0.13393354, 0.11443306, 0.08565881, 0.04992304, 0.00826446],
]
b = [0.03046255, 0.06851684, 0.10108155, 0.12546269, 0.13968067, 0.14258278, 0.13393354, 0.11443306, 0.08565881, 0.04992304, 0.00826446]
c = [0.01191761, 0.06173207, 0.14711145, 0.26115968, 0.39463985, 0.53673877, 0.67594446, 0.80097892, 0.90171099, 0.96997097, 1.0]
else:
raise ValueError(f"Unknown variant: {v}")
return np.array(a), np.array(b), np.array(c)
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
# 2. Sigma Schedule
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# We handle multi-history expansion
_a_mat, _b_vec, c_vec = self._get_tableau()
len(c_vec)
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
for c_val in c_vec:
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
sigmas_expanded.append(0.0)
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
if isinstance(schedule_timesteps, torch.Tensor):
schedule_timesteps = schedule_timesteps.detach().cpu().numpy()
if isinstance(timestep, torch.Tensor):
timestep = timestep.detach().cpu().numpy()
return np.abs(schedule_timesteps - timestep).argmin().item()
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def _init_step_index(self, timestep):
if self._step_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
a_mat, b_vec, c_vec = self._get_tableau()
num_stages = len(c_vec)
stage_index = self._step_index % num_stages
base_step_index = (self._step_index // num_stages) * num_stages
sigma_curr = self.sigmas[base_step_index]
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
sigma_next = self.sigmas[sigma_next_idx]
if sigma_next <= 0:
sigma_t = self.sigmas[self._step_index]
denoised = sample - sigma_t * model_output if getattr(self.config, "prediction_type", "epsilon") == "epsilon" else model_output
prev_sample = denoised
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
h = sigma_next - sigma_curr
sigma_t = self.sigmas[self._step_index]
prediction_type = getattr(self.config, "prediction_type", "epsilon")
if prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {getattr(self.config, 'prediction_type', 'epsilon')}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self.sigmas[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
else:
self.model_outputs.append(derivative)
next_stage_idx = stage_index + 1
if next_stage_idx < num_stages:
sum_ak = 0
for j in range(len(self.model_outputs)):
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
sigma_next_stage = self.sigmas[self._step_index + 1]
# Update x (unnormalized sample)
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
else:
sum_bk = 0
for j in range(len(self.model_outputs)):
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_bk
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,451 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
from .phi_functions import Phi, calculate_gamma
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class RESMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
RESMultistepScheduler (Restartable Exponential Integrator) ported from RES4LYF.
Supports RES 2M, 3M and DEIS 2M, 3M variants.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to "linear"):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
prediction_type (`str`, defaults to "epsilon"):
The prediction type of the scheduler function.
variant (`str`, defaults to "res_2m"):
The specific RES/DEIS variant to use. Supported: "res_2m", "res_3m", "deis_2m", "deis_3m".
use_analytic_solution (`bool`, defaults to True):
Whether to use high-precision analytic solutions for phi functions.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
prediction_type: str = "epsilon",
variant: Literal["res_2m", "res_3m", "deis_2m", "deis_3m"] = "res_2m",
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Buffer for multistep
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
# Linear remapping for Flow Matching
if self.config.use_flow_sigmas:
# Standardize linear spacing
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
# Already handled above, ensuring variable consistency
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
if self.config.use_flow_sigmas:
timesteps = sigmas * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self.lower_order_nums = 0
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step = self._step_index
sigma = self.sigmas[step]
sigma_next = self.sigmas[step + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0 (Matching PEC pattern)
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
raise ValueError(f"prediction_type {self.config.prediction_type} is not supported.")
self.model_outputs.append(model_output)
self.x0_outputs.append(x0)
self.prev_sigmas.append(sigma)
# Order logic
variant = self.config.variant
order = int(variant[-2]) if variant.endswith("m") else 1
# Effective order for current step
curr_order = min(len(self.prev_sigmas), order) if sigma > 0 else 1
if self.config.prediction_type == "flow_prediction":
# Variable Step Adams-Bashforth for Flow Matching
dt = sigma_next - sigma
v_n = model_output
if curr_order == 1:
x_next = sample + dt * v_n
elif curr_order == 2:
# AB2
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma - sigma_prev
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
# Stability check
if dt_prev == 0 or r < -0.9 or r > 2.0: # Fallback
x_next = sample + dt * v_n
else:
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
elif curr_order >= 3:
# Re-implement AB2 logic
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma - sigma_prev
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
self._step_index += 1
if len(self.model_outputs) > order:
self.model_outputs.pop(0)
self.x0_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
# Exponential Integrator Setup
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
phi_1 = phi(1)
if variant.startswith("res"):
# Force Order 1 at the end of schedule
if self.num_inference_steps is not None and self._step_index >= self.num_inference_steps - 3:
curr_order = 1
if curr_order == 2:
h_prev = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-2])
elif curr_order == 3:
pass
else:
pass
# Exponential Integrator Update in x-space
if curr_order == 1:
res = phi_1 * x0
elif curr_order == 2:
# b2 = -phi_2 / r
# b2 = -phi_2 / r = -phi(2) / (h_prev/h)
# Here we use: b2 = phi(2) / ((-h_prev / h) + 1e-9)
# Since (-h_prev/h) is negative (-r), this gives correct negative sign for b2.
# Stability check
r_check = h_prev / (h + 1e-9) # This is effectively -r if using h_prev definition above?
# Wait, h_prev above is -log(). Positive.
# h is positive.
# So h_prev/h is positive. defined as r in other files.
# But here code uses -h_prev / h in denominator.
# Stability check
r_check = h_prev / (h + 1e-9)
# Hard Restart
if r_check < 0.5 or r_check > 2.0:
res = phi_1 * x0
else:
b2 = phi(2) / ((-h_prev / h) + 1e-9)
b1 = phi_1 - b2
res = b1 * self.x0_outputs[-1] + b2 * self.x0_outputs[-2]
elif curr_order == 3:
# Generalized AB3 for Exponential Integrators
h_p1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
h_p2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
r1 = h_p1 / (h + 1e-9)
r2 = h_p2 / (h + 1e-9)
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
res = phi_1 * x0
else:
phi_2, phi_3 = phi(2), phi(3)
denom = r2 - r1 + 1e-9
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
b1 = phi_1 - b2 - b3
res = b1 * self.x0_outputs[-1] + b2 * self.x0_outputs[-2] + b3 * self.x0_outputs[-3]
else:
res = phi_1 * x0
if sigma_next == 0:
x_next = x0
else:
x_next = torch.exp(-h) * sample + h * res
else:
# DEIS logic (Linear multistep in log-sigma space)
b = self._get_deis_coefficients(curr_order, sigma, sigma_next)
# For DEIS, we apply b to the denoised estimates
res = torch.zeros_like(sample)
for i, b_val in enumerate(b[0]):
idx = len(self.x0_outputs) - 1 - i
if idx >= 0:
res += b_val * self.x0_outputs[idx]
# DEIS update in x-space
if sigma_next == 0:
x_next = x0
else:
x_next = torch.exp(-h) * sample + h * res
self._step_index += 1
if len(self.model_outputs) > order:
self.model_outputs.pop(0)
self.x0_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _get_res_coefficients(self, rk_type, h, c2, c3):
ci = [0, c2, c3]
phi = Phi(h, ci, getattr(self.config, "use_analytic_solution", True))
if rk_type == "res_2s":
b2 = phi(2) / (c2 + 1e-9)
b = [[phi(1) - b2, b2]]
a = [[0, 0], [c2 * phi(1, 2), 0]]
elif rk_type == "res_3s":
gamma_val = calculate_gamma(c2, c3)
b3 = phi(2) / (gamma_val * c2 + c3 + 1e-9)
b2 = gamma_val * b3
b = [[phi(1) - (b2 + b3), b2, b3]]
a = [] # Simplified
else:
b = [[phi(1)]]
a = [[0]]
return a, b, ci
def _get_deis_coefficients(self, order, sigma, sigma_next):
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
phi_1 = phi(1)
if order == 1:
return [[phi_1]]
elif order == 2:
h_prev = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
r = h_prev / (h + 1e-9)
# Correct Adams-Bashforth-like coefficients for Exponential Integrators
# Hard Restart for stability
if r < 0.5 or r > 2.0:
return [[phi_1]]
phi_2 = phi(2)
b2 = -phi_2 / (r + 1e-9)
b1 = phi_1 - b2
return [[b1, b2]]
elif order == 3:
h_prev1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
h_prev2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
r1 = h_prev1 / (h + 1e-9)
r2 = h_prev2 / (h + 1e-9)
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
return [[phi_1]]
phi_2 = phi(2)
phi_3 = phi(3)
# Generalized AB3 for Exponential Integrators (Varying steps)
denom = r2 - r1 + 1e-9
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
b1 = phi_1 - (b2 + b3)
return [[b1, b2, b3]]
else:
return [[phi_1]]
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,330 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin):
"""
RESMultistepSDEScheduler (Stochastic Exponential Integrator) ported from RES4LYF.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
variant (`str`, defaults to "res_2m"):
The specific RES/DEIS variant to use. Supported: "res_2m", "res_3m".
eta (`float`, defaults to 1.0):
The amount of noise to add during sampling (0.0 for ODE, 1.0 for full SDE).
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
prediction_type: str = "epsilon",
variant: Literal["res_2m", "res_3m"] = "res_2m",
eta: float = 1.0,
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Buffer for multistep
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step = self._step_index
sigma = self.sigmas[step]
sigma_next = self.sigmas[step + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
self.model_outputs.append(model_output)
self.x0_outputs.append(x0)
self.prev_sigmas.append(sigma)
# Order logic
variant = self.config.variant
order = int(variant[-2]) if variant.endswith("m") else 1
# Effective order for current step
curr_order = min(len(self.prev_sigmas), order)
# REiS Multistep logic
c2, c3 = 0.5, 1.0
if curr_order == 2:
h_prev = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-2])
c2 = (-h_prev / h).item() if h > 0 else 0.5
rk_type = "res_2s"
elif curr_order == 3:
h_prev1 = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-2])
h_prev2 = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-3])
c2 = (-h_prev1 / h).item() if h > 0 else 0.5
c3 = (-h_prev2 / h).item() if h > 0 else 1.0
rk_type = "res_3s"
else:
rk_type = "res_1s"
if curr_order == 1:
rk_type = "res_1s"
_a, b, _ci = self._get_res_coefficients(rk_type, h, c2, c3)
# Apply coefficients to get multistep x_0
res = torch.zeros_like(sample)
for i, b_val in enumerate(b[0]):
idx = len(self.x0_outputs) - 1 - i
if idx >= 0:
res += b_val * self.x0_outputs[idx]
# SDE stochastic step
eta = self.config.eta
if sigma_next == 0:
x_next = x0
else:
# Ancestral SDE logic:
# 1. Calculate sigma_up and sigma_down to preserve variance
# sigma_up = eta * sigma_next * sqrt(1 - (sigma_next/sigma)^2)
# sigma_down = sqrt(sigma_next^2 - sigma_up^2)
sigma_up = eta * (sigma_next**2 * (sigma**2 - sigma_next**2) / (sigma**2 + 1e-9))**0.5
sigma_down = (sigma_next**2 - sigma_up**2)**0.5
# 2. Take deterministic step to sigma_down
h_det = -torch.log(sigma_down / sigma) if sigma > 0 and sigma_down > 0 else h
# Re-calculate coefficients for h_det
_a, b_det, _ci = self._get_res_coefficients(rk_type, h_det, c2, c3)
res_det = torch.zeros_like(sample)
for i, b_val in enumerate(b_det[0]):
idx = len(self.x0_outputs) - 1 - i
if idx >= 0:
res_det += b_val * self.x0_outputs[idx]
x_det = torch.exp(-h_det) * sample + h_det * res_det
# 3. Add noise scaled by sigma_up
if eta > 0:
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
x_next = x_det + sigma_up * noise
else:
x_next = x_det
self._step_index += 1
if len(self.x0_outputs) > order:
self.x0_outputs.pop(0)
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _get_res_coefficients(self, rk_type, h, c2, c3):
from .phi_functions import Phi, calculate_gamma
ci = [0, c2, c3]
phi = Phi(h, ci, self.config.use_analytic_solution)
if rk_type == "res_2s":
b2 = phi(2) / (c2 + 1e-9)
b = [[phi(1) - b2, b2]]
a = [[0, 0], [c2 * phi(1, 2), 0]]
elif rk_type == "res_3s":
gamma_val = calculate_gamma(c2, c3)
b3 = phi(2) / (gamma_val * c2 + c3 + 1e-9)
b2 = gamma_val * b3
b = [[phi(1) - (b2 + b3), b2, b3]]
a = []
else:
b = [[phi(1)]]
a = [[0]]
return a, b, ci
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,243 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class RESSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
RESSinglestepScheduler (Multistage Exponential Integrator) ported from RES4LYF.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
prediction_type: str = "epsilon",
variant: Literal["res_2s", "res_3s", "res_5s", "res_6s"] = "res_2s",
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
# Linear remapping logic
if self.config.use_flow_sigmas:
# Logic handled here
pass
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
if not self.config.use_flow_sigmas:
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
if self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
if self.config.use_flow_sigmas:
timesteps = sigmas * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step = self._step_index
sigma = self.sigmas[step]
sigma_next = self.sigmas[step + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0 (Matching PEC pattern)
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
if self.config.prediction_type == "flow_prediction":
dt = sigma_next - sigma
x_next = sample + dt * model_output
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
# Exponential Integrator Update
if sigma_next == 0:
x_next = x0
else:
# For singlestep RES (multistage), a proper RK requires model evals at intermediate ci * h.
# Here we provide the standard 1st order update as a base.
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,237 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin):
"""
RESSinglestepSDEScheduler (Stochastic Multistage Exponential Integrator) ported from RES4LYF.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
prediction_type: str = "epsilon",
variant: Literal["res_2s", "res_3s", "res_5s", "res_6s"] = "res_2s",
eta: float = 1.0,
use_analytic_solution: bool = True,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step = self._step_index
sigma = self.sigmas[step]
sigma_next = self.sigmas[step + 1]
eta = self.config.eta
# RECONSTRUCT X0
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1)**0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
# Exponential Integrator Update (Deterministic Part)
if sigma_next == 0:
x_next = x0
else:
# Ancestral SDE logic
sigma_up = eta * (sigma_next**2 * (sigma**2 - sigma_next**2) / (sigma**2 + 1e-9))**0.5
sigma_down = (sigma_next**2 - sigma_up**2)**0.5
h_det = -torch.log(sigma_down / sigma) if sigma > 0 and sigma_down > 0 else torch.zeros_like(sigma)
# Deterministic update to sigma_down
x_det = torch.exp(-h_det) * sample + (1 - torch.exp( -h_det)) * x0
# Stochastic part
if eta > 0:
noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
x_next = x_det + sigma_up * noise
else:
x_next = x_det
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,342 @@
from typing import ClassVar, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .phi_functions import Phi
class RESUnifiedScheduler(SchedulerMixin, ConfigMixin):
"""
RESUnifiedScheduler (Exponential Integrator) ported from RES4LYF.
Supports RES 2M, 3M, 2S, 3S, 5S, 6S
Supports DEIS 1S, 2M, 3M
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
prediction_type: str = "epsilon",
rk_type: str = "res_2m",
use_analytic_solution: bool = True,
rescale_betas_zero_snr: bool = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = torch.Tensor([])
self.timesteps = torch.Tensor([])
self.model_outputs = []
self.x0_outputs = []
self.prev_sigmas = []
self._step_index = None
self._begin_index = None
self.init_noise_sigma = 1.0
def set_sigmas(self, sigmas: torch.Tensor):
self.sigmas = sigmas
self._step_index = None
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
timestep_spacing = getattr(self.config, "timestep_spacing", "linspace")
steps_offset = getattr(self.config, "steps_offset", 0)
if timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += steps_offset
elif timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {timestep_spacing} is not supported.")
# Derived sigma range from alphas_cumprod
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = base_sigmas[::-1].copy() # Ensure high to low
if getattr(self.config, "use_karras_sigmas", False):
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_exponential_sigmas", False):
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_beta_sigmas", False):
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_flow_sigmas", False):
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
else:
if self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
else:
# Re-sample the base sigmas at the requested steps
idx = np.linspace(0, len(base_sigmas) - 1, num_inference_steps)
sigmas = np.interp(idx, np.arange(len(base_sigmas)), base_sigmas)[::-1].copy()
shift = getattr(self.config, "shift", 1.0)
use_dynamic_shifting = getattr(self.config, "use_dynamic_shifting", False)
if shift != 1.0 or use_dynamic_shifting:
if use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
getattr(self.config, "base_shift", 0.5),
getattr(self.config, "max_shift", 1.5),
getattr(self.config, "base_image_seq_len", 256),
getattr(self.config, "max_image_seq_len", 4096),
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
if getattr(self.config, "use_flow_sigmas", False):
timesteps = sigmas * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def _get_coefficients(self, sigma, sigma_next):
h = -torch.log(sigma_next / sigma) if sigma > 0 else torch.zeros_like(sigma)
phi = Phi(h, [], getattr(self.config, "use_analytic_solution", True))
phi_1 = phi(1)
phi_2 = phi(2)
# phi_2 = phi(2) # Moved inside conditional blocks as needed
history_len = len(self.x0_outputs)
# Stability: Force Order 1 for final few steps to prevent degradation at low noise levels
if self.num_inference_steps is not None and self._step_index >= self.num_inference_steps - 3:
return [phi_1], h
if self.config.rk_type in ["res_2m", "deis_2m"] and history_len >= 2:
h_prev = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
r = h_prev / (h + 1e-9)
h_prev = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
r = h_prev / (h + 1e-9)
# Hard Restart: if step sizes vary too wildly, fallback to order 1
if r < 0.5 or r > 2.0:
return [phi_1], h
phi_2 = phi(2)
# Correct Adams-Bashforth-like coefficients for Exponential Integrators
b2 = -phi_2 / (r + 1e-9)
b1 = phi_1 - b2
return [b1, b2], h
elif self.config.rk_type in ["res_3m", "deis_3m"] and history_len >= 3:
h_prev1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
h_prev2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
r1 = h_prev1 / (h + 1e-9)
r2 = h_prev2 / (h + 1e-9)
h_prev1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
h_prev2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
r1 = h_prev1 / (h + 1e-9)
r2 = h_prev2 / (h + 1e-9)
# Hard Restart check
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
return [phi_1], h
phi_2 = phi(2)
phi_3 = phi(3)
# Generalized AB3 for Exponential Integrators (Varying steps)
denom = r2 - r1 + 1e-9
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
b1 = phi_1 - (b2 + b3)
return [b1, b2, b3], h
return [phi_1], h
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self._step_index]
sigma_next = self.sigmas[self._step_index + 1]
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
# RECONSTRUCT X0 (Matching PEC pattern)
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
self.x0_outputs.append(x0)
self.model_outputs.append(model_output) # Added for AB support
self.prev_sigmas.append(sigma)
if len(self.x0_outputs) > 3:
self.x0_outputs.pop(0)
self.model_outputs.pop(0)
self.prev_sigmas.pop(0)
if self.config.prediction_type == "flow_prediction":
# Variable Step Adams-Bashforth for Flow Matching
dt = sigma_next - sigma
v_n = model_output
curr_order = min(len(self.prev_sigmas), 3) # Max order 3 here
if curr_order == 1:
x_next = sample + dt * v_n
elif curr_order == 2:
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma - sigma_prev
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
if dt_prev == 0 or r < -0.9 or r > 2.0:
x_next = sample + dt * v_n
else:
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
else:
# AB2 fallback for robustness
sigma_prev = self.prev_sigmas[-2]
dt_prev = sigma - sigma_prev
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
c0 = 1 + 0.5 * r
c1 = -0.5 * r
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
# GET COEFFICIENTS
b, h_val = self._get_coefficients(sigma, sigma_next)
if len(b) == 1:
res = b[0] * x0
elif len(b) == 2:
res = b[0] * self.x0_outputs[-1] + b[1] * self.x0_outputs[-2]
elif len(b) == 3:
res = b[0] * self.x0_outputs[-1] + b[1] * self.x0_outputs[-2] + b[2] * self.x0_outputs[-3]
else:
res = b[0] * x0
# UPDATE
if sigma_next == 0:
x_next = x0
else:
# Propagate in x-space (unnormalized)
x_next = torch.exp(-h) * sample + h * res
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,264 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin):
"""
Riemannian Flow scheduler using Exponential Integrator step.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
metric_type: Literal["euclidean", "hyperbolic", "spherical", "lorentzian"] = "hyperbolic",
curvature: float = 1.0,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# Setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = torch.zeros((num_train_timesteps,), dtype=torch.float32)
self._step_index = None
self._begin_index = None
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
timestep_spacing = getattr(self.config, "timestep_spacing", "linspace")
steps_offset = getattr(self.config, "steps_offset", 0)
if timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
timesteps += steps_offset
elif timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
timesteps -= 1
else:
raise ValueError(f"timestep_spacing {timestep_spacing} is not supported.")
# Derived sigma range from alphas_cumprod
# In FM, we usually go from sigma_max to sigma_min
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
# Note: alphas_cumprod[0] is ~0.999 (small sigma), alphas_cumprod[-1] is ~0.0001 (large sigma)
start_sigma = base_sigmas[-1]
end_sigma = base_sigmas[0]
t = torch.linspace(0, 1, num_inference_steps, device=device)
metric_type = self.config.metric_type
curvature = self.config.curvature
if metric_type == "euclidean":
result = start_sigma * (1 - t) + end_sigma * t
elif metric_type == "hyperbolic":
x_start = torch.tanh(torch.tensor(start_sigma / 2, device=device))
x_end = torch.tanh(torch.tensor(end_sigma / 2, device=device))
d = torch.acosh(torch.clamp(1 + 2 * ((x_start - x_end)**2) / ((1 - x_start**2) * (1 - x_end**2) + 1e-9), min=1.0))
lambda_t = torch.sinh(t * d) / (torch.sinh(d) + 1e-9)
result = 2 * torch.atanh(torch.clamp((1 - lambda_t) * x_start + lambda_t * x_end, -0.999, 0.999))
elif metric_type == "spherical":
k = torch.tensor(curvature, device=device)
theta_start = start_sigma * torch.sqrt(k)
theta_end = end_sigma * torch.sqrt(k)
result = torch.sin((1 - t) * theta_start + t * theta_end) / torch.sqrt(k)
elif metric_type == "lorentzian":
gamma = 1 / torch.sqrt(torch.clamp(1 - curvature * t**2, min=1e-9))
result = (start_sigma * (1 - t) + end_sigma * t) * gamma
else:
result = start_sigma * (1 - t) + end_sigma * t
result = torch.clamp(result, min=min(start_sigma, end_sigma), max=max(start_sigma, end_sigma))
if start_sigma > end_sigma:
result, _ = torch.sort(result, descending=True)
sigmas = result.cpu().numpy()
if getattr(self.config, "use_karras_sigmas", False):
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_exponential_sigmas", False):
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_beta_sigmas", False):
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif getattr(self.config, "use_flow_sigmas", False):
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
shift = getattr(self.config, "shift", 1.0)
use_dynamic_shifting = getattr(self.config, "use_dynamic_shifting", False)
if shift != 1.0 or use_dynamic_shifting:
if use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
getattr(self.config, "base_shift", 0.5),
getattr(self.config, "max_shift", 1.5),
getattr(self.config, "base_image_seq_len", 256),
getattr(self.config, "max_image_seq_len", 4096),
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
# Determine denoised (x_0 prediction)
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1)**0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
# Exponential Integrator Update (1st order)
if sigma_next == 0:
x_next = x0
else:
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,251 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin):
"""
RK4: Classical 4th-order Runge-Kutta scheduler.
Adapted from the RES4LYF repository.
This scheduler uses 4 stages per step.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for RungeKutta44Scheduler")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.init_noise_sigma = 1.0
# Internal state for multi-stage
self.model_outputs = []
self.sample_at_start_of_step = None
self._sigmas_cpu = None
self._step_index = None
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Base sigmas
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
# 2. Add sub-step sigmas for multi-stage RK
# RK4 has c = [0, 1/2, 1/2, 1]
c_values = [0.0, 0.5, 0.5, 1.0]
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
# Intermediate sigmas: s_curr + c * (s_next - s_curr)
for c in c_values:
# Add a tiny epsilon to duplicate sigmas to allow distinct indexing if needed,
# but better to rely on internal counter.
sigmas_expanded.append(s_curr + c * (s_next - s_curr))
sigmas_expanded.append(0.0) # terminal sigma
# 3. Map back to timesteps
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
# Use argmin for robust float matching
index = torch.abs(schedule_timesteps - timestep).argmin().item()
return index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self._sigmas_cpu[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
stage_index = step_index % 4
# Current and next step interval sigmas
base_step_index = (step_index // 4) * 4
sigma_curr = self._sigmas_cpu[base_step_index]
sigma_next_idx = min(base_step_index + 4, len(self._sigmas_cpu) - 1)
sigma_next = self._sigmas_cpu[sigma_next_idx] # The sigma at the end of this 4-stage step
h = sigma_next - sigma_curr
sigma_t = self._sigmas_cpu[step_index]
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
prediction_type = getattr(self.config, "prediction_type", "epsilon")
if prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {prediction_type}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self._sigmas_cpu[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
# Stage 2 input: y + 0.5 * h * k1
prev_sample = self.sample_at_start_of_step + 0.5 * h * derivative
elif stage_index == 1:
self.model_outputs.append(derivative)
# Stage 3 input: y + 0.5 * h * k2
prev_sample = self.sample_at_start_of_step + 0.5 * h * derivative
elif stage_index == 2:
self.model_outputs.append(derivative)
# Stage 4 input: y + h * k3
prev_sample = self.sample_at_start_of_step + h * derivative
elif stage_index == 3:
self.model_outputs.append(derivative)
# Final result: y + (h/6) * (k1 + 2*k2 + 2*k3 + k4)
k1, k2, k3, k4 = self.model_outputs
prev_sample = self.sample_at_start_of_step + (h / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
# Clear state
self.model_outputs = []
self.sample_at_start_of_step = None
# Increment step index
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,299 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin):
"""
RK5_7S: 5th-order Runge-Kutta scheduler with 7 stages.
Adapted from the RES4LYF repository.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.init_noise_sigma = 1.0
# Internal state
self.model_outputs = []
self.sample_at_start_of_step = None
self._sigmas_cpu = None
self._step_index = None
self._timesteps_cpu = None
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float).copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(float)
timesteps -= step_ratio
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
# Ensure trailing ends at 0
if self.config.timestep_spacing == "trailing":
timesteps = np.maximum(timesteps, 0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# RK5_7s c values: [0, 1/5, 3/10, 4/5, 8/9, 1, 1]
c_values = [0, 1 / 5, 3 / 10, 4 / 5, 8 / 9, 1, 1]
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
for c in c_values:
sigmas_expanded.append(s_curr + c * (s_next - s_curr))
sigmas_expanded.append(0.0)
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
self._timesteps_cpu = self.timesteps.detach().cpu().numpy()
self._step_index = None
self.model_outputs = []
self.sample_at_start_of_step = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def _init_step_index(self, timestep):
if self._step_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self._sigmas_cpu[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
# Dormand-Prince 5(4) Coefficients
a = [
[],
[1/5],
[3/40, 9/40],
[44/45, -56/15, 32/9],
[19372/6561, -25360/2187, 64448/6561, -212/729],
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]
]
b = [35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0]
step_index = self._step_index
stage_index = step_index % 7
base_step_index = (step_index // 7) * 7
sigma_curr = self._sigmas_cpu[base_step_index]
sigma_next_idx = min(base_step_index + 7, len(self._sigmas_cpu) - 1)
sigma_next = self._sigmas_cpu[sigma_next_idx]
h = sigma_next - sigma_curr
sigma_t = self._sigmas_cpu[step_index]
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
prediction_type = getattr(self.config, "prediction_type", "epsilon")
if prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {prediction_type}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self._sigmas_cpu[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
else:
self.model_outputs.append(derivative)
if stage_index < 6:
# Predict next stage sample: y_next_stage = y_start + h * sum(a[stage_index+1][j] * k[j])
next_a_row = a[stage_index + 1]
sum_ak = torch.zeros_like(derivative)
for j, weight in enumerate(next_a_row):
sum_ak += weight * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_ak
else:
# Final 7th stage complete, calculate final step
sum_bk = torch.zeros_like(derivative)
for j, weight in enumerate(b):
sum_bk += weight * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_bk
# Clear state
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,301 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin):
"""
RK6_7S: 6th-order Runge-Kutta scheduler with 7 stages.
Adapted from the RES4LYF repository.
(Note: Defined as 5th order in some contexts, but follows the 7-stage tableau).
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
rho: float = 7.0,
shift: Optional[float] = None,
base_shift: float = 0.5,
max_shift: float = 1.15,
use_dynamic_shifting: bool = False,
timestep_spacing: str = "linspace",
clip_sample: bool = False,
sample_max_value: float = 1.0,
set_alpha_to_one: bool = False,
skip_prk_steps: bool = False,
interpolation_type: str = "linear",
steps_offset: int = 0,
timestep_type: str = "discrete",
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.init_noise_sigma = 1.0
# internal state
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = None
self.model_outputs = []
self.sample_at_start_of_step = None
self._sigmas_cpu = None
self._timesteps_cpu = None
self._step_index = None
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
self.num_inference_steps = num_inference_steps
# 1. Spacing
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float).copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(float)
timesteps -= step_ratio
else:
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
# Ensure trailing ends at 0
if self.config.timestep_spacing == "trailing":
timesteps = np.maximum(timesteps, 0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
else:
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
if self.config.use_karras_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
rho = self.config.rho
ramp = np.linspace(0, 1, num_inference_steps)
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
elif self.config.use_exponential_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
elif self.config.use_beta_sigmas:
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
alpha, beta = 0.6, 0.6
ramp = np.linspace(0, 1, num_inference_steps)
try:
import torch.distributions as dist
b = dist.Beta(alpha, beta)
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
except Exception:
pass
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
elif self.config.use_flow_sigmas:
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
# 3. Shifting
if self.config.use_dynamic_shifting and mu is not None:
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
elif self.config.shift is not None:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
# RK6_7s c values: [0, 1/3, 2/3, 1/3, 1/2, 1/2, 1]
c_values = [0, 1 / 3, 2 / 3, 1 / 3, 1 / 2, 1 / 2, 1]
sigmas_expanded = []
for i in range(len(sigmas) - 1):
s_curr = sigmas[i]
s_next = sigmas[i + 1]
for c in c_values:
sigmas_expanded.append(s_curr + c * (s_next - s_curr))
sigmas_expanded.append(0.0)
sigmas_interpolated = np.array(sigmas_expanded)
# Linear remapping for Flow Matching
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
self._timesteps_cpu = self.timesteps.detach().cpu().numpy()
self._step_index = None
self.model_outputs = []
self.sample_at_start_of_step = None
@property
def step_index(self):
"""
The index counter for the current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def _init_step_index(self, timestep):
if self._step_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self._sigmas_cpu[self._step_index]
return sample / ((sigma**2 + 1) ** 0.5)
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
stage_index = step_index % 7
base_step_index = (step_index // 7) * 7
sigma_curr = self._sigmas_cpu[base_step_index]
sigma_next_idx = min(base_step_index + 7, len(self._sigmas_cpu) - 1)
sigma_next = self._sigmas_cpu[sigma_next_idx]
h = sigma_next - sigma_curr
sigma_t = self._sigmas_cpu[step_index]
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
prediction_type = getattr(self.config, "prediction_type", "epsilon")
if prediction_type == "epsilon":
denoised = sample - sigma_t * model_output
elif prediction_type == "v_prediction":
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
sigma_actual = sigma_t * alpha_t
denoised = alpha_t * sample - sigma_actual * model_output
elif prediction_type == "flow_prediction":
denoised = sample - sigma_t * model_output
elif prediction_type == "sample":
denoised = model_output
else:
raise ValueError(f"prediction_type error: {prediction_type}")
if self.config.clip_sample:
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
# derivative = (x - x0) / sigma
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
if self.sample_at_start_of_step is None:
if stage_index > 0:
# Mid-step fallback for Img2Img/Inpainting
sigma_next_t = self._sigmas_cpu[self._step_index + 1]
dt = sigma_next_t - sigma_t
prev_sample = sample + dt * derivative
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
self.sample_at_start_of_step = sample
self.model_outputs = [derivative] * stage_index
# Butcher Tableau A matrix for rk6_7s
a = [
[],
[1 / 3],
[0, 2 / 3],
[1 / 12, 1 / 3, -1 / 12],
[-1 / 16, 9 / 8, -3 / 16, -3 / 8],
[0, 9 / 8, -3 / 8, -3 / 4, 1 / 2],
[9 / 44, -9 / 11, 63 / 44, 18 / 11, 0, -16 / 11],
]
# Butcher Tableau B weights for rk6_7s
b = [11 / 120, 0, 27 / 40, 27 / 40, -4 / 15, -4 / 15, 11 / 120]
if stage_index == 0:
self.model_outputs = [derivative]
self.sample_at_start_of_step = sample
else:
self.model_outputs.append(derivative)
if stage_index < 6:
# Predict next stage sample: y_next_stage = y_start + h * sum(a[stage_index+1][j] * k[j])
next_a_row = a[stage_index + 1]
sum_ak = torch.zeros_like(derivative)
for j, weight in enumerate(next_a_row):
sum_ak += weight * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_ak
else:
# Final 7th stage complete, calculate final step
sum_bk = torch.zeros_like(derivative)
for j, weight in enumerate(b):
sum_bk += weight * self.model_outputs[j]
prev_sample = self.sample_at_start_of_step + h * sum_bk
# Clear state
self.model_outputs = []
self.sample_at_start_of_step = None
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -0,0 +1,119 @@
import math
from typing import Literal
import numpy as np
import torch
try:
import scipy.stats
_scipy_available = True
except ImportError:
_scipy_available = False
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=dtype)
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
alphas_bar_sqrt -= alphas_bar_sqrt_T
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu", dtype: torch.dtype = torch.float32):
ramp = np.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu", dtype: torch.dtype = torch.float32):
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), n))
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
def get_sigmas_beta(n, sigma_min, sigma_max, alpha=0.6, beta=0.6, device="cpu", dtype: torch.dtype = torch.float32):
if not _scipy_available:
raise ImportError("scipy is required for beta sigmas")
sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, n)
]
]
)
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
def get_sigmas_flow(n, sigma_min, sigma_max, device="cpu", dtype: torch.dtype = torch.float32):
# Linear flow sigmas
sigmas = np.linspace(sigma_max, sigma_min, n)
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
def apply_shift(sigmas, shift):
return shift * sigmas / (1 + (shift - 1) * sigmas)
def get_dynamic_shift(mu, base_shift, max_shift, base_seq_len, max_seq_len):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
return m * mu + b
def index_for_timestep(timestep, timesteps):
# Normalize inputs to numpy arrays for a robust, device-agnostic argmin
if isinstance(timestep, torch.Tensor):
timestep_np = timestep.detach().cpu().numpy()
else:
timestep_np = np.array(timestep)
if isinstance(timesteps, torch.Tensor):
timesteps_np = timesteps.detach().cpu().numpy()
else:
timesteps_np = np.array(timesteps)
# Use numpy argmin on absolute difference for stability
idx = np.abs(timesteps_np - timestep_np).argmin()
return int(idx)
def add_noise_to_sample(
original_samples: torch.Tensor,
noise: torch.Tensor,
sigmas: torch.Tensor,
timestep: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
step_index = index_for_timestep(timestep, timesteps)
sigma = sigmas[step_index].to(original_samples.dtype)
noisy_samples = original_samples + sigma * noise
return noisy_samples

View File

@ -0,0 +1,214 @@
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin):
"""
Simple Exponential sigma scheduler using Exponential Integrator step.
"""
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
order: ClassVar[int] = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: str = "linear",
sigma_max: float = 1.0,
sigma_min: float = 0.01,
gain: float = 1.0,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
rescale_betas_zero_snr: bool = False,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
use_flow_sigmas: bool = False,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float = 0.5,
max_shift: float = 1.15,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
):
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does not exist.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# Setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.sigmas = torch.zeros((num_train_timesteps,), dtype=torch.float32)
self._step_index = None
self._begin_index = None
@property
def step_index(self) -> Optional[int]:
return self._step_index
@property
def begin_index(self) -> Optional[int]:
return self._begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
self._begin_index = begin_index
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
from .scheduler_utils import (
apply_shift,
get_dynamic_shift,
get_sigmas_beta,
get_sigmas_exponential,
get_sigmas_flow,
get_sigmas_karras,
)
self.num_inference_steps = num_inference_steps
sigmas = np.exp(np.linspace(np.log(self.config.sigma_max), np.log(self.config.sigma_min), num_inference_steps))
if self.config.use_karras_sigmas:
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_exponential_sigmas:
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_beta_sigmas:
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
elif self.config.use_flow_sigmas:
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
shift = self.config.shift
if self.config.use_dynamic_shifting and mu is not None:
shift = get_dynamic_shift(
mu,
self.config.base_shift,
self.config.max_shift,
self.config.base_image_seq_len,
self.config.max_image_seq_len,
)
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
self.timesteps = torch.from_numpy(np.linspace(1000, 0, num_inference_steps)).to(device=device, dtype=dtype)
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
from .scheduler_utils import index_for_timestep
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
return index_for_timestep(timestep, schedule_timesteps)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
from .scheduler_utils import add_noise_to_sample
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
if self._step_index is None:
self._init_step_index(timestep)
if self.config.prediction_type == "flow_prediction":
return sample
sigma = self.sigmas[self._step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def step(
self,
model_output: torch.Tensor,
timestep: Union[float, torch.Tensor],
sample: torch.Tensor,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self._step_index is None:
self._init_step_index(timestep)
step_index = self._step_index
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
# Determine denoised (x_0 prediction)
if self.config.prediction_type == "epsilon":
x0 = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
alpha_t = 1.0 / (sigma**2 + 1)**0.5
sigma_t = sigma * alpha_t
x0 = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "sample":
x0 = model_output
elif self.config.prediction_type == "flow_prediction":
x0 = sample - sigma * model_output
else:
x0 = model_output
# Exponential Integrator Update (1st order)
if sigma_next == 0:
x_next = x0
else:
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
self._step_index += 1
if not return_dict:
return (x_next,)
return SchedulerOutput(prev_sample=x_next)
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def __len__(self):
return self.config.num_train_timesteps

Some files were not shown because too many files have changed in this diff Show More