mirror of https://github.com/vladmandic/automatic
commit
0d240b1a8f
|
|
@ -1,7 +1,7 @@
|
|||
# defaults
|
||||
venv/
|
||||
__pycache__
|
||||
.ruff_cache
|
||||
/cache.json
|
||||
/*.json
|
||||
/*.yaml
|
||||
/params.txt
|
||||
|
|
@ -9,13 +9,14 @@ __pycache__
|
|||
/user.css
|
||||
/webui-user.bat
|
||||
/webui-user.sh
|
||||
/html/extensions.json
|
||||
/html/themes.json
|
||||
/data/metadata.json
|
||||
/data/extensions.json
|
||||
/data/cache.json
|
||||
/data/themes.json
|
||||
config_states
|
||||
node_modules
|
||||
pnpm-lock.yaml
|
||||
package-lock.json
|
||||
venv
|
||||
.history
|
||||
cache
|
||||
**/.DS_Store
|
||||
|
|
@ -65,6 +66,7 @@ tunableop_results*.csv
|
|||
.*/
|
||||
|
||||
# force included
|
||||
!/data
|
||||
!/models/VAE-approx
|
||||
!/models/VAE-approx/model.pt
|
||||
!/models/Reference
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ ignore-paths=/usr/lib/.*$,
|
|||
modules/taesd,
|
||||
modules/teacache,
|
||||
modules/todo,
|
||||
modules/res4lyf,
|
||||
pipelines/bria,
|
||||
pipelines/flex2,
|
||||
pipelines/f_lite,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
line-length = 250
|
||||
indent-width = 4
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"venv",
|
||||
".git",
|
||||
|
|
@ -41,9 +44,6 @@ exclude = [
|
|||
"extensions-builtin/sd-webui-agent-scheduler",
|
||||
"extensions-builtin/sdnext-modernui/node_modules",
|
||||
]
|
||||
line-length = 250
|
||||
indent-width = 4
|
||||
target-version = "py310"
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
"env": { "USED_VSCODE_COMMAND_PICKARGS": "1" },
|
||||
"args": [
|
||||
"--uv",
|
||||
"--quick",
|
||||
"--log", "vscode.log",
|
||||
"${command:pickArgs}"]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
{
|
||||
"files.eol": "\n",
|
||||
"python.analysis.extraPaths": [".", "./modules", "./scripts", "./pipelines"],
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"editor.formatOnSave": false,
|
||||
|
|
|
|||
76
CHANGELOG.md
76
CHANGELOG.md
|
|
@ -1,5 +1,79 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2026-02-04
|
||||
|
||||
### Highlights for 2026-02-04
|
||||
|
||||
Refresh release two weeks after prior release, yet we still somehow managed to pack in *~150 commits*!
|
||||
Highlights would be two new models: **Z-Image-Base** and **Anima**, *captioning* support for **tagger** models and a massive addition of new **schedulers**
|
||||
Also here are updates to `torch` and additional GPU archs support for `ROCm` backends, plus a lot of internal improvements and fixes.
|
||||
|
||||
[ReadMe](https://github.com/vladmandic/automatic/blob/master/README.md) | [ChangeLog](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [Docs](https://vladmandic.github.io/sdnext-docs/) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867) | [Sponsor](https://github.com/sponsors/vladmandic)
|
||||
|
||||
### Details for 2026-02-04
|
||||
|
||||
- **Models**
|
||||
- [Tongyi-MAI Z-Image Base](https://tongyi-mai.github.io/Z-Image-blog/)
|
||||
yup, its finally here, the full base model of **Z-Image**
|
||||
- [CircleStone Anima](https://huggingface.co/circlestone-labs/Anima)
|
||||
2B anime optimized model based on a modified Cosmos-Predict, using Qwen3-0.6B as a text encoder
|
||||
- **Features**
|
||||
- **caption** tab support for Booru tagger models, thanks @CalamitousFelicitousness
|
||||
- add SmilingWolf WD14/WaifuDiffusion tagger models, thanks @CalamitousFelicitousness
|
||||
- support comments in wildcard files, using `#`
|
||||
- support aliases in metadata skip params, thanks @CalamitousFelicitousness
|
||||
- ui gallery improve cache cleanup and add manual option, thanks @awsr
|
||||
- selectable options to add system info to metadata, thanks @Athari
|
||||
see *settings -> image metadata*
|
||||
- **Schedulers**
|
||||
- schedulers documentation has new home: <https://vladmandic.github.io/sdnext-docs/Schedulers/>
|
||||
- add 13(!) new scheduler families
|
||||
not a port, but more of inspired-by [res4lyf](https://github.com/ClownsharkBatwing/RES4LYF) library
|
||||
all schedulers should be compatible with both `epsilon` and `flow` prediction style!
|
||||
*note*: each family may have multiple actual schedulers, so the list total is 56(!) new schedulers
|
||||
- core family: *RES*
|
||||
- exponential: *DEIS, ETD, Lawson, ABNorsett*
|
||||
- integrators: *Runge-Kutta, Linear-RK, Specialized-RK, Lobatto, Radau-IIA, Gauss-Legendre*
|
||||
- flow: *PEC, Riemannian, Euclidean, Hyperbolic, Lorentzian, Langevin-Dynamics*
|
||||
- add 3 additional schedulers: *CogXDDIM, DDIMParallel, DDPMParallel*
|
||||
not originally intended to be a general purpose schedulers, but they work quite nicely and produce good results
|
||||
- image metadata: always log scheduler class used
|
||||
- **API**
|
||||
- add `/sdapi/v1/xyz-grid` to enumerate xyz-grid axis options and their choices
|
||||
see `/cli/api-xyzenum.py` for example usage
|
||||
- add `/sdapi/v1/sampler` to get current sampler config
|
||||
- modify `/sdapi/v1/samplers` to enumerate available samplers possible options
|
||||
see `/cli/api-samplers.py` for example usage
|
||||
- **Internal**
|
||||
- tagged release history: <https://github.com/vladmandic/sdnext/tags>
|
||||
each major for the past year is now tagged for easier reference
|
||||
- **torch** update
|
||||
*note*: may cause slow first startup/generate
|
||||
**cuda**: update to `torch==2.10.0`
|
||||
**xpu**: update to `torch==2.10.0`
|
||||
**rocm**: update to `torch==2.10.0`
|
||||
**openvino**: update to `torch==2.10.0` and `openvino==2025.4.1`
|
||||
- rocm: expand available gfx archs, thanks @crashingalexsan
|
||||
- rocm: set `MIOPEN_FIND_MODE=2` by default, thanks @crashingalexsan
|
||||
- relocate all json data files to `data/` folder
|
||||
existing data files are auto-migrated on startup
|
||||
- refactor and improve connection monitor, thanks @awsr
|
||||
- further work on type consistency and type checking, thanks @awsr
|
||||
- log captured exceptions
|
||||
- improve temp folder handling and cleanup
|
||||
- remove torch errors/warings on fast server shutdown
|
||||
- add ui placeholders for future agent-scheduler work, thanks @ryanmeador
|
||||
- implement abort system on repeated errors, thanks @awsr
|
||||
currently used by lora and textual-inversion loaders
|
||||
- update package requirements
|
||||
- **Fixes**
|
||||
- add video ui elem_ids, thanks @ryanmeador
|
||||
- use base steps as-is for non sd/sdxl models
|
||||
- ui css fixes for modernui
|
||||
- support lora inside prompt selector
|
||||
- framepack video save
|
||||
- metadata save for manual saves
|
||||
|
||||
## Update for 2026-01-22
|
||||
|
||||
Bugfix refresh
|
||||
|
|
@ -139,7 +213,7 @@ End of year release update, just two weeks after previous one, with several new
|
|||
- **Models**
|
||||
- [LongCat Image](https://github.com/meituan-longcat/LongCat-Image) in *Image* and *Image Edit* variants
|
||||
LongCat is a new 8B diffusion base model using Qwen-2.5 as text encoder
|
||||
- [Qwen-Image-Edit 2511](Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants
|
||||
- [Qwen-Image-Edit 2511](https://huggingface.co/Qwen/Qwen-Image-Edit-2511) in *base* and *pre-quantized* variants
|
||||
Key enhancements: mitigate image drift, improved character consistency, enhanced industrial design generation, and strengthened geometric reasoning ability
|
||||
- [Qwen-Image-Layered](https://huggingface.co/Qwen/Qwen-Image-Layered) in *base* and *pre-quantized* variants
|
||||
Qwen-Image-Layered, a model capable of decomposing an image into multiple RGBA layers
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<div align="center">
|
||||
<img src="https://github.com/vladmandic/sdnext/raw/master/html/logo-transparent.png" width=200 alt="SD.Next">
|
||||
|
||||
# SD.Next: All-in-one WebUI for AI generative image and video creation
|
||||
# SD.Next: All-in-one WebUI for AI generative image and video creation and captioning
|
||||
|
||||

|
||||

|
||||
|
|
@ -27,10 +27,8 @@
|
|||
All individual features are not listed here, instead check [ChangeLog](CHANGELOG.md) for full list of changes
|
||||
- Fully localized:
|
||||
▹ **English | Chinese | Russian | Spanish | German | French | Italian | Portuguese | Japanese | Korean**
|
||||
- Multiple UIs!
|
||||
▹ **Standard | Modern**
|
||||
- Desktop and Mobile support!
|
||||
- Multiple [diffusion models](https://vladmandic.github.io/sdnext-docs/Model-Support/)!
|
||||
- Built-in Control for Text, Image, Batch and Video processing!
|
||||
- Multi-platform!
|
||||
▹ **Windows | Linux | MacOS | nVidia CUDA | AMD ROCm | Intel Arc / IPEX XPU | DirectML | OpenVINO | ONNX+Olive | ZLUDA**
|
||||
- Platform specific auto-detection and tuning performed on install
|
||||
|
|
@ -38,9 +36,7 @@ All individual features are not listed here, instead check [ChangeLog](CHANGELOG
|
|||
Compile backends: *Triton | StableFast | DeepCache | OneDiff | TeaCache | etc.*
|
||||
Quantization methods: *SDNQ | BitsAndBytes | Optimum-Quanto | TorchAO / LayerWise*
|
||||
- **Interrogate/Captioning** with 150+ **OpenCLiP** models and 20+ built-in **VLMs**
|
||||
- Built-in queue management
|
||||
- Built in installer with automatic updates and dependency management
|
||||
- Mobile compatible
|
||||
|
||||
<br>
|
||||
|
||||
|
|
|
|||
190
TODO.md
190
TODO.md
|
|
@ -1,107 +1,137 @@
|
|||
# TODO
|
||||
|
||||
## Project Board
|
||||
|
||||
- <https://github.com/users/vladmandic/projects>
|
||||
|
||||
## Internal
|
||||
|
||||
- Feature: Move `nunchaku` models to refernce instead of internal decision
|
||||
- Update: `transformers==5.0.0`
|
||||
- Feature: Unify *huggingface* and *diffusers* model folders
|
||||
- Reimplement `llama` remover for Kanvas
|
||||
- Update: `transformers==5.0.0`, owner @CalamitousFelicitousness
|
||||
- Deploy: Create executable for SD.Next
|
||||
- Feature: Integrate natural language image search
|
||||
[ImageDB](https://github.com/vladmandic/imagedb)
|
||||
- Feature: Remote Text-Encoder support
|
||||
- Refactor: move sampler options to settings to config
|
||||
- Refactor: [GGUF](https://huggingface.co/docs/diffusers/main/en/quantization/gguf)
|
||||
- Feature: LoRA add OMI format support for SD35/FLUX.1
|
||||
- Refactor: remove `CodeFormer`
|
||||
- Refactor: remove `GFPGAN`
|
||||
- UI: Lite vs Expert mode
|
||||
- Video tab: add full API support
|
||||
- Control tab: add overrides handling
|
||||
- Engine: `TensorRT` acceleration
|
||||
- Deploy: Lite vs Expert mode
|
||||
- Engine: [mmgp](https://github.com/deepbeepmeep/mmgp)
|
||||
- Engine: [sharpfin](https://github.com/drhead/sharpfin) instead of `torchvision`
|
||||
- Engine: `TensorRT` acceleration
|
||||
- Feature: Auto handle scheduler `prediction_type`
|
||||
- Feature: Cache models in memory
|
||||
- Feature: Control tab add overrides handling
|
||||
- Feature: Integrate natural language image search
|
||||
[ImageDB](https://github.com/vladmandic/imagedb)
|
||||
- Feature: LoRA add OMI format support for SD35/FLUX.1, on-hold
|
||||
- Feature: Multi-user support
|
||||
- Feature: Remote Text-Encoder support, sidelined for the moment
|
||||
- Feature: Settings profile manager
|
||||
- Feature: Video tab add full API support
|
||||
- Refactor: Unify *huggingface* and *diffusers* model folders
|
||||
- Refactor: Move `nunchaku` models to refernce instead of internal decision, owner @CalamitousFelicitousness
|
||||
- Refactor: [GGUF](https://huggingface.co/docs/diffusers/main/en/quantization/gguf)
|
||||
- Refactor: move sampler options to settings to config
|
||||
- Refactor: remove `CodeFormer`, owner @CalamitousFelicitousness
|
||||
- Refactor: remove `GFPGAN`, owner @CalamitousFelicitousness
|
||||
- Reimplement `llama` remover for Kanvas, pending end-to-end review of `Kanvas`
|
||||
|
||||
## Modular
|
||||
|
||||
*Pending finalization of modular pipelines implementation and development of compatibility layer*
|
||||
|
||||
- Switch to modular pipelines
|
||||
- Feature: Transformers unified cache handler
|
||||
- Refactor: [Modular pipelines and guiders](https://github.com/huggingface/diffusers/issues/11915)
|
||||
- [MagCache](https://github.com/lllyasviel/FramePack/pull/673/files)
|
||||
- [MagCache](https://github.com/huggingface/diffusers/pull/12744)
|
||||
- [SmoothCache](https://github.com/huggingface/diffusers/issues/11135)
|
||||
|
||||
## Features
|
||||
|
||||
- [Flux.2 TinyVAE](https://huggingface.co/fal/FLUX.2-Tiny-AutoEncoder)
|
||||
- [IPAdapter composition](https://huggingface.co/ostris/ip-composition-adapter)
|
||||
- [IPAdapter negative guidance](https://github.com/huggingface/diffusers/discussions/7167)
|
||||
- [STG](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#spatiotemporal-skip-guidance)
|
||||
- [Video Inpaint Pipeline](https://github.com/huggingface/diffusers/pull/12506)
|
||||
- [Sonic Inpaint](https://github.com/ubc-vision/sonic)
|
||||
|
||||
### New models / Pipelines
|
||||
## New models / Pipelines
|
||||
|
||||
TODO: Investigate which models are diffusers-compatible and prioritize!
|
||||
|
||||
- [Bria FiboEdit](https://github.com/huggingface/diffusers/commit/d7a1c31f4f85bae5a9e01cdce49bd7346bd8ccd6)
|
||||
- [LTXVideo 0.98 LongMulti](https://github.com/huggingface/diffusers/pull/12614)
|
||||
- [Cosmos-Predict-2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B)
|
||||
- [NewBie Image Exp0.1](https://github.com/huggingface/diffusers/pull/12803)
|
||||
- [Sana-I2V](https://github.com/huggingface/diffusers/pull/12634#issuecomment-3540534268)
|
||||
- [Bria FIBO](https://huggingface.co/briaai/FIBO)
|
||||
- [Bytedance Lynx](https://github.com/bytedance/lynx)
|
||||
- [ByteDance OneReward](https://github.com/bytedance/OneReward)
|
||||
- [ByteDance USO](https://github.com/bytedance/USO)
|
||||
- [Chroma Radiance](https://huggingface.co/lodestones/Chroma1-Radiance)
|
||||
- [Chroma Zeta](https://huggingface.co/lodestones/Zeta-Chroma)
|
||||
- [DiffSynth Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||
- [DiffusionForcing](https://github.com/kwsong0113/diffusion-forcing-transformer)
|
||||
- [Dream0 guidance](https://huggingface.co/ByteDance/DreamO)
|
||||
- [HunyuanAvatar](https://huggingface.co/tencent/HunyuanVideo-Avatar)
|
||||
- [HunyuanCustom](https://github.com/Tencent-Hunyuan/HunyuanCustom)
|
||||
- [Inf-DiT](https://github.com/zai-org/Inf-DiT)
|
||||
- [Krea Realtime Video](https://huggingface.co/krea/krea-realtime-video)
|
||||
- [LanDiff](https://github.com/landiff/landiff)
|
||||
- [Liquid](https://github.com/FoundationVision/Liquid)
|
||||
- [LongCat-Video](https://huggingface.co/meituan-longcat/LongCat-Video)
|
||||
- [LucyEdit](https://github.com/huggingface/diffusers/pull/12340)
|
||||
- [Lumina-DiMOO](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO)
|
||||
- [Magi](https://github.com/SandAI-org/MAGI-1)(https://github.com/huggingface/diffusers/pull/11713)
|
||||
- [Ming](https://github.com/inclusionAI/Ming)
|
||||
- [MUG-V 10B](https://huggingface.co/MUG-V/MUG-V-inference)
|
||||
- [Ovi](https://github.com/character-ai/Ovi)
|
||||
- [Phantom HuMo](https://github.com/Phantom-video/Phantom)
|
||||
- [SD3 UltraEdit](https://github.com/HaozheZhao/UltraEdit)
|
||||
- [SelfForcing](https://github.com/guandeh17/Self-Forcing)
|
||||
- [SEVA](https://github.com/huggingface/diffusers/pull/11440)
|
||||
- [Step1X](https://github.com/stepfun-ai/Step1X-Edit)
|
||||
- [Wan-2.2 Animate](https://github.com/huggingface/diffusers/pull/12526)
|
||||
- [Wan-2.2 S2V](https://github.com/huggingface/diffusers/pull/12258)
|
||||
- [WAN-CausVid-Plus t2v](https://github.com/goatWu/CausVid-Plus/)
|
||||
- [WAN-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
|
||||
- [WAN-StepDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill)
|
||||
- [Wan2.2-Animate-14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B)
|
||||
- [WAN2GP](https://github.com/deepbeepmeep/Wan2GP)
|
||||
### Upscalers
|
||||
|
||||
- [HQX](https://github.com/uier/py-hqx/blob/main/hqx.py)
|
||||
- [DCCI](https://every-algorithm.github.io/2024/11/06/directional_cubic_convolution_interpolation.html)
|
||||
- [ICBI](https://github.com/gyfastas/ICBI/blob/master/icbi.py)
|
||||
|
||||
### Image-Base
|
||||
- [Chroma Zeta](https://huggingface.co/lodestones/Zeta-Chroma): Image and video generator for creative effects and professional filters
|
||||
- [Chroma Radiance](https://huggingface.co/lodestones/Chroma1-Radiance): Pixel-space model eliminating VAE artifacts for high visual fidelity
|
||||
- [Liquid](https://github.com/FoundationVision/Liquid): Unified vision-language auto-regressive generation paradigm
|
||||
- [Lumina-DiMOO](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO): Foundational multi-modal generation and understanding via discrete diffusion
|
||||
- [nVidia Cosmos-Predict-2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B): Physics-aware world foundation model for consistent scene prediction
|
||||
- [Liquid (unified multimodal generator)](https://github.com/FoundationVision/Liquid): Auto-regressive generation paradigm across vision and language
|
||||
- [Lumina-DiMOO](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO): foundational multi-modal multi-task generation and understanding
|
||||
|
||||
### Image-Edit
|
||||
- [Meituan LongCat-Image-Edit-Turbo](https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo):6B instruction-following image editing with high visual consistency
|
||||
- [VIBE Image-Edit](https://huggingface.co/iitolstykh/VIBE-Image-Edit): (Sana+Qwen-VL)Fast visual instruction-based image editing framework
|
||||
- [LucyEdit](https://github.com/huggingface/diffusers/pull/12340):Instruction-guided video editing while preserving motion and identity
|
||||
- [Step1X-Edit](https://github.com/stepfun-ai/Step1X-Edit):Multimodal image editing decoding MLLM tokens via DiT
|
||||
- [OneReward](https://github.com/bytedance/OneReward):Reinforcement learning grounded generative reward model for image editing
|
||||
- [ByteDance DreamO](https://huggingface.co/ByteDance/DreamO): image customization framework for IP adaptation and virtual try-on
|
||||
|
||||
### Video
|
||||
- [OpenMOSS MOVA](https://huggingface.co/OpenMOSS-Team/MOVA-720p): Unified foundation model for synchronized high-fidelity video and audio
|
||||
- [Wan family (Wan2.1 / Wan2.2 variants)](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B): MoE-based foundational tools for cinematic T2V/I2V/TI2V
|
||||
example: [Wan2.1-T2V-14B-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
|
||||
distill / step-distill examples: [Wan2.1-StepDistill-CfgDistill](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill)
|
||||
- [Krea Realtime Video](https://huggingface.co/krea/krea-realtime-video): (Wan2.1)Distilled real-time video diffusion using self-forcing techniques
|
||||
- [MAGI-1 (autoregressive video)](https://github.com/SandAI-org/MAGI-1): Autoregressive video generation allowing infinite and timeline control
|
||||
- [MUG-V 10B (video generation)](https://huggingface.co/MUG-V/MUG-V-inference): large-scale DiT-based video generation system trained via flow-matching
|
||||
- [Ovi (audio/video generation)](https://github.com/character-ai/Ovi): (Wan2.2)Speech-to-video with synchronized sound effects and music
|
||||
- [HunyuanVideo-Avatar / HunyuanCustom](https://huggingface.co/tencent/HunyuanVideo-Avatar): (HunyuanVideo)MM-DiT based dynamic emotion-controllable dialogue generation
|
||||
- [Sana Image→Video (Sana-I2V)](https://github.com/huggingface/diffusers/pull/12634#issuecomment-3540534268): (Sana)Compact Linear DiT framework for efficient high-resolution video
|
||||
- [Wan-2.2 S2V (diffusers PR)](https://github.com/huggingface/diffusers/pull/12258): (Wan2.2)Audio-driven cinematic speech-to-video generation
|
||||
- [LongCat-Video](https://huggingface.co/meituan-longcat/LongCat-Video): Unified framework for minutes-long coherent video generation via Block Sparse Attention
|
||||
- [LTXVideo / LTXVideo LongMulti (diffusers PR)](https://github.com/huggingface/diffusers/pull/12614): Real-time DiT-based generation with production-ready camera controls
|
||||
- [DiffSynth-Studio (ModelScope)](https://github.com/modelscope/DiffSynth-Studio): (Wan2.2)Comprehensive training and quantization tools for Wan video models
|
||||
- [Phantom (Phantom HuMo)](https://github.com/Phantom-video/Phantom): Human-centric video generation framework focus on subject ID consistency
|
||||
- [CausVid-Plus / WAN-CausVid-Plus](https://github.com/goatWu/CausVid-Plus/): (Wan2.1)Causal diffusion for high-quality temporally consistent long videos
|
||||
- [Wan2GP (workflow/GUI for Wan)](https://github.com/deepbeepmeep/Wan2GP): (Wan)Web-based UI focused on running complex video models for GPU-poor setups
|
||||
- [LivePortrait](https://github.com/KwaiVGI/LivePortrait): Efficient portrait animation system with high stitching and retargeting control
|
||||
- [Magi (SandAI)](https://github.com/SandAI-org/MAGI-1): High-quality autoregressive video generation framework
|
||||
- [Ming (inclusionAI)](https://github.com/inclusionAI/Ming): Unified multimodal model for processing text, audio, image, and video
|
||||
|
||||
### Other/Unsorted
|
||||
- [DiffusionForcing](https://github.com/kwsong0113/diffusion-forcing-transformer): Full-sequence diffusion with autoregressive next-token prediction
|
||||
- [Self-Forcing](https://github.com/guandeh17/Self-Forcing): Framework for improving temporal consistency in long-horizon video generation
|
||||
- [SEVA](https://github.com/huggingface/diffusers/pull/11440): Stable Virtual Camera for novel view synthesis and 3D-consistent video
|
||||
- [ByteDance USO](https://github.com/bytedance/USO): Unified Style-Subject Optimized framework for personalized image generation
|
||||
- [ByteDance Lynx](https://github.com/bytedance/lynx): State-of-the-art high-fidelity personalized video generation based on DiT
|
||||
- [LanDiff](https://github.com/landiff/landiff): Coarse-to-fine text-to-video integrating Language and Diffusion Models
|
||||
- [Video Inpaint Pipeline](https://github.com/huggingface/diffusers/pull/12506): Unified inpainting pipeline implementation within Diffusers library
|
||||
- [Sonic Inpaint](https://github.com/ubc-vision/sonic): Audio-driven portrait animation system focus on global audio perception
|
||||
- [Make-It-Count](https://github.com/Litalby1/make-it-count): CountGen method for precise numerical control of objects via object identity features
|
||||
- [ControlNeXt](https://github.com/dvlab-research/ControlNeXt/): Lightweight architecture for efficient controllable image and video generation
|
||||
- [MS-Diffusion](https://github.com/MS-Diffusion/MS-Diffusion): Layout-guided multi-subject image personalization framework
|
||||
- [UniRef](https://github.com/FoundationVision/UniRef): Unified model for segmentation tasks designed as foundation model plug-in
|
||||
- [FlashFace](https://github.com/ali-vilab/FlashFace): High-fidelity human image customization and face swapping framework
|
||||
- [ReNO](https://github.com/ExplainableML/ReNO): Reward-based Noise Optimization to improve text-to-image quality during inference
|
||||
|
||||
### Not Planned
|
||||
- [Bria FIBO](https://huggingface.co/briaai/FIBO): Fully JSON based
|
||||
- [Bria FiboEdit](https://github.com/huggingface/diffusers/commit/d7a1c31f4f85bae5a9e01cdce49bd7346bd8ccd6): Fully JSON based
|
||||
- [LoRAdapter](https://github.com/CompVis/LoRAdapter): Not recently updated
|
||||
- [SD3 UltraEdit](https://github.com/HaozheZhao/UltraEdit): Based on SD3
|
||||
- [PowerPaint](https://github.com/open-mmlab/PowerPaint): Based on SD15
|
||||
- [FreeCustom](https://github.com/aim-uofa/FreeCustom): Based on SD15
|
||||
- [AnyDoor](https://github.com/ali-vilab/AnyDoor): Based on SD21
|
||||
- [AnyText2](https://github.com/tyxsspa/AnyText2): Based on SD15
|
||||
- [DragonDiffusion](https://github.com/MC-E/DragonDiffusion): Based on SD15
|
||||
- [DenseDiffusion](https://github.com/naver-ai/DenseDiffusion): Based on SD15
|
||||
- [IC-Light](https://github.com/lllyasviel/IC-Light): Based on SD15
|
||||
|
||||
## Migration
|
||||
|
||||
### Asyncio
|
||||
|
||||
- Policy system is deprecated and will be removed in **Python 3.16**
|
||||
- [Python 3.14 removals - asyncio](https://docs.python.org/3.14/whatsnew/3.14.html#id10)
|
||||
- https://docs.python.org/3.14/library/asyncio-policy.html
|
||||
- Affected files:
|
||||
- [`webui.py`](webui.py)
|
||||
- [`cli/sdapi.py`](cli/sdapi.py)
|
||||
- Migration:
|
||||
- [asyncio.run](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.run)
|
||||
- [asyncio.Runner](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.Runner)
|
||||
- Policy system is deprecated and will be removed in Python 3.16
|
||||
[Python 3.14 removalsasyncio](https://docs.python.org/3.14/whatsnew/3.14.html#id10)
|
||||
https://docs.python.org/3.14/library/asyncio-policy.html
|
||||
Affected files:
|
||||
[`webui.py`](webui.py)
|
||||
[`cli/sdapi.py`](cli/sdapi.py)
|
||||
Migration:
|
||||
[asyncio.run](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.run)
|
||||
[asyncio.Runner](https://docs.python.org/3.14/library/asyncio-runner.html#asyncio.Runner)
|
||||
|
||||
#### rmtree
|
||||
### rmtree
|
||||
|
||||
- `onerror` deprecated and replaced with `onexc` in **Python 3.12**
|
||||
- `onerror` deprecated and replaced with `onexc` in Python 3.12
|
||||
``` python
|
||||
def excRemoveReadonly(func, path, exc: BaseException):
|
||||
import stat
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
get list of all samplers and details of current sampler
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import urllib3
|
||||
import requests
|
||||
|
||||
|
||||
url = "http://127.0.0.1:7860"
|
||||
user = ""
|
||||
password = ""
|
||||
|
||||
log_format = '%(asctime)s %(levelname)s: %(message)s'
|
||||
logging.basicConfig(level = logging.INFO, format = log_format)
|
||||
log = logging.getLogger("sd")
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
log.info('available samplers')
|
||||
auth = requests.auth.HTTPBasicAuth(user, password) if len(user) > 0 and len(password) > 0 else None
|
||||
req = requests.get(f'{url}/sdapi/v1/samplers', verify=False, auth=auth, timeout=60)
|
||||
if req.status_code != 200:
|
||||
log.error({ 'url': req.url, 'request': req.status_code, 'reason': req.reason })
|
||||
exit(1)
|
||||
res = req.json()
|
||||
for item in res:
|
||||
log.info(item)
|
||||
|
||||
log.info('current sampler')
|
||||
req = requests.get(f'{url}/sdapi/v1/sampler', verify=False, auth=auth, timeout=60)
|
||||
res = req.json()
|
||||
log.info(res)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
|
||||
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
||||
sd_username = os.environ.get('SDAPI_USR', None)
|
||||
sd_password = os.environ.get('SDAPI_PWD', None)
|
||||
options = {
|
||||
"save_images": True,
|
||||
"send_images": True,
|
||||
}
|
||||
|
||||
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
def auth():
|
||||
if sd_username is not None and sd_password is not None:
|
||||
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
||||
return None
|
||||
|
||||
|
||||
def get(endpoint: str, dct: dict = None):
|
||||
req = requests.get(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
||||
if req.status_code != 200:
|
||||
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
||||
else:
|
||||
return req.json()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
options = get('/sdapi/v1/xyz-grid')
|
||||
log.info(f'api-xyzgrid-options: {len(options)}')
|
||||
for option in options:
|
||||
log.info(f' {option}')
|
||||
details = get('/sdapi/v1/xyz-grid?option=upscaler')
|
||||
for choice in details[0]['choices']:
|
||||
log.info(f' {choice}')
|
||||
|
|
@ -0,0 +1,260 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Ensure we can import modules
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
|
||||
from modules.errors import log
|
||||
from modules.res4lyf import (
|
||||
BASE, SIMPLE, VARIANTS,
|
||||
RESUnifiedScheduler, RESMultistepScheduler, RESDEISMultistepScheduler,
|
||||
ETDRKScheduler, LawsonScheduler, ABNorsettScheduler, PECScheduler,
|
||||
RiemannianFlowScheduler, RESSinglestepScheduler, RESSinglestepSDEScheduler,
|
||||
RESMultistepSDEScheduler, SimpleExponentialScheduler, LinearRKScheduler,
|
||||
LobattoScheduler, GaussLegendreScheduler, RungeKutta44Scheduler,
|
||||
RungeKutta57Scheduler, RungeKutta67Scheduler, SpecializedRKScheduler,
|
||||
BongTangentScheduler, CommonSigmaScheduler, RadauIIAScheduler,
|
||||
LangevinDynamicsScheduler
|
||||
)
|
||||
from modules.schedulers.scheduler_vdm import VDMScheduler
|
||||
from modules.schedulers.scheduler_unipc_flowmatch import FlowUniPCMultistepScheduler
|
||||
from modules.schedulers.scheduler_ufogen import UFOGenScheduler
|
||||
from modules.schedulers.scheduler_tdd import TDDScheduler
|
||||
from modules.schedulers.scheduler_tcd import TCDScheduler
|
||||
from modules.schedulers.scheduler_flashflow import FlashFlowMatchEulerDiscreteScheduler
|
||||
from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler
|
||||
from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler
|
||||
from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler
|
||||
|
||||
def test_scheduler(name, scheduler_class, config):
|
||||
try:
|
||||
scheduler = scheduler_class(**config)
|
||||
except Exception as e:
|
||||
log.error(f'scheduler="{name}" cls={scheduler_class} config={config} error="Init failed: {e}"')
|
||||
return False
|
||||
|
||||
num_steps = 20
|
||||
scheduler.set_timesteps(num_steps)
|
||||
|
||||
sample = torch.randn((1, 4, 64, 64))
|
||||
has_changed = False
|
||||
t0 = time.time()
|
||||
messages = []
|
||||
|
||||
try:
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
# Simulate model output (noise or x0 or v), Using random noise for stability check
|
||||
model_output = torch.randn_like(sample)
|
||||
|
||||
# Scaling Check
|
||||
step_idx = scheduler.step_index if hasattr(scheduler, "step_index") and scheduler.step_index is not None else i
|
||||
# Clamp index
|
||||
if hasattr(scheduler, 'sigmas'):
|
||||
step_idx = min(step_idx, len(scheduler.sigmas) - 1)
|
||||
sigma = scheduler.sigmas[step_idx]
|
||||
else:
|
||||
sigma = torch.tensor(1.0) # Dummy for non-sigma schedulers
|
||||
|
||||
# Re-introduce scaling calculation first
|
||||
scaled_sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
if config.get("prediction_type") == "flow_prediction" or name in ["UFOGenScheduler", "TDDScheduler", "TCDScheduler", "BDIA_DDIMScheduler", "DCSolverMultistepScheduler"]:
|
||||
# Some new schedulers don't use K-diffusion scaling
|
||||
expected_scale = 1.0
|
||||
else:
|
||||
expected_scale = 1.0 / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# Simple check with loose tolerance due to float precision
|
||||
expected_scaled_sample = sample * expected_scale
|
||||
if not torch.allclose(scaled_sample, expected_scaled_sample, atol=1e-4):
|
||||
# If failed, double check if it's just 'sample' (no scaling)
|
||||
if torch.allclose(scaled_sample, sample, atol=1e-4):
|
||||
messages.append('warning="scaling is identity"')
|
||||
else:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} expected={expected_scale} error="scaling mismatch"')
|
||||
return False
|
||||
|
||||
if torch.isnan(scaled_sample).any():
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="NaN in scaled_sample"')
|
||||
return False
|
||||
|
||||
if torch.isinf(scaled_sample).any():
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Inf in scaled_sample"')
|
||||
return False
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
|
||||
# Shape and Dtype check
|
||||
if output.prev_sample.shape != sample.shape:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Shape mismatch: {output.prev_sample.shape} vs {sample.shape}"')
|
||||
return False
|
||||
if output.prev_sample.dtype != sample.dtype:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Dtype mismatch: {output.prev_sample.dtype} vs {sample.dtype}"')
|
||||
return False
|
||||
|
||||
# Update check: Did the sample change?
|
||||
if not torch.equal(sample, output.prev_sample):
|
||||
has_changed = True
|
||||
|
||||
# Sample Evolution Check
|
||||
step_diff = (sample - output.prev_sample).abs().mean().item()
|
||||
if step_diff < 1e-6:
|
||||
messages.append(f'warning="minimal sample change: {step_diff}"')
|
||||
|
||||
sample = output.prev_sample
|
||||
|
||||
if torch.isnan(sample).any():
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="NaN in sample"')
|
||||
return False
|
||||
|
||||
if torch.isinf(sample).any():
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="Inf in sample"')
|
||||
return False
|
||||
|
||||
# Divergence check
|
||||
if sample.abs().max() > 1e10:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} error="divergence detected"')
|
||||
return False
|
||||
|
||||
# External check for Sigma Monotonicity
|
||||
if hasattr(scheduler, 'sigmas'):
|
||||
sigmas = scheduler.sigmas.cpu().numpy()
|
||||
if len(sigmas) > 1:
|
||||
diffs = np.diff(sigmas) # Check if potentially monotonic decreasing (standard) OR increasing (some flow/inverse setups). We allow flat sections (diff=0) hence 1e-6 slack
|
||||
is_monotonic_decreasing = np.all(diffs <= 1e-6)
|
||||
is_monotonic_increasing = np.all(diffs >= -1e-6)
|
||||
if not (is_monotonic_decreasing or is_monotonic_increasing):
|
||||
messages.append('warning="sigmas are not monotonic"')
|
||||
|
||||
except Exception as e:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} exception: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if not has_changed:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} error="sample never changed"')
|
||||
return False
|
||||
|
||||
final_std = sample.std().item()
|
||||
if final_std > 50.0 or final_std < 0.1:
|
||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} std={final_std} error="variance drift"')
|
||||
|
||||
t1 = time.time()
|
||||
messages = list(set(messages))
|
||||
log.info(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} time={t1-t0} messages={messages}')
|
||||
return True
|
||||
|
||||
def run_tests():
|
||||
prediction_types = ["epsilon", "v_prediction", "sample"] # flow_prediction is special, usually requires flow sigmas or specific setup, checking standard ones first
|
||||
|
||||
# Test BASE schedulers with their specific parameters
|
||||
log.warning('type="base"')
|
||||
for name, cls in BASE:
|
||||
configs = []
|
||||
|
||||
# prediction_types
|
||||
for pt in prediction_types:
|
||||
configs.append({"prediction_type": pt})
|
||||
|
||||
# Specific params for specific classes
|
||||
if cls == RESUnifiedScheduler:
|
||||
rk_types = ["res_2m", "res_3m", "res_2s", "res_3s", "res_5s", "res_6s", "deis_1s", "deis_2m", "deis_3m"]
|
||||
for rk in rk_types:
|
||||
for pt in prediction_types:
|
||||
configs.append({"rk_type": rk, "prediction_type": pt})
|
||||
|
||||
elif cls == RESMultistepScheduler:
|
||||
variants = ["res_2m", "res_3m", "deis_2m", "deis_3m"]
|
||||
for v in variants:
|
||||
for pt in prediction_types:
|
||||
configs.append({"variant": v, "prediction_type": pt})
|
||||
|
||||
elif cls == RESDEISMultistepScheduler:
|
||||
for order in range(1, 6):
|
||||
for pt in prediction_types:
|
||||
configs.append({"solver_order": order, "prediction_type": pt})
|
||||
|
||||
elif cls == ETDRKScheduler:
|
||||
variants = ["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"]
|
||||
for v in variants:
|
||||
for pt in prediction_types:
|
||||
configs.append({"variant": v, "prediction_type": pt})
|
||||
|
||||
elif cls == LawsonScheduler:
|
||||
variants = ["lawson2a_2s", "lawson2b_2s", "lawson4_4s"]
|
||||
for v in variants:
|
||||
for pt in prediction_types:
|
||||
configs.append({"variant": v, "prediction_type": pt})
|
||||
|
||||
elif cls == ABNorsettScheduler:
|
||||
variants = ["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"]
|
||||
for v in variants:
|
||||
for pt in prediction_types:
|
||||
configs.append({"variant": v, "prediction_type": pt})
|
||||
|
||||
elif cls == PECScheduler:
|
||||
variants = ["pec423_2h2s", "pec433_2h3s"]
|
||||
for v in variants:
|
||||
for pt in prediction_types:
|
||||
configs.append({"variant": v, "prediction_type": pt})
|
||||
|
||||
elif cls == RiemannianFlowScheduler:
|
||||
metrics = ["euclidean", "hyperbolic", "spherical", "lorentzian"]
|
||||
for m in metrics:
|
||||
configs.append({"metric_type": m, "prediction_type": "epsilon"}) # Flow usually uses v or raw, but epsilon check matches others
|
||||
|
||||
if not configs:
|
||||
for pt in prediction_types:
|
||||
configs.append({"prediction_type": pt})
|
||||
|
||||
for conf in configs:
|
||||
test_scheduler(name, cls, conf)
|
||||
|
||||
log.warning('type="simple"')
|
||||
for name, cls in SIMPLE:
|
||||
for pt in prediction_types:
|
||||
test_scheduler(name, cls, {"prediction_type": pt})
|
||||
|
||||
log.warning('type="variants"')
|
||||
for name, cls in VARIANTS:
|
||||
# these classes preset their variants/rk_types in __init__ so we just test prediction types
|
||||
for pt in prediction_types:
|
||||
test_scheduler(name, cls, {"prediction_type": pt})
|
||||
|
||||
# Extra robustness check: Flow Prediction Type
|
||||
log.warning('type="flow"')
|
||||
flow_schedulers = [
|
||||
# res4lyf schedulers
|
||||
RESUnifiedScheduler, RESMultistepScheduler, ABNorsettScheduler,
|
||||
RESSinglestepScheduler, RESSinglestepSDEScheduler, RESDEISMultistepScheduler,
|
||||
RESMultistepSDEScheduler, ETDRKScheduler, LawsonScheduler, PECScheduler,
|
||||
SimpleExponentialScheduler, LinearRKScheduler, LobattoScheduler,
|
||||
GaussLegendreScheduler, RungeKutta44Scheduler, RungeKutta57Scheduler,
|
||||
RungeKutta67Scheduler, SpecializedRKScheduler, BongTangentScheduler,
|
||||
CommonSigmaScheduler, RadauIIAScheduler, LangevinDynamicsScheduler,
|
||||
RiemannianFlowScheduler,
|
||||
# sdnext schedulers
|
||||
FlowUniPCMultistepScheduler, FlashFlowMatchEulerDiscreteScheduler, FlowMatchDPMSolverMultistepScheduler,
|
||||
]
|
||||
for cls in flow_schedulers:
|
||||
test_scheduler(cls.__name__, cls, {"prediction_type": "flow_prediction", "use_flow_sigmas": True})
|
||||
|
||||
log.warning('type="sdnext"')
|
||||
extended_schedulers = [
|
||||
VDMScheduler,
|
||||
UFOGenScheduler,
|
||||
TDDScheduler,
|
||||
TCDScheduler,
|
||||
DCSolverMultistepScheduler,
|
||||
BDIA_DDIMScheduler
|
||||
]
|
||||
for prediction_type in ["epsilon", "v_prediction", "sample"]:
|
||||
for cls in extended_schedulers:
|
||||
test_scheduler(cls.__name__, cls, {"prediction_type": prediction_type})
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -0,0 +1,847 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Tagger Settings Test Suite
|
||||
|
||||
Tests all WaifuDiffusion and DeepBooru tagger settings to verify they're properly
|
||||
mapped and affect output correctly.
|
||||
|
||||
Usage:
|
||||
python cli/test-tagger.py [image_path]
|
||||
|
||||
If no image path is provided, uses a built-in test image.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Add parent directory to path for imports
|
||||
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, script_dir)
|
||||
os.chdir(script_dir)
|
||||
|
||||
# Suppress installer output during import
|
||||
os.environ['SD_INSTALL_QUIET'] = '1'
|
||||
|
||||
# Initialize cmd_args properly with all argument groups
|
||||
import modules.cmd_args
|
||||
import installer
|
||||
|
||||
# Add installer args to the parser
|
||||
installer.add_args(modules.cmd_args.parser)
|
||||
|
||||
# Parse with empty args to get defaults
|
||||
modules.cmd_args.parsed, _ = modules.cmd_args.parser.parse_known_args([])
|
||||
|
||||
# Now we can safely import modules that depend on cmd_args
|
||||
|
||||
|
||||
# Default test images (in order of preference)
|
||||
DEFAULT_TEST_IMAGES = [
|
||||
'html/sdnext-robot-2k.jpg', # SD.Next robot mascot
|
||||
'venv/lib/python3.13/site-packages/gradio/test_data/lion.jpg',
|
||||
'venv/lib/python3.13/site-packages/gradio/test_data/cheetah1.jpg',
|
||||
'venv/lib/python3.13/site-packages/skimage/data/astronaut.png',
|
||||
'venv/lib/python3.13/site-packages/skimage/data/coffee.png',
|
||||
]
|
||||
|
||||
|
||||
def find_test_image():
|
||||
"""Find a suitable test image from defaults."""
|
||||
for img_path in DEFAULT_TEST_IMAGES:
|
||||
full_path = os.path.join(script_dir, img_path)
|
||||
if os.path.exists(full_path):
|
||||
return full_path
|
||||
return None
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image as fallback."""
|
||||
from PIL import Image, ImageDraw
|
||||
img = Image.new('RGB', (512, 512), color=(200, 150, 100))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.ellipse([100, 100, 400, 400], fill=(255, 200, 150), outline=(100, 50, 0))
|
||||
draw.rectangle([150, 200, 350, 350], fill=(150, 100, 200))
|
||||
return img
|
||||
|
||||
|
||||
class TaggerTest:
|
||||
"""Test harness for tagger settings."""
|
||||
|
||||
def __init__(self):
|
||||
self.results = {'passed': [], 'failed': [], 'skipped': []}
|
||||
self.test_image = None
|
||||
self.waifudiffusion_loaded = False
|
||||
self.deepbooru_loaded = False
|
||||
|
||||
def log_pass(self, msg):
|
||||
print(f" [PASS] {msg}")
|
||||
self.results['passed'].append(msg)
|
||||
|
||||
def log_fail(self, msg):
|
||||
print(f" [FAIL] {msg}")
|
||||
self.results['failed'].append(msg)
|
||||
|
||||
def log_skip(self, msg):
|
||||
print(f" [SKIP] {msg}")
|
||||
self.results['skipped'].append(msg)
|
||||
|
||||
def log_warn(self, msg):
|
||||
print(f" [WARN] {msg}")
|
||||
self.results['skipped'].append(msg)
|
||||
|
||||
def setup(self):
|
||||
"""Load test image and models."""
|
||||
from PIL import Image
|
||||
|
||||
print("=" * 70)
|
||||
print("TAGGER SETTINGS TEST SUITE")
|
||||
print("=" * 70)
|
||||
|
||||
# Get or create test image
|
||||
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
|
||||
img_path = sys.argv[1]
|
||||
print(f"\nUsing provided image: {img_path}")
|
||||
self.test_image = Image.open(img_path).convert('RGB')
|
||||
else:
|
||||
img_path = find_test_image()
|
||||
if img_path:
|
||||
print(f"\nUsing default test image: {img_path}")
|
||||
self.test_image = Image.open(img_path).convert('RGB')
|
||||
else:
|
||||
print("\nNo test image found, creating synthetic image...")
|
||||
self.test_image = create_test_image()
|
||||
|
||||
print(f"Image size: {self.test_image.size}")
|
||||
|
||||
# Load models
|
||||
print("\nLoading models...")
|
||||
from modules.interrogate import waifudiffusion, deepbooru
|
||||
|
||||
t0 = time.time()
|
||||
self.waifudiffusion_loaded = waifudiffusion.load_model()
|
||||
print(f" WaifuDiffusion: {'loaded' if self.waifudiffusion_loaded else 'FAILED'} ({time.time()-t0:.1f}s)")
|
||||
|
||||
t0 = time.time()
|
||||
self.deepbooru_loaded = deepbooru.load_model()
|
||||
print(f" DeepBooru: {'loaded' if self.deepbooru_loaded else 'FAILED'} ({time.time()-t0:.1f}s)")
|
||||
|
||||
def cleanup(self):
|
||||
"""Unload models and free memory."""
|
||||
print("\n" + "=" * 70)
|
||||
print("CLEANUP")
|
||||
print("=" * 70)
|
||||
|
||||
from modules.interrogate import waifudiffusion, deepbooru
|
||||
from modules import devices
|
||||
|
||||
waifudiffusion.unload_model()
|
||||
deepbooru.unload_model()
|
||||
devices.torch_gc(force=True)
|
||||
print(" Models unloaded")
|
||||
|
||||
def print_summary(self):
|
||||
"""Print test summary."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\n PASSED: {len(self.results['passed'])}")
|
||||
for item in self.results['passed']:
|
||||
print(f" - {item}")
|
||||
|
||||
print(f"\n FAILED: {len(self.results['failed'])}")
|
||||
for item in self.results['failed']:
|
||||
print(f" - {item}")
|
||||
|
||||
print(f"\n SKIPPED: {len(self.results['skipped'])}")
|
||||
for item in self.results['skipped']:
|
||||
print(f" - {item}")
|
||||
|
||||
total = len(self.results['passed']) + len(self.results['failed'])
|
||||
if total > 0:
|
||||
success_rate = len(self.results['passed']) / total * 100
|
||||
print(f"\n SUCCESS RATE: {success_rate:.1f}% ({len(self.results['passed'])}/{total})")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: ONNX Providers Detection
|
||||
# =========================================================================
|
||||
def test_onnx_providers(self):
|
||||
"""Verify ONNX runtime providers are properly detected."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: ONNX Providers Detection")
|
||||
print("=" * 70)
|
||||
|
||||
from modules import devices
|
||||
|
||||
# Test 1: onnxruntime can be imported
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
self.log_pass(f"onnxruntime imported: version={ort.__version__}")
|
||||
except ImportError as e:
|
||||
self.log_fail(f"onnxruntime import failed: {e}")
|
||||
return
|
||||
|
||||
# Test 2: Get available providers
|
||||
available = ort.get_available_providers()
|
||||
if available and len(available) > 0:
|
||||
self.log_pass(f"Available providers: {available}")
|
||||
else:
|
||||
self.log_fail("No ONNX providers available")
|
||||
return
|
||||
|
||||
# Test 3: devices.onnx is properly configured
|
||||
if devices.onnx is not None and len(devices.onnx) > 0:
|
||||
self.log_pass(f"devices.onnx configured: {devices.onnx}")
|
||||
else:
|
||||
self.log_fail(f"devices.onnx not configured: {devices.onnx}")
|
||||
|
||||
# Test 4: Configured providers exist in available providers
|
||||
for provider in devices.onnx:
|
||||
if provider in available:
|
||||
self.log_pass(f"Provider '{provider}' is available")
|
||||
else:
|
||||
self.log_fail(f"Provider '{provider}' configured but not available")
|
||||
|
||||
# Test 5: If WaifuDiffusion loaded, check session providers
|
||||
if self.waifudiffusion_loaded:
|
||||
from modules.interrogate import waifudiffusion
|
||||
if waifudiffusion.tagger.session is not None:
|
||||
session_providers = waifudiffusion.tagger.session.get_providers()
|
||||
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
|
||||
else:
|
||||
self.log_skip("WaifuDiffusion session not initialized")
|
||||
|
||||
# =========================================================================
|
||||
# TEST: Memory Management (Offload/Reload/Unload)
|
||||
# =========================================================================
|
||||
def get_memory_stats(self):
|
||||
"""Get current GPU and CPU memory usage."""
|
||||
import torch
|
||||
|
||||
stats = {}
|
||||
|
||||
# GPU memory (if CUDA available)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
stats['gpu_allocated'] = torch.cuda.memory_allocated() / 1024 / 1024 # MB
|
||||
stats['gpu_reserved'] = torch.cuda.memory_reserved() / 1024 / 1024 # MB
|
||||
else:
|
||||
stats['gpu_allocated'] = 0
|
||||
stats['gpu_reserved'] = 0
|
||||
|
||||
# CPU/RAM memory (try psutil, fallback to basic)
|
||||
try:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
stats['ram_used'] = process.memory_info().rss / 1024 / 1024 # MB
|
||||
except ImportError:
|
||||
stats['ram_used'] = 0
|
||||
|
||||
return stats
|
||||
|
||||
def test_memory_management(self):
|
||||
"""Test model offload to RAM, reload to GPU, and unload with memory monitoring."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: Memory Management (Offload/Reload/Unload)")
|
||||
print("=" * 70)
|
||||
|
||||
import torch
|
||||
import gc
|
||||
from modules import devices
|
||||
from modules.interrogate import waifudiffusion, deepbooru
|
||||
|
||||
# Memory leak tolerance (MB) - some variance is expected
|
||||
GPU_LEAK_TOLERANCE_MB = 50
|
||||
RAM_LEAK_TOLERANCE_MB = 200
|
||||
|
||||
# =====================================================================
|
||||
# DeepBooru: Test GPU/CPU movement with memory monitoring
|
||||
# =====================================================================
|
||||
if self.deepbooru_loaded:
|
||||
print("\n DeepBooru Memory Management:")
|
||||
|
||||
# Baseline memory before any operations
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
baseline = self.get_memory_stats()
|
||||
print(f" Baseline: GPU={baseline['gpu_allocated']:.1f}MB, RAM={baseline['ram_used']:.1f}MB")
|
||||
|
||||
# Test 1: Check initial state (should be on CPU after load)
|
||||
initial_device = next(deepbooru.model.model.parameters()).device
|
||||
print(f" Initial device: {initial_device}")
|
||||
if initial_device.type == 'cpu':
|
||||
self.log_pass("DeepBooru: initial state on CPU")
|
||||
else:
|
||||
self.log_pass(f"DeepBooru: initial state on {initial_device}")
|
||||
|
||||
# Test 2: Move to GPU (start)
|
||||
deepbooru.model.start()
|
||||
gpu_device = next(deepbooru.model.model.parameters()).device
|
||||
after_gpu = self.get_memory_stats()
|
||||
print(f" After start(): {gpu_device} | GPU={after_gpu['gpu_allocated']:.1f}MB (+{after_gpu['gpu_allocated']-baseline['gpu_allocated']:.1f}MB)")
|
||||
if gpu_device.type == devices.device.type:
|
||||
self.log_pass(f"DeepBooru: moved to GPU ({gpu_device})")
|
||||
else:
|
||||
self.log_fail(f"DeepBooru: failed to move to GPU, got {gpu_device}")
|
||||
|
||||
# Test 3: Run inference while on GPU
|
||||
try:
|
||||
tags = deepbooru.model.tag_multi(self.test_image, max_tags=3)
|
||||
after_infer = self.get_memory_stats()
|
||||
print(f" After inference: GPU={after_infer['gpu_allocated']:.1f}MB")
|
||||
if tags:
|
||||
self.log_pass(f"DeepBooru: inference on GPU works ({tags[:30]}...)")
|
||||
else:
|
||||
self.log_fail("DeepBooru: inference on GPU returned empty")
|
||||
except Exception as e:
|
||||
self.log_fail(f"DeepBooru: inference on GPU failed: {e}")
|
||||
|
||||
# Test 4: Offload to CPU (stop)
|
||||
deepbooru.model.stop()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
after_offload = self.get_memory_stats()
|
||||
cpu_device = next(deepbooru.model.model.parameters()).device
|
||||
print(f" After stop(): {cpu_device} | GPU={after_offload['gpu_allocated']:.1f}MB, RAM={after_offload['ram_used']:.1f}MB")
|
||||
if cpu_device.type == 'cpu':
|
||||
self.log_pass("DeepBooru: offloaded to CPU")
|
||||
else:
|
||||
self.log_fail(f"DeepBooru: failed to offload, still on {cpu_device}")
|
||||
|
||||
# Check GPU memory returned to near baseline after offload
|
||||
gpu_diff = after_offload['gpu_allocated'] - baseline['gpu_allocated']
|
||||
if gpu_diff <= GPU_LEAK_TOLERANCE_MB:
|
||||
self.log_pass(f"DeepBooru: GPU memory cleared after offload (diff={gpu_diff:.1f}MB)")
|
||||
else:
|
||||
self.log_fail(f"DeepBooru: GPU memory leak after offload (diff={gpu_diff:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
|
||||
|
||||
# Test 5: Full cycle - reload and run again
|
||||
deepbooru.model.start()
|
||||
try:
|
||||
tags = deepbooru.model.tag_multi(self.test_image, max_tags=3)
|
||||
if tags:
|
||||
self.log_pass("DeepBooru: reload cycle works")
|
||||
else:
|
||||
self.log_fail("DeepBooru: reload cycle returned empty")
|
||||
except Exception as e:
|
||||
self.log_fail(f"DeepBooru: reload cycle failed: {e}")
|
||||
deepbooru.model.stop()
|
||||
|
||||
# Test 6: Full unload with memory check
|
||||
deepbooru.unload_model()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
after_unload = self.get_memory_stats()
|
||||
print(f" After unload: GPU={after_unload['gpu_allocated']:.1f}MB, RAM={after_unload['ram_used']:.1f}MB")
|
||||
|
||||
if deepbooru.model.model is None:
|
||||
self.log_pass("DeepBooru: unload successful")
|
||||
else:
|
||||
self.log_fail("DeepBooru: unload failed, model still exists")
|
||||
|
||||
# Check for memory leaks after full unload
|
||||
gpu_leak = after_unload['gpu_allocated'] - baseline['gpu_allocated']
|
||||
ram_leak = after_unload['ram_used'] - baseline['ram_used']
|
||||
if gpu_leak <= GPU_LEAK_TOLERANCE_MB:
|
||||
self.log_pass(f"DeepBooru: no GPU memory leak after unload (diff={gpu_leak:.1f}MB)")
|
||||
else:
|
||||
self.log_fail(f"DeepBooru: GPU memory leak detected (diff={gpu_leak:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
|
||||
|
||||
if ram_leak <= RAM_LEAK_TOLERANCE_MB:
|
||||
self.log_pass(f"DeepBooru: no RAM leak after unload (diff={ram_leak:.1f}MB)")
|
||||
else:
|
||||
self.log_warn(f"DeepBooru: RAM increased after unload (diff={ram_leak:.1f}MB) - may be caching")
|
||||
|
||||
# Reload for remaining tests
|
||||
deepbooru.load_model()
|
||||
|
||||
# =====================================================================
|
||||
# WaifuDiffusion: Test session lifecycle with memory monitoring
|
||||
# =====================================================================
|
||||
if self.waifudiffusion_loaded:
|
||||
print("\n WaifuDiffusion Memory Management:")
|
||||
|
||||
# Baseline memory
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
baseline = self.get_memory_stats()
|
||||
print(f" Baseline: GPU={baseline['gpu_allocated']:.1f}MB, RAM={baseline['ram_used']:.1f}MB")
|
||||
|
||||
# Test 1: Session exists
|
||||
if waifudiffusion.tagger.session is not None:
|
||||
self.log_pass("WaifuDiffusion: session loaded")
|
||||
else:
|
||||
self.log_fail("WaifuDiffusion: session not loaded")
|
||||
return
|
||||
|
||||
# Test 2: Get current providers
|
||||
providers = waifudiffusion.tagger.session.get_providers()
|
||||
print(f" Active providers: {providers}")
|
||||
self.log_pass(f"WaifuDiffusion: using providers {providers}")
|
||||
|
||||
# Test 3: Run inference
|
||||
try:
|
||||
tags = waifudiffusion.tagger.predict(self.test_image, max_tags=3)
|
||||
after_infer = self.get_memory_stats()
|
||||
print(f" After inference: GPU={after_infer['gpu_allocated']:.1f}MB, RAM={after_infer['ram_used']:.1f}MB")
|
||||
if tags:
|
||||
self.log_pass(f"WaifuDiffusion: inference works ({tags[:30]}...)")
|
||||
else:
|
||||
self.log_fail("WaifuDiffusion: inference returned empty")
|
||||
except Exception as e:
|
||||
self.log_fail(f"WaifuDiffusion: inference failed: {e}")
|
||||
|
||||
# Test 4: Unload session with memory check
|
||||
model_name = waifudiffusion.tagger.model_name
|
||||
waifudiffusion.unload_model()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
after_unload = self.get_memory_stats()
|
||||
print(f" After unload: GPU={after_unload['gpu_allocated']:.1f}MB, RAM={after_unload['ram_used']:.1f}MB")
|
||||
|
||||
if waifudiffusion.tagger.session is None:
|
||||
self.log_pass("WaifuDiffusion: unload successful")
|
||||
else:
|
||||
self.log_fail("WaifuDiffusion: unload failed, session still exists")
|
||||
|
||||
# Check for memory leaks after unload
|
||||
gpu_leak = after_unload['gpu_allocated'] - baseline['gpu_allocated']
|
||||
ram_leak = after_unload['ram_used'] - baseline['ram_used']
|
||||
if gpu_leak <= GPU_LEAK_TOLERANCE_MB:
|
||||
self.log_pass(f"WaifuDiffusion: no GPU memory leak after unload (diff={gpu_leak:.1f}MB)")
|
||||
else:
|
||||
self.log_fail(f"WaifuDiffusion: GPU memory leak detected (diff={gpu_leak:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
|
||||
|
||||
if ram_leak <= RAM_LEAK_TOLERANCE_MB:
|
||||
self.log_pass(f"WaifuDiffusion: no RAM leak after unload (diff={ram_leak:.1f}MB)")
|
||||
else:
|
||||
self.log_warn(f"WaifuDiffusion: RAM increased after unload (diff={ram_leak:.1f}MB) - may be caching")
|
||||
|
||||
# Test 5: Reload session
|
||||
waifudiffusion.load_model(model_name)
|
||||
after_reload = self.get_memory_stats()
|
||||
print(f" After reload: GPU={after_reload['gpu_allocated']:.1f}MB, RAM={after_reload['ram_used']:.1f}MB")
|
||||
if waifudiffusion.tagger.session is not None:
|
||||
self.log_pass("WaifuDiffusion: reload successful")
|
||||
else:
|
||||
self.log_fail("WaifuDiffusion: reload failed")
|
||||
|
||||
# Test 6: Inference after reload
|
||||
try:
|
||||
tags = waifudiffusion.tagger.predict(self.test_image, max_tags=3)
|
||||
if tags:
|
||||
self.log_pass("WaifuDiffusion: inference after reload works")
|
||||
else:
|
||||
self.log_fail("WaifuDiffusion: inference after reload returned empty")
|
||||
except Exception as e:
|
||||
self.log_fail(f"WaifuDiffusion: inference after reload failed: {e}")
|
||||
|
||||
# Final memory check after full cycle
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
final = self.get_memory_stats()
|
||||
print(f" Final (after full cycle): GPU={final['gpu_allocated']:.1f}MB, RAM={final['ram_used']:.1f}MB")
|
||||
|
||||
# =========================================================================
|
||||
# TEST: Settings Existence
|
||||
# =========================================================================
|
||||
def test_settings_exist(self):
|
||||
"""Verify all tagger settings exist in shared.opts."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: Settings Existence")
|
||||
print("=" * 70)
|
||||
|
||||
from modules import shared
|
||||
|
||||
settings = [
|
||||
('tagger_threshold', float),
|
||||
('tagger_include_rating', bool),
|
||||
('tagger_max_tags', int),
|
||||
('tagger_sort_alpha', bool),
|
||||
('tagger_use_spaces', bool),
|
||||
('tagger_escape_brackets', bool),
|
||||
('tagger_exclude_tags', str),
|
||||
('tagger_show_scores', bool),
|
||||
('waifudiffusion_model', str),
|
||||
('waifudiffusion_character_threshold', float),
|
||||
('interrogate_offload', bool),
|
||||
]
|
||||
|
||||
for setting, _expected_type in settings:
|
||||
if hasattr(shared.opts, setting):
|
||||
value = getattr(shared.opts, setting)
|
||||
self.log_pass(f"{setting} = {value!r}")
|
||||
else:
|
||||
self.log_fail(f"{setting} - NOT FOUND")
|
||||
|
||||
# =========================================================================
|
||||
# TEST: Parameter Effect - Tests a single parameter on both taggers
|
||||
# =========================================================================
|
||||
def test_parameter(self, param_name, test_func, waifudiffusion_supported=True, deepbooru_supported=True):
|
||||
"""Test a parameter on both WaifuDiffusion and DeepBooru."""
|
||||
print(f"\n Testing: {param_name}")
|
||||
|
||||
if waifudiffusion_supported and self.waifudiffusion_loaded:
|
||||
try:
|
||||
result = test_func('waifudiffusion')
|
||||
if result is True:
|
||||
self.log_pass(f"WaifuDiffusion: {param_name}")
|
||||
elif result is False:
|
||||
self.log_fail(f"WaifuDiffusion: {param_name}")
|
||||
else:
|
||||
self.log_skip(f"WaifuDiffusion: {param_name} - {result}")
|
||||
except Exception as e:
|
||||
self.log_fail(f"WaifuDiffusion: {param_name} - {e}")
|
||||
elif waifudiffusion_supported:
|
||||
self.log_skip(f"WaifuDiffusion: {param_name} - model not loaded")
|
||||
|
||||
if deepbooru_supported and self.deepbooru_loaded:
|
||||
try:
|
||||
result = test_func('deepbooru')
|
||||
if result is True:
|
||||
self.log_pass(f"DeepBooru: {param_name}")
|
||||
elif result is False:
|
||||
self.log_fail(f"DeepBooru: {param_name}")
|
||||
else:
|
||||
self.log_skip(f"DeepBooru: {param_name} - {result}")
|
||||
except Exception as e:
|
||||
self.log_fail(f"DeepBooru: {param_name} - {e}")
|
||||
elif deepbooru_supported:
|
||||
self.log_skip(f"DeepBooru: {param_name} - model not loaded")
|
||||
|
||||
def tag(self, tagger, **kwargs):
|
||||
"""Helper to call the appropriate tagger."""
|
||||
if tagger == 'waifudiffusion':
|
||||
from modules.interrogate import waifudiffusion
|
||||
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
|
||||
else:
|
||||
from modules.interrogate import deepbooru
|
||||
return deepbooru.model.tag(self.test_image, **kwargs)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: general_threshold
|
||||
# =========================================================================
|
||||
def test_threshold(self):
|
||||
"""Test that threshold affects tag count."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: general_threshold effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_threshold(tagger):
|
||||
tags_high = self.tag(tagger, general_threshold=0.9)
|
||||
tags_low = self.tag(tagger, general_threshold=0.1)
|
||||
|
||||
count_high = len(tags_high.split(', ')) if tags_high else 0
|
||||
count_low = len(tags_low.split(', ')) if tags_low else 0
|
||||
|
||||
print(f" {tagger}: threshold=0.9 -> {count_high} tags, threshold=0.1 -> {count_low} tags")
|
||||
|
||||
if count_low > count_high:
|
||||
return True
|
||||
elif count_low == count_high == 0:
|
||||
return "no tags returned"
|
||||
else:
|
||||
return "threshold effect unclear"
|
||||
|
||||
self.test_parameter('general_threshold', check_threshold)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: max_tags
|
||||
# =========================================================================
|
||||
def test_max_tags(self):
|
||||
"""Test that max_tags limits output."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: max_tags effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_max_tags(tagger):
|
||||
tags_5 = self.tag(tagger, general_threshold=0.1, max_tags=5)
|
||||
tags_50 = self.tag(tagger, general_threshold=0.1, max_tags=50)
|
||||
|
||||
count_5 = len(tags_5.split(', ')) if tags_5 else 0
|
||||
count_50 = len(tags_50.split(', ')) if tags_50 else 0
|
||||
|
||||
print(f" {tagger}: max_tags=5 -> {count_5} tags, max_tags=50 -> {count_50} tags")
|
||||
|
||||
return count_5 <= 5
|
||||
|
||||
self.test_parameter('max_tags', check_max_tags)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: use_spaces
|
||||
# =========================================================================
|
||||
def test_use_spaces(self):
|
||||
"""Test that use_spaces converts underscores to spaces."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: use_spaces effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_use_spaces(tagger):
|
||||
tags_under = self.tag(tagger, use_spaces=False, max_tags=10)
|
||||
tags_space = self.tag(tagger, use_spaces=True, max_tags=10)
|
||||
|
||||
print(f" {tagger} use_spaces=False: {tags_under[:50]}...")
|
||||
print(f" {tagger} use_spaces=True: {tags_space[:50]}...")
|
||||
|
||||
# Check if underscores are converted to spaces
|
||||
has_underscore_before = '_' in tags_under
|
||||
has_underscore_after = '_' in tags_space.replace(', ', ',') # ignore comma-space
|
||||
|
||||
# If there were underscores before but not after, it worked
|
||||
if has_underscore_before and not has_underscore_after:
|
||||
return True
|
||||
# If there were never underscores, inconclusive
|
||||
elif not has_underscore_before:
|
||||
return "no underscores in tags to convert"
|
||||
else:
|
||||
return False
|
||||
|
||||
self.test_parameter('use_spaces', check_use_spaces)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: escape_brackets
|
||||
# =========================================================================
|
||||
def test_escape_brackets(self):
|
||||
"""Test that escape_brackets escapes special characters."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: escape_brackets effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_escape_brackets(tagger):
|
||||
tags_escaped = self.tag(tagger, escape_brackets=True, max_tags=30, general_threshold=0.1)
|
||||
tags_raw = self.tag(tagger, escape_brackets=False, max_tags=30, general_threshold=0.1)
|
||||
|
||||
print(f" {tagger} escape=True: {tags_escaped[:60]}...")
|
||||
print(f" {tagger} escape=False: {tags_raw[:60]}...")
|
||||
|
||||
# Check for escaped brackets (\\( or \\))
|
||||
has_escaped = '\\(' in tags_escaped or '\\)' in tags_escaped
|
||||
has_unescaped = '(' in tags_raw.replace('\\(', '') or ')' in tags_raw.replace('\\)', '')
|
||||
|
||||
if has_escaped:
|
||||
return True
|
||||
elif has_unescaped:
|
||||
# Has brackets but not escaped - fail
|
||||
return False
|
||||
else:
|
||||
return "no brackets in tags to escape"
|
||||
|
||||
self.test_parameter('escape_brackets', check_escape_brackets)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: sort_alpha
|
||||
# =========================================================================
|
||||
def test_sort_alpha(self):
|
||||
"""Test that sort_alpha sorts tags alphabetically."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: sort_alpha effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_sort_alpha(tagger):
|
||||
tags_conf = self.tag(tagger, sort_alpha=False, max_tags=20, general_threshold=0.1)
|
||||
tags_alpha = self.tag(tagger, sort_alpha=True, max_tags=20, general_threshold=0.1)
|
||||
|
||||
list_conf = [t.strip() for t in tags_conf.split(',')]
|
||||
list_alpha = [t.strip() for t in tags_alpha.split(',')]
|
||||
|
||||
print(f" {tagger} by_confidence: {', '.join(list_conf[:5])}...")
|
||||
print(f" {tagger} alphabetical: {', '.join(list_alpha[:5])}...")
|
||||
|
||||
is_sorted = list_alpha == sorted(list_alpha)
|
||||
return is_sorted
|
||||
|
||||
self.test_parameter('sort_alpha', check_sort_alpha)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: exclude_tags
|
||||
# =========================================================================
|
||||
def test_exclude_tags(self):
|
||||
"""Test that exclude_tags removes specified tags."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: exclude_tags effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_exclude_tags(tagger):
|
||||
tags_all = self.tag(tagger, max_tags=50, general_threshold=0.1, exclude_tags='')
|
||||
tag_list = [t.strip().replace(' ', '_') for t in tags_all.split(',')]
|
||||
|
||||
if len(tag_list) < 2:
|
||||
return "not enough tags to test"
|
||||
|
||||
# Exclude the first tag
|
||||
tag_to_exclude = tag_list[0]
|
||||
tags_filtered = self.tag(tagger, max_tags=50, general_threshold=0.1, exclude_tags=tag_to_exclude)
|
||||
|
||||
print(f" {tagger} without exclusion: {tags_all[:50]}...")
|
||||
print(f" {tagger} excluding '{tag_to_exclude}': {tags_filtered[:50]}...")
|
||||
|
||||
# Check if the exact tag was removed by parsing the filtered list
|
||||
filtered_list = [t.strip().replace(' ', '_') for t in tags_filtered.split(',')]
|
||||
# Also check space variant
|
||||
tag_space_variant = tag_to_exclude.replace('_', ' ')
|
||||
tag_present = tag_to_exclude in filtered_list or tag_space_variant in [t.strip() for t in tags_filtered.split(',')]
|
||||
return not tag_present
|
||||
|
||||
self.test_parameter('exclude_tags', check_exclude_tags)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: tagger_show_scores (via shared.opts)
|
||||
# =========================================================================
|
||||
def test_show_scores(self):
|
||||
"""Test that tagger_show_scores adds confidence scores."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: tagger_show_scores effect")
|
||||
print("=" * 70)
|
||||
|
||||
from modules import shared
|
||||
|
||||
def check_show_scores(tagger):
|
||||
original = shared.opts.tagger_show_scores
|
||||
|
||||
shared.opts.tagger_show_scores = False
|
||||
tags_no_scores = self.tag(tagger, max_tags=5)
|
||||
|
||||
shared.opts.tagger_show_scores = True
|
||||
tags_with_scores = self.tag(tagger, max_tags=5)
|
||||
|
||||
shared.opts.tagger_show_scores = original
|
||||
|
||||
print(f" {tagger} show_scores=False: {tags_no_scores[:50]}...")
|
||||
print(f" {tagger} show_scores=True: {tags_with_scores[:50]}...")
|
||||
|
||||
has_scores = ':' in tags_with_scores and '(' in tags_with_scores
|
||||
no_scores = ':' not in tags_no_scores
|
||||
|
||||
return has_scores and no_scores
|
||||
|
||||
self.test_parameter('tagger_show_scores', check_show_scores)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: include_rating
|
||||
# =========================================================================
|
||||
def test_include_rating(self):
|
||||
"""Test that include_rating includes/excludes rating tags."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: include_rating effect")
|
||||
print("=" * 70)
|
||||
|
||||
def check_include_rating(tagger):
|
||||
tags_no_rating = self.tag(tagger, include_rating=False, max_tags=100, general_threshold=0.01)
|
||||
tags_with_rating = self.tag(tagger, include_rating=True, max_tags=100, general_threshold=0.01)
|
||||
|
||||
print(f" {tagger} include_rating=False: {tags_no_rating[:60]}...")
|
||||
print(f" {tagger} include_rating=True: {tags_with_rating[:60]}...")
|
||||
|
||||
# Rating tags typically start with "rating:" or are like "safe", "questionable", "explicit"
|
||||
rating_keywords = ['rating:', 'safe', 'questionable', 'explicit', 'general', 'sensitive']
|
||||
|
||||
has_rating_before = any(kw in tags_no_rating.lower() for kw in rating_keywords)
|
||||
has_rating_after = any(kw in tags_with_rating.lower() for kw in rating_keywords)
|
||||
|
||||
if has_rating_after and not has_rating_before:
|
||||
return True
|
||||
elif has_rating_after and has_rating_before:
|
||||
return "rating tags appear in both (may need very low threshold)"
|
||||
elif not has_rating_after:
|
||||
return "no rating tags detected"
|
||||
else:
|
||||
return False
|
||||
|
||||
self.test_parameter('include_rating', check_include_rating)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: character_threshold (WaifuDiffusion only)
|
||||
# =========================================================================
|
||||
def test_character_threshold(self):
|
||||
"""Test that character_threshold affects character tag count (WaifuDiffusion only)."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: character_threshold effect (WaifuDiffusion only)")
|
||||
print("=" * 70)
|
||||
|
||||
def check_character_threshold(tagger):
|
||||
if tagger != 'waifudiffusion':
|
||||
return "not supported"
|
||||
|
||||
# Character threshold only affects character tags
|
||||
# We need an image with character tags to properly test this
|
||||
tags_high = self.tag(tagger, character_threshold=0.99, general_threshold=0.5)
|
||||
tags_low = self.tag(tagger, character_threshold=0.1, general_threshold=0.5)
|
||||
|
||||
print(f" {tagger} char_threshold=0.99: {tags_high[:50]}...")
|
||||
print(f" {tagger} char_threshold=0.10: {tags_low[:50]}...")
|
||||
|
||||
# If thresholds are different, the setting is at least being applied
|
||||
# Hard to verify without an image with known character tags
|
||||
return True # Setting exists and is applied (verified by code inspection)
|
||||
|
||||
self.test_parameter('character_threshold', check_character_threshold, deepbooru_supported=False)
|
||||
|
||||
# =========================================================================
|
||||
# TEST: Unified Interface
|
||||
# =========================================================================
|
||||
def test_unified_interface(self):
|
||||
"""Test that the unified tagger interface works for both backends."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST: Unified tagger.tag() interface")
|
||||
print("=" * 70)
|
||||
|
||||
from modules.interrogate import tagger
|
||||
|
||||
# Test WaifuDiffusion through unified interface
|
||||
if self.waifudiffusion_loaded:
|
||||
try:
|
||||
models = tagger.get_models()
|
||||
waifudiffusion_model = next((m for m in models if m != 'DeepBooru'), None)
|
||||
if waifudiffusion_model:
|
||||
tags = tagger.tag(self.test_image, model_name=waifudiffusion_model, max_tags=5)
|
||||
print(f" WaifuDiffusion ({waifudiffusion_model}): {tags[:50]}...")
|
||||
self.log_pass("Unified interface: WaifuDiffusion")
|
||||
except Exception as e:
|
||||
self.log_fail(f"Unified interface: WaifuDiffusion - {e}")
|
||||
|
||||
# Test DeepBooru through unified interface
|
||||
if self.deepbooru_loaded:
|
||||
try:
|
||||
tags = tagger.tag(self.test_image, model_name='DeepBooru', max_tags=5)
|
||||
print(f" DeepBooru: {tags[:50]}...")
|
||||
self.log_pass("Unified interface: DeepBooru")
|
||||
except Exception as e:
|
||||
self.log_fail(f"Unified interface: DeepBooru - {e}")
|
||||
|
||||
def run_all_tests(self):
|
||||
"""Run all tests."""
|
||||
self.setup()
|
||||
|
||||
self.test_onnx_providers()
|
||||
self.test_memory_management()
|
||||
self.test_settings_exist()
|
||||
self.test_threshold()
|
||||
self.test_max_tags()
|
||||
self.test_use_spaces()
|
||||
self.test_escape_brackets()
|
||||
self.test_sort_alpha()
|
||||
self.test_exclude_tags()
|
||||
self.test_show_scores()
|
||||
self.test_include_rating()
|
||||
self.test_character_threshold()
|
||||
self.test_unified_interface()
|
||||
|
||||
self.cleanup()
|
||||
self.print_summary()
|
||||
|
||||
return len(self.results['failed']) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = TaggerTest()
|
||||
success = test.run_all_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -128,5 +128,12 @@
|
|||
"preview": "shuttleai--shuttle-jaguar.jpg",
|
||||
"tags": "community",
|
||||
"skip": true
|
||||
},
|
||||
"Anima": {
|
||||
"path": "CalamitousFelicitousness/Anima-sdnext-diffusers",
|
||||
"preview": "CalamitousFelicitousness--Anima-sdnext-diffusers.png",
|
||||
"desc": "Modified Cosmos-Predict-2B that replaces the T5-11B text encoder with Qwen3-0.6B. Anima is a 2 billion parameter text-to-image model created via a collaboration between CircleStone Labs and Comfy Org. It is focused mainly on anime concepts, characters, and styles, but is also capable of generating a wide variety of other non-photorealistic content. The model is designed for making illustrations and artistic images, and will not work well at realism.",
|
||||
"tags": "community",
|
||||
"skip": true
|
||||
}
|
||||
}
|
||||
|
|
@ -143,6 +143,15 @@
|
|||
"date": "2025 January"
|
||||
},
|
||||
|
||||
"Z-Image": {
|
||||
"path": "Tongyi-MAI/Z-Image",
|
||||
"preview": "Tongyi-MAI--Z-Image.jpg",
|
||||
"desc": "Z-Image, an efficient image generation foundation model built on a Single-Stream Diffusion Transformer architecture. It preserves the complete training signal with full CFG support, enabling aesthetic versatility from hyper-realistic photography to anime, enhanced output diversity, and robust negative prompting for artifact suppression. Ideal base for LoRA training, ControlNet, and semantic conditioning.",
|
||||
"skip": true,
|
||||
"extras": "sampler: Default, cfg_scale: 4.0, steps: 50",
|
||||
"size": 20.3,
|
||||
"date": "2026 January"
|
||||
},
|
||||
"Z-Image-Turbo": {
|
||||
"path": "Tongyi-MAI/Z-Image-Turbo",
|
||||
"preview": "Tongyi-MAI--Z-Image-Turbo.jpg",
|
||||
|
|
@ -53,6 +53,7 @@ const jsConfig = defineConfig([
|
|||
generateForever: 'readonly',
|
||||
showContributors: 'readonly',
|
||||
opts: 'writable',
|
||||
monitorOption: 'readonly',
|
||||
sortUIElements: 'readonly',
|
||||
all_gallery_buttons: 'readonly',
|
||||
selected_gallery_button: 'readonly',
|
||||
|
|
@ -98,6 +99,8 @@ const jsConfig = defineConfig([
|
|||
idbAdd: 'readonly',
|
||||
idbCount: 'readonly',
|
||||
idbFolderCleanup: 'readonly',
|
||||
idbClearAll: 'readonly',
|
||||
idbIsReady: 'readonly',
|
||||
initChangelog: 'readonly',
|
||||
sendNotification: 'readonly',
|
||||
monitorConnection: 'readonly',
|
||||
|
|
@ -241,6 +244,9 @@ const jsonConfig = defineConfig([
|
|||
plugins: { json },
|
||||
language: 'json/json',
|
||||
extends: ['json/recommended'],
|
||||
rules: {
|
||||
'json/no-empty-keys': 'off',
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@
|
|||
{"id":"","label":"Embedding","localized":"","reload":"","hint":"Textual inversion embedding is a trained embedded information about the subject"},
|
||||
{"id":"","label":"Hypernetwork","localized":"","reload":"","hint":"Small trained neural network that modifies behavior of the loaded model"},
|
||||
{"id":"","label":"VLM Caption","localized":"","reload":"","hint":"Analyze image using vision langugage model"},
|
||||
{"id":"","label":"CLiP Interrogate","localized":"","reload":"","hint":"Analyze image using CLiP model"},
|
||||
{"id":"","label":"OpenCLiP","localized":"","reload":"","hint":"Analyze image using CLiP model via OpenCLiP"},
|
||||
{"id":"","label":"VAE","localized":"","reload":"","hint":"Variational Auto Encoder: model used to run image decode at the end of generate"},
|
||||
{"id":"","label":"History","localized":"","reload":"","hint":"List of previous generations that can be further reprocessed"},
|
||||
{"id":"","label":"UI disable variable aspect ratio","localized":"","reload":"","hint":"When disabled, all thumbnails appear as squared images"},
|
||||
|
|
|
|||
53
installer.py
53
installer.py
|
|
@ -112,7 +112,7 @@ def install_traceback(suppress: list = []):
|
|||
width = os.environ.get("SD_TRACEWIDTH", console.width if console else None)
|
||||
if width is not None:
|
||||
width = int(width)
|
||||
traceback_install(
|
||||
log.excepthook = traceback_install(
|
||||
console=console,
|
||||
extra_lines=int(os.environ.get("SD_TRACELINES", 1)),
|
||||
max_frames=int(os.environ.get("SD_TRACEFRAMES", 16)),
|
||||
|
|
@ -168,7 +168,6 @@ def setup_logging():
|
|||
def get(self):
|
||||
return self.buffer
|
||||
|
||||
|
||||
class LogFilter(logging.Filter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -215,6 +214,23 @@ def setup_logging():
|
|||
logging.Logger.trace = partialmethod(logging.Logger.log, logging.TRACE)
|
||||
logging.trace = partial(logging.log, logging.TRACE)
|
||||
|
||||
def exception_hook(e: Exception, suppress=[]):
|
||||
from rich.traceback import Traceback
|
||||
tb = Traceback.from_exception(type(e), e, e.__traceback__, show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
|
||||
# print-to-console, does not get printed-to-file
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
log.excepthook(exc_type, exc_value, exc_traceback)
|
||||
# print-to-file, temporarily disable-console-handler
|
||||
for handler in log.handlers.copy():
|
||||
if isinstance(handler, RichHandler):
|
||||
log.removeHandler(handler)
|
||||
with console.capture() as capture:
|
||||
console.print(tb)
|
||||
log.critical(capture.get())
|
||||
log.addHandler(rh)
|
||||
|
||||
log.traceback = exception_hook
|
||||
|
||||
level = logging.DEBUG if (args.debug or args.trace) else logging.INFO
|
||||
log.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd`
|
||||
log.print = rprint
|
||||
|
|
@ -240,8 +256,10 @@ def setup_logging():
|
|||
)
|
||||
|
||||
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null
|
||||
|
||||
pretty_install(console=console)
|
||||
install_traceback()
|
||||
|
||||
while log.hasHandlers() and len(log.handlers) > 0:
|
||||
log.removeHandler(log.handlers[0])
|
||||
|
||||
|
|
@ -288,7 +306,6 @@ def setup_logging():
|
|||
logging.getLogger("torch").setLevel(logging.ERROR)
|
||||
logging.getLogger("ControlNet").handlers = log.handlers
|
||||
logging.getLogger("lycoris").handlers = log.handlers
|
||||
# logging.getLogger("DeepSpeed").handlers = log.handlers
|
||||
ts('log', t_start)
|
||||
|
||||
|
||||
|
|
@ -712,9 +729,9 @@ def install_cuda():
|
|||
log.info('CUDA: nVidia toolkit detected')
|
||||
ts('cuda', t_start)
|
||||
if args.use_nightly:
|
||||
cmd = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --extra-index-url https://download.pytorch.org/whl/nightly/cu126')
|
||||
cmd = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 --extra-index-url https://download.pytorch.org/whl/nightly/cu130')
|
||||
else:
|
||||
cmd = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+cu128 torchvision==0.24.1+cu128 --index-url https://download.pytorch.org/whl/cu128')
|
||||
cmd = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+cu128 torchvision==0.25.0+cu128 --index-url https://download.pytorch.org/whl/cu128')
|
||||
return cmd
|
||||
|
||||
|
||||
|
|
@ -765,7 +782,6 @@ def install_rocm_zluda():
|
|||
|
||||
if sys.platform == "win32":
|
||||
if args.use_zluda:
|
||||
#check_python(supported_minors=[10, 11, 12, 13], reason='ZLUDA backend requires a Python version between 3.10 and 3.13')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.7.1+cu118 torchvision==0.22.1+cu118 --index-url https://download.pytorch.org/whl/cu118')
|
||||
|
||||
if args.device_id is not None:
|
||||
|
|
@ -795,6 +811,7 @@ def install_rocm_zluda():
|
|||
torch_command = os.environ.get('TORCH_COMMAND', f'torch torchvision --index-url https://rocm.nightlies.amd.com/{device.therock}')
|
||||
else:
|
||||
check_python(supported_minors=[12], reason='ROCm: Windows preview python==3.12 required')
|
||||
# torch 2.8.0a0 is the last version with rocm 6.4 support
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--no-cache-dir https://repo.radeon.com/rocm/windows/rocm-rel-6.4.4/torch-2.8.0a0%2Bgitfc14c65-cp312-cp312-win_amd64.whl https://repo.radeon.com/rocm/windows/rocm-rel-6.4.4/torchvision-0.24.0a0%2Bc85f008-cp312-cp312-win_amd64.whl')
|
||||
else:
|
||||
#check_python(supported_minors=[10, 11, 12, 13, 14], reason='ROCm backend requires a Python version between 3.10 and 3.13')
|
||||
|
|
@ -804,7 +821,11 @@ def install_rocm_zluda():
|
|||
else: # oldest rocm version on nightly is 7.0
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0')
|
||||
else:
|
||||
if rocm.version is None or float(rocm.version) >= 6.4: # assume the latest if version check fails
|
||||
if rocm.version is None or float(rocm.version) >= 7.1: # assume the latest if version check fails
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+rocm7.1 torchvision==0.25.0+rocm7.1 --index-url https://download.pytorch.org/whl/rocm7.1')
|
||||
elif rocm.version == "7.0": # assume the latest if version check fails
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+rocm7.0 torchvision==0.25.0+rocm7.0 --index-url https://download.pytorch.org/whl/rocm7.0')
|
||||
elif rocm.version == "6.4":
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+rocm6.4 torchvision==0.24.1+rocm6.4 --index-url https://download.pytorch.org/whl/rocm6.4')
|
||||
elif rocm.version == "6.3":
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+rocm6.3 torchvision==0.24.1+rocm6.3 --index-url https://download.pytorch.org/whl/rocm6.3')
|
||||
|
|
@ -841,7 +862,7 @@ def install_ipex():
|
|||
if args.use_nightly:
|
||||
torch_command = os.environ.get('TORCH_COMMAND', '--upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
|
||||
else:
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+xpu torchvision==0.24.1+xpu --index-url https://download.pytorch.org/whl/xpu')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+xpu torchvision==0.25.0+xpu --index-url https://download.pytorch.org/whl/xpu')
|
||||
|
||||
ts('ipex', t_start)
|
||||
return torch_command
|
||||
|
|
@ -854,13 +875,13 @@ def install_openvino():
|
|||
|
||||
#check_python(supported_minors=[10, 11, 12, 13], reason='OpenVINO backend requires a Python version between 3.10 and 3.13')
|
||||
if sys.platform == 'darwin':
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1 torchvision==0.24.1')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0 torchvision==0.25.0')
|
||||
else:
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.9.1+cpu torchvision==0.24.1 --index-url https://download.pytorch.org/whl/cpu')
|
||||
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.10.0+cpu torchvision==0.25.0 --index-url https://download.pytorch.org/whl/cpu')
|
||||
|
||||
if not (args.skip_all or args.skip_requirements):
|
||||
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.3.0'), 'openvino')
|
||||
install(os.environ.get('NNCF_COMMAND', 'nncf==2.18.0'), 'nncf')
|
||||
install(os.environ.get('OPENVINO_COMMAND', 'openvino==2025.4.1'), 'openvino')
|
||||
install(os.environ.get('NNCF_COMMAND', 'nncf==2.19.0'), 'nncf')
|
||||
ts('openvino', t_start)
|
||||
return torch_command
|
||||
|
||||
|
|
@ -1427,6 +1448,7 @@ def set_environment():
|
|||
os.environ.setdefault('TORCH_CUDNN_V8_API_ENABLED', '1')
|
||||
os.environ.setdefault('TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD', '1')
|
||||
os.environ.setdefault('TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL', '1')
|
||||
os.environ.setdefault('MIOPEN_FIND_MODE', '2')
|
||||
os.environ.setdefault('UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS', '1')
|
||||
os.environ.setdefault('USE_TORCH', '1')
|
||||
os.environ.setdefault('UV_INDEX_STRATEGY', 'unsafe-any-match')
|
||||
|
|
@ -1540,7 +1562,7 @@ def check_ui(ver):
|
|||
|
||||
t_start = time.time()
|
||||
if not same(ver):
|
||||
log.debug(f'Branch mismatch: sdnext={ver["branch"]} ui={ver["ui"]}')
|
||||
log.debug(f'Branch mismatch: {ver}')
|
||||
cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir('extensions-builtin/sdnext-modernui')
|
||||
|
|
@ -1548,10 +1570,7 @@ def check_ui(ver):
|
|||
git('checkout ' + target, ignore=True, optional=True)
|
||||
os.chdir(cwd)
|
||||
ver = get_version(force=True)
|
||||
if not same(ver):
|
||||
log.debug(f'Branch synchronized: {ver["branch"]}')
|
||||
else:
|
||||
log.debug(f'Branch sync failed: sdnext={ver["branch"]} ui={ver["ui"]}')
|
||||
log.debug(f'Branch sync: {ver}')
|
||||
except Exception as e:
|
||||
log.debug(f'Branch switch: {e}')
|
||||
os.chdir(cwd)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
let ws;
|
||||
let url;
|
||||
let currentImage = null;
|
||||
let currentGalleryFolder = null;
|
||||
let pruneImagesTimer;
|
||||
let outstanding = 0;
|
||||
let lastSort = 0;
|
||||
|
|
@ -20,6 +21,7 @@ const el = {
|
|||
search: undefined,
|
||||
status: undefined,
|
||||
btnSend: undefined,
|
||||
clearCacheFolder: undefined,
|
||||
};
|
||||
|
||||
const SUPPORTED_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'tiff', 'jp2', 'jxl', 'gif', 'mp4', 'mkv', 'avi', 'mjpeg', 'mpg', 'avr'];
|
||||
|
|
@ -117,9 +119,12 @@ function updateGalleryStyles() {
|
|||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
transition-duration: 0.2s;
|
||||
transition-property: color, opacity, background-color, border-color;
|
||||
transition-timing-function: ease-out;
|
||||
}
|
||||
.gallery-folder:hover {
|
||||
background-color: var(--button-primary-background-fill-hover);
|
||||
background-color: var(--button-primary-background-fill-hover, var(--sd-button-hover-color));
|
||||
}
|
||||
.gallery-folder-selected {
|
||||
background-color: var(--sd-button-selected-color);
|
||||
|
|
@ -258,6 +263,14 @@ class SimpleFunctionQueue {
|
|||
this.#queue = [];
|
||||
}
|
||||
|
||||
static abortLogger(identifier, result) {
|
||||
if (typeof result === 'string' || (result instanceof DOMException && result.name === 'AbortError')) {
|
||||
log(identifier, result?.message || result);
|
||||
} else {
|
||||
error(identifier, result.message);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {{
|
||||
* signal: AbortSignal,
|
||||
|
|
@ -301,6 +314,8 @@ class SimpleFunctionQueue {
|
|||
// HTML Elements
|
||||
|
||||
class GalleryFolder extends HTMLElement {
|
||||
static folders = new Set();
|
||||
|
||||
constructor(folder) {
|
||||
super();
|
||||
// Support both old format (string) and new format (object with path and label)
|
||||
|
|
@ -314,21 +329,173 @@ class GalleryFolder extends HTMLElement {
|
|||
this.style.overflowX = 'hidden';
|
||||
this.shadow = this.attachShadow({ mode: 'open' });
|
||||
this.shadow.adoptedStyleSheets = [folderStylesheet];
|
||||
|
||||
this.div = document.createElement('div');
|
||||
}
|
||||
|
||||
connectedCallback() {
|
||||
const div = document.createElement('div');
|
||||
div.className = 'gallery-folder';
|
||||
div.innerHTML = `<span class="gallery-folder-icon">\uf03e</span> ${this.label}`;
|
||||
div.title = this.name; // Show full path on hover
|
||||
div.addEventListener('click', () => {
|
||||
for (const folder of el.folders.children) {
|
||||
if (folder.name === this.name) folder.shadow.firstElementChild.classList.add('gallery-folder-selected');
|
||||
else folder.shadow.firstElementChild.classList.remove('gallery-folder-selected');
|
||||
if (GalleryFolder.folders.has(this)) return; // Element is just being moved
|
||||
|
||||
this.div.className = 'gallery-folder';
|
||||
this.div.innerHTML = `<span class="gallery-folder-icon">\uf03e</span> ${this.label}`;
|
||||
this.div.title = this.name; // Show full path on hover
|
||||
this.div.addEventListener('click', () => { this.updateSelected(); }); // Ensures 'this' isn't the div in the called method
|
||||
this.div.addEventListener('click', fetchFilesWS); // eslint-disable-line no-use-before-define
|
||||
this.shadow.appendChild(this.div);
|
||||
GalleryFolder.folders.add(this);
|
||||
}
|
||||
|
||||
async disconnectedCallback() {
|
||||
await Promise.resolve(); // Wait for other microtasks (such as element moving)
|
||||
if (this.isConnected) return;
|
||||
GalleryFolder.folders.delete(this);
|
||||
}
|
||||
|
||||
updateSelected() {
|
||||
this.div.classList.add('gallery-folder-selected');
|
||||
for (const folder of GalleryFolder.folders) {
|
||||
if (folder !== this) {
|
||||
folder.div.classList.remove('gallery-folder-selected');
|
||||
}
|
||||
});
|
||||
div.addEventListener('click', fetchFilesWS); // eslint-disable-line no-use-before-define
|
||||
this.shadow.appendChild(div);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function delayFetchThumb(fn, signal) {
|
||||
await awaitForOutstanding(16, signal);
|
||||
try {
|
||||
outstanding++;
|
||||
const ts = Date.now().toString();
|
||||
const res = await authFetch(`${window.api}/browser/thumb?file=${encodeURI(fn)}&ts=${ts}`, { priority: 'low' });
|
||||
if (!res.ok) {
|
||||
error(`fetchThumb: ${res.statusText}`);
|
||||
return undefined;
|
||||
}
|
||||
const json = await res.json();
|
||||
if (!res || !json || json.error || Object.keys(json).length === 0) {
|
||||
if (json.error) error(`fetchThumb: ${json.error}`);
|
||||
return undefined;
|
||||
}
|
||||
return json;
|
||||
} finally {
|
||||
outstanding--;
|
||||
}
|
||||
}
|
||||
|
||||
class GalleryFile extends HTMLElement {
|
||||
/** @type {AbortSignal} */
|
||||
#signal;
|
||||
|
||||
constructor(folder, file, signal) {
|
||||
super();
|
||||
this.folder = folder;
|
||||
this.name = file;
|
||||
this.#signal = signal;
|
||||
this.size = 0;
|
||||
this.mtime = 0;
|
||||
this.hash = undefined;
|
||||
this.exif = '';
|
||||
this.width = 0;
|
||||
this.height = 0;
|
||||
this.src = `${this.folder}/${this.name}`;
|
||||
this.shadow = this.attachShadow({ mode: 'open' });
|
||||
this.shadow.adoptedStyleSheets = [fileStylesheet];
|
||||
|
||||
this.firstRun = true;
|
||||
}
|
||||
|
||||
async connectedCallback() {
|
||||
if (!this.firstRun) return; // Element is just being moved
|
||||
this.firstRun = false;
|
||||
|
||||
// Check separator state early to hide the element immediately
|
||||
const dir = this.name.match(/(.*)[/\\]/);
|
||||
if (dir && dir[1]) {
|
||||
const dirPath = dir[1];
|
||||
const isOpen = separatorStates.get(dirPath);
|
||||
if (isOpen === false) {
|
||||
this.style.display = 'none';
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize path to ensure consistent hash regardless of which folder view is used
|
||||
const normalizedPath = this.src.replace(/\/+/g, '/').replace(/\/$/, '');
|
||||
this.hash = await getHash(`${normalizedPath}/${this.size}/${this.mtime}`); // eslint-disable-line no-use-before-define
|
||||
const cachedData = (this.hash && opts.browser_cache) ? await idbGet(this.hash).catch(() => undefined) : undefined;
|
||||
const img = document.createElement('img');
|
||||
img.className = 'gallery-file';
|
||||
img.loading = 'lazy';
|
||||
img.onload = async () => {
|
||||
img.title += `\nResolution: ${this.width} x ${this.height}`;
|
||||
this.title = img.title;
|
||||
if (!cachedData && opts.browser_cache) {
|
||||
if ((this.width === 0) || (this.height === 0)) { // fetch thumb failed so we use actual image
|
||||
this.width = img.naturalWidth;
|
||||
this.height = img.naturalHeight;
|
||||
}
|
||||
}
|
||||
};
|
||||
let ok = true;
|
||||
if (cachedData?.img) {
|
||||
img.src = cachedData.img;
|
||||
this.exif = cachedData.exif;
|
||||
this.width = cachedData.width;
|
||||
this.height = cachedData.height;
|
||||
this.size = cachedData.size;
|
||||
this.mtime = new Date(cachedData.mtime);
|
||||
} else {
|
||||
try {
|
||||
const json = await delayFetchThumb(this.src, this.#signal);
|
||||
if (!json) {
|
||||
ok = false;
|
||||
} else {
|
||||
img.src = json.data;
|
||||
this.exif = json.exif;
|
||||
this.width = json.width;
|
||||
this.height = json.height;
|
||||
this.size = json.size;
|
||||
this.mtime = new Date(json.mtime);
|
||||
if (opts.browser_cache) {
|
||||
await idbAdd({
|
||||
hash: this.hash,
|
||||
folder: this.folder,
|
||||
file: this.name,
|
||||
size: this.size,
|
||||
mtime: this.mtime,
|
||||
width: this.width,
|
||||
height: this.height,
|
||||
src: this.src,
|
||||
exif: this.exif,
|
||||
img: img.src,
|
||||
// exif: await getExif(img), // alternative client-side exif
|
||||
// img: await createThumb(img), // alternative client-side thumb
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (err) { // thumb fetch failed so assign actual image
|
||||
img.src = `file=${this.src}`;
|
||||
}
|
||||
}
|
||||
if (this.#signal.aborted) { // Do not change the operations order from here...
|
||||
return;
|
||||
}
|
||||
galleryHashes.add(this.hash);
|
||||
if (!ok) {
|
||||
return;
|
||||
} // ... to here unless modifications are also being made to maintenance functionality and the usage of AbortController/AbortSignal
|
||||
img.onclick = () => {
|
||||
setGallerySelectionByElement(this, { send: true });
|
||||
};
|
||||
img.title = `Folder: ${this.folder}\nFile: ${this.name}\nSize: ${this.size.toLocaleString()} bytes\nModified: ${this.mtime.toLocaleString()}`;
|
||||
this.title = img.title;
|
||||
|
||||
// Final visibility check based on search term.
|
||||
const shouldDisplayBasedOnSearch = this.title.toLowerCase().includes(el.search.value.toLowerCase());
|
||||
if (this.style.display !== 'none') { // Only proceed if not already hidden by a closed separator
|
||||
this.style.display = shouldDisplayBasedOnSearch ? 'unset' : 'none';
|
||||
}
|
||||
|
||||
this.shadow.appendChild(img);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -459,148 +626,6 @@ async function addSeparators() {
|
|||
}
|
||||
}
|
||||
|
||||
async function delayFetchThumb(fn, signal) {
|
||||
await awaitForOutstanding(16, signal);
|
||||
try {
|
||||
outstanding++;
|
||||
const ts = Date.now().toString();
|
||||
const res = await authFetch(`${window.api}/browser/thumb?file=${encodeURI(fn)}&ts=${ts}`, { priority: 'low' });
|
||||
if (!res.ok) {
|
||||
error(`fetchThumb: ${res.statusText}`);
|
||||
return undefined;
|
||||
}
|
||||
const json = await res.json();
|
||||
if (!res || !json || json.error || Object.keys(json).length === 0) {
|
||||
if (json.error) error(`fetchThumb: ${json.error}`);
|
||||
return undefined;
|
||||
}
|
||||
return json;
|
||||
} finally {
|
||||
outstanding--;
|
||||
}
|
||||
}
|
||||
|
||||
class GalleryFile extends HTMLElement {
|
||||
/** @type {AbortSignal} */
|
||||
#signal;
|
||||
|
||||
constructor(folder, file, signal) {
|
||||
super();
|
||||
this.folder = folder;
|
||||
this.name = file;
|
||||
this.#signal = signal;
|
||||
this.size = 0;
|
||||
this.mtime = 0;
|
||||
this.hash = undefined;
|
||||
this.exif = '';
|
||||
this.width = 0;
|
||||
this.height = 0;
|
||||
this.src = `${this.folder}/${this.name}`;
|
||||
this.shadow = this.attachShadow({ mode: 'open' });
|
||||
this.shadow.adoptedStyleSheets = [fileStylesheet];
|
||||
}
|
||||
|
||||
async connectedCallback() {
|
||||
if (this.shadow.children.length > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check separator state early to hide the element immediately
|
||||
const dir = this.name.match(/(.*)[/\\]/);
|
||||
if (dir && dir[1]) {
|
||||
const dirPath = dir[1];
|
||||
const isOpen = separatorStates.get(dirPath);
|
||||
if (isOpen === false) {
|
||||
this.style.display = 'none';
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize path to ensure consistent hash regardless of which folder view is used
|
||||
const normalizedPath = this.src.replace(/\/+/g, '/').replace(/\/$/, '');
|
||||
this.hash = await getHash(`${normalizedPath}/${this.size}/${this.mtime}`); // eslint-disable-line no-use-before-define
|
||||
const cachedData = (this.hash && opts.browser_cache) ? await idbGet(this.hash).catch(() => undefined) : undefined;
|
||||
const img = document.createElement('img');
|
||||
img.className = 'gallery-file';
|
||||
img.loading = 'lazy';
|
||||
img.onload = async () => {
|
||||
img.title += `\nResolution: ${this.width} x ${this.height}`;
|
||||
this.title = img.title;
|
||||
if (!cachedData && opts.browser_cache) {
|
||||
if ((this.width === 0) || (this.height === 0)) { // fetch thumb failed so we use actual image
|
||||
this.width = img.naturalWidth;
|
||||
this.height = img.naturalHeight;
|
||||
}
|
||||
}
|
||||
};
|
||||
let ok = true;
|
||||
if (cachedData?.img) {
|
||||
img.src = cachedData.img;
|
||||
this.exif = cachedData.exif;
|
||||
this.width = cachedData.width;
|
||||
this.height = cachedData.height;
|
||||
this.size = cachedData.size;
|
||||
this.mtime = new Date(cachedData.mtime);
|
||||
} else {
|
||||
try {
|
||||
const json = await delayFetchThumb(this.src, this.#signal);
|
||||
if (!json) {
|
||||
ok = false;
|
||||
} else {
|
||||
img.src = json.data;
|
||||
this.exif = json.exif;
|
||||
this.width = json.width;
|
||||
this.height = json.height;
|
||||
this.size = json.size;
|
||||
this.mtime = new Date(json.mtime);
|
||||
if (opts.browser_cache) {
|
||||
// Store file's actual parent directory (not browsed folder) for consistent cleanup
|
||||
const fileDir = this.src.replace(/\/+/g, '/').replace(/\/[^/]+$/, '');
|
||||
await idbAdd({
|
||||
hash: this.hash,
|
||||
folder: fileDir,
|
||||
file: this.name,
|
||||
size: this.size,
|
||||
mtime: this.mtime,
|
||||
width: this.width,
|
||||
height: this.height,
|
||||
src: this.src,
|
||||
exif: this.exif,
|
||||
img: img.src,
|
||||
// exif: await getExif(img), // alternative client-side exif
|
||||
// img: await createThumb(img), // alternative client-side thumb
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (err) { // thumb fetch failed so assign actual image
|
||||
img.src = `file=${this.src}`;
|
||||
}
|
||||
}
|
||||
if (this.#signal.aborted) { // Do not change the operations order from here...
|
||||
return;
|
||||
}
|
||||
galleryHashes.add(this.hash);
|
||||
if (!ok) {
|
||||
return;
|
||||
} // ... to here unless modifications are also being made to maintenance functionality and the usage of AbortController/AbortSignal
|
||||
img.onclick = () => {
|
||||
setGallerySelectionByElement(this, { send: true });
|
||||
};
|
||||
img.title = `Folder: ${this.folder}\nFile: ${this.name}\nSize: ${this.size.toLocaleString()} bytes\nModified: ${this.mtime.toLocaleString()}`;
|
||||
if (this.shadow.children.length > 0) {
|
||||
return; // avoid double-adding
|
||||
}
|
||||
this.title = img.title;
|
||||
|
||||
// Final visibility check based on search term.
|
||||
const shouldDisplayBasedOnSearch = this.title.toLowerCase().includes(el.search.value.toLowerCase());
|
||||
if (this.style.display !== 'none') { // Only proceed if not already hidden by a closed separator
|
||||
this.style.display = shouldDisplayBasedOnSearch ? 'unset' : 'none';
|
||||
}
|
||||
|
||||
this.shadow.appendChild(img);
|
||||
}
|
||||
}
|
||||
|
||||
// methods
|
||||
|
||||
const gallerySendImage = (_images) => [currentImage]; // invoked by gradio button
|
||||
|
|
@ -919,9 +944,10 @@ async function gallerySort(btn) {
|
|||
/**
|
||||
* Generate and display the overlay to announce cleanup is in progress.
|
||||
* @param {number} count - Number of entries being cleaned up
|
||||
* @param {boolean} all - Indicate that all thumbnails are being cleared
|
||||
* @returns {ClearMsgCallback}
|
||||
*/
|
||||
function showCleaningMsg(count) {
|
||||
function showCleaningMsg(count, all = false) {
|
||||
// Rendering performance isn't a priority since this doesn't run often
|
||||
const parent = el.folders.parentElement;
|
||||
const cleaningOverlay = document.createElement('div');
|
||||
|
|
@ -936,7 +962,7 @@ function showCleaningMsg(count) {
|
|||
msgText.style.cssText = 'font-size: 1.2em';
|
||||
msgInfo.style.cssText = 'font-size: 0.9em; text-align: center;';
|
||||
msgText.innerText = 'Thumbnail cleanup...';
|
||||
msgInfo.innerText = `Found ${count} old entries`;
|
||||
msgInfo.innerText = all ? 'Clearing all entries' : `Found ${count} old entries`;
|
||||
anim.classList.add('idbBusyAnim');
|
||||
|
||||
msgDiv.append(msgText, msgInfo);
|
||||
|
|
@ -945,16 +971,17 @@ function showCleaningMsg(count) {
|
|||
return () => { cleaningOverlay.remove(); };
|
||||
}
|
||||
|
||||
const maintenanceQueue = new SimpleFunctionQueue('Maintenance');
|
||||
const maintenanceQueue = new SimpleFunctionQueue('Gallery Maintenance');
|
||||
|
||||
/**
|
||||
* Handles calling the cleanup function for the thumbnail cache
|
||||
* @param {string} folder - Folder to clean
|
||||
* @param {number} imgCount - Expected number of images in gallery
|
||||
* @param {AbortController} controller - AbortController that's handling this task
|
||||
* @param {boolean} force - Force full cleanup of the folder
|
||||
*/
|
||||
async function thumbCacheCleanup(folder, imgCount, controller) {
|
||||
if (!opts.browser_cache) return;
|
||||
async function thumbCacheCleanup(folder, imgCount, controller, force = false) {
|
||||
if (!opts.browser_cache && !force) return;
|
||||
try {
|
||||
if (typeof folder !== 'string' || typeof imgCount !== 'number') {
|
||||
throw new Error('Function called with invalid arguments');
|
||||
|
|
@ -971,14 +998,14 @@ async function thumbCacheCleanup(folder, imgCount, controller) {
|
|||
callback: async () => {
|
||||
log(`Thumbnail DB cleanup: Checking if "${folder}" needs cleaning`);
|
||||
const t0 = performance.now();
|
||||
const staticGalleryHashes = new Set(galleryHashes); // External context should be safe since this function run is guarded by AbortController/AbortSignal in the SimpleFunctionQueue
|
||||
const keptGalleryHashes = force ? new Set() : new Set(galleryHashes.values()); // External context should be safe since this function run is guarded by AbortController/AbortSignal in the SimpleFunctionQueue
|
||||
const cachedHashesCount = await idbCount(folder)
|
||||
.catch((e) => {
|
||||
error(`Thumbnail DB cleanup: Error when getting entry count for "${folder}".`, e);
|
||||
return Infinity; // Forces next check to fail if something went wrong
|
||||
});
|
||||
const cleanupCount = cachedHashesCount - staticGalleryHashes.size;
|
||||
if (cleanupCount < 500 || !Number.isFinite(cleanupCount)) {
|
||||
const cleanupCount = cachedHashesCount - keptGalleryHashes.size;
|
||||
if (!force && (cleanupCount < 500 || !Number.isFinite(cleanupCount))) {
|
||||
// Don't run when there aren't many excess entries
|
||||
return;
|
||||
}
|
||||
|
|
@ -988,30 +1015,95 @@ async function thumbCacheCleanup(folder, imgCount, controller) {
|
|||
return;
|
||||
}
|
||||
const cb_clearMsg = showCleaningMsg(cleanupCount);
|
||||
const tRun = Date.now(); // Doesn't need high resolution
|
||||
await idbFolderCleanup(staticGalleryHashes, folder, controller.signal)
|
||||
await idbFolderCleanup(keptGalleryHashes, folder, controller.signal)
|
||||
.then((delcount) => {
|
||||
const t1 = performance.now();
|
||||
log(`Thumbnail DB cleanup: folder=${folder} kept=${staticGalleryHashes.size} deleted=${delcount} time=${Math.floor(t1 - t0)}ms`);
|
||||
log(`Thumbnail DB cleanup: folder=${folder} kept=${keptGalleryHashes.size} deleted=${delcount} time=${Math.floor(t1 - t0)}ms`);
|
||||
currentGalleryFolder = null;
|
||||
el.clearCacheFolder.innerText = '<select a folder first>';
|
||||
updateStatusWithSort('Thumbnail cache cleared');
|
||||
})
|
||||
.catch((reason) => {
|
||||
if (typeof reason === 'string' || (reason instanceof DOMException && reason.name === 'AbortError')) {
|
||||
log('Thumbnail DB cleanup:', reason?.message || reason);
|
||||
} else {
|
||||
error('Thumbnail DB cleanup:', reason.message);
|
||||
}
|
||||
SimpleFunctionQueue.abortLogger('Thumbnail DB cleanup:', reason);
|
||||
})
|
||||
.finally(async () => {
|
||||
// Ensure at least enough time to see that it's a message and not the UI breaking/flickering
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, Math.min(1000, Math.max(1000 - (Date.now() - tRun), 0))); // Total display time of at least 1 second
|
||||
});
|
||||
await new Promise((resolve) => { setTimeout(resolve, 1000); }); // Delay removal by 1 second to ensure at least minimum visibility
|
||||
cb_clearMsg();
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function resetGalleryState(reason) {
|
||||
maintenanceController.abort(reason);
|
||||
const controller = new AbortController();
|
||||
maintenanceController = controller;
|
||||
|
||||
galleryHashes.clear(); // Must happen AFTER the AbortController steps
|
||||
galleryProgressBar.clear();
|
||||
resetGallerySelection();
|
||||
return controller;
|
||||
}
|
||||
|
||||
function clearCacheIfDisabled(browser_cache) {
|
||||
if (browser_cache === false) {
|
||||
log('Thumbnail DB cleanup:', 'Image gallery cache setting disabled. Clearing cache.');
|
||||
const controller = resetGalleryState('Clearing all thumbnails from cache');
|
||||
maintenanceQueue.enqueue({
|
||||
signal: controller.signal,
|
||||
callback: async () => {
|
||||
const t0 = performance.now();
|
||||
const cb_clearMsg = showCleaningMsg(0, true);
|
||||
await idbClearAll(controller.signal)
|
||||
.then(() => {
|
||||
log(`Thumbnail DB cleanup: Cache cleared. time=${Math.floor(performance.now() - t0)}ms`);
|
||||
currentGalleryFolder = null;
|
||||
el.clearCacheFolder.innerText = '<select a folder first>';
|
||||
updateStatusWithSort('Thumbnail cache cleared');
|
||||
})
|
||||
.catch((e) => {
|
||||
SimpleFunctionQueue.abortLogger('Thumbnail DB cleanup:', e);
|
||||
})
|
||||
.finally(async () => {
|
||||
await new Promise((resolve) => { setTimeout(resolve, 1000); });
|
||||
cb_clearMsg();
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function addCacheClearLabel() { // Don't use async
|
||||
const setting = document.querySelector('#setting_browser_cache');
|
||||
if (setting) {
|
||||
const div = document.createElement('div');
|
||||
div.style.marginBlock = '0.75rem';
|
||||
|
||||
const span = document.createElement('span');
|
||||
span.style.cssText = 'font-weight: bold; text-decoration: underline; cursor: pointer; color: var(--color-blue); user-select: none;';
|
||||
span.innerText = '<select a folder first>';
|
||||
|
||||
div.append('Clear the thumbnail cache for: ', span, ' (double-click)');
|
||||
setting.parentElement.insertAdjacentElement('afterend', div);
|
||||
el.clearCacheFolder = span;
|
||||
|
||||
span.addEventListener('dblclick', (evt) => {
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
if (!currentGalleryFolder) return;
|
||||
el.clearCacheFolder.style.color = 'var(--color-green)';
|
||||
setTimeout(() => {
|
||||
el.clearCacheFolder.style.color = 'var(--color-blue)';
|
||||
}, 1000);
|
||||
const controller = resetGalleryState('Clearing folder thumbnails cache');
|
||||
el.files.innerHTML = '';
|
||||
thumbCacheCleanup(currentGalleryFolder, 0, controller, true);
|
||||
});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
async function fetchFilesHT(evt, controller) {
|
||||
const t0 = performance.now();
|
||||
const fragment = document.createDocumentFragment();
|
||||
|
|
@ -1049,12 +1141,8 @@ async function fetchFilesHT(evt, controller) {
|
|||
|
||||
async function fetchFilesWS(evt) { // fetch file-by-file list over websockets
|
||||
if (!url) return;
|
||||
const controller = new AbortController(); // Only called here because fetchFilesHT isn't called directly
|
||||
maintenanceController.abort('Gallery update'); // Abort previous controller
|
||||
maintenanceController = controller; // Point to new controller for next time
|
||||
galleryHashes.clear(); // Must happen AFTER the AbortController steps
|
||||
galleryProgressBar.clear();
|
||||
resetGallerySelection();
|
||||
// Abort previous controller and point to new controller for next time
|
||||
const controller = resetGalleryState('Gallery update'); // Called here because fetchFilesHT isn't called directly
|
||||
|
||||
el.files.innerHTML = '';
|
||||
updateGalleryStyles();
|
||||
|
|
@ -1068,6 +1156,10 @@ async function fetchFilesWS(evt) { // fetch file-by-file list over websockets
|
|||
return;
|
||||
}
|
||||
log(`gallery: connected=${wsConnected} state=${ws?.readyState} url=${ws?.url}`);
|
||||
currentGalleryFolder = evt.target.name;
|
||||
if (el.clearCacheFolder) {
|
||||
el.clearCacheFolder.innerText = currentGalleryFolder;
|
||||
}
|
||||
if (!wsConnected) {
|
||||
await fetchFilesHT(evt, controller); // fallback to http
|
||||
return;
|
||||
|
|
@ -1115,26 +1207,17 @@ async function fetchFilesWS(evt) { // fetch file-by-file list over websockets
|
|||
ws.send(encodeURI(evt.target.name));
|
||||
}
|
||||
|
||||
async function pruneImages() {
|
||||
// TODO replace img.src with placeholder for images that are not visible
|
||||
}
|
||||
|
||||
async function galleryVisible() {
|
||||
async function updateFolders() {
|
||||
// if (el.folders.children.length > 0) return;
|
||||
const res = await authFetch(`${window.api}/browser/folders`);
|
||||
if (!res || res.status !== 200) return;
|
||||
el.folders.innerHTML = '';
|
||||
url = res.url.split('/sdapi')[0].replace('http', 'ws'); // update global url as ws need fqdn
|
||||
const folders = await res.json();
|
||||
el.folders.innerHTML = '';
|
||||
for (const folder of folders) {
|
||||
const f = new GalleryFolder(folder);
|
||||
el.folders.appendChild(f);
|
||||
}
|
||||
pruneImagesTimer = setInterval(pruneImages, 1000);
|
||||
}
|
||||
|
||||
async function galleryHidden() {
|
||||
if (pruneImagesTimer) clearInterval(pruneImagesTimer);
|
||||
}
|
||||
|
||||
async function monitorGalleries() {
|
||||
|
|
@ -1165,6 +1248,32 @@ async function setOverlayAnimation() {
|
|||
document.head.append(busyAnimation);
|
||||
}
|
||||
|
||||
async function galleryClearInit() {
|
||||
let galleryClearInitTimeout = 0;
|
||||
const tryCleanupInit = setInterval(() => {
|
||||
if (addCacheClearLabel() || galleryClearInitTimeout++ === 60) {
|
||||
clearInterval(tryCleanupInit);
|
||||
monitorOption('browser_cache', clearCacheIfDisabled);
|
||||
}
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
async function blockQueueUntilReady() {
|
||||
// Add block to maintenanceQueue until cache is ready
|
||||
maintenanceQueue.enqueue({
|
||||
signal: new AbortController().signal, // Use standalone AbortSignal that can't be aborted
|
||||
callback: async () => {
|
||||
let timeout = 0;
|
||||
while (!idbIsReady() && timeout++ < 60) {
|
||||
await new Promise((resolve) => { setTimeout(resolve, 1000); });
|
||||
}
|
||||
if (!idbIsReady()) {
|
||||
throw new Error('Timed out waiting for thumbnail cache');
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function initGallery() { // triggered on gradio change to monitor when ui gets sufficiently constructed
|
||||
log('initGallery');
|
||||
el.folders = gradioApp().getElementById('tab-gallery-folders');
|
||||
|
|
@ -1175,9 +1284,12 @@ async function initGallery() { // triggered on gradio change to monitor when ui
|
|||
error('initGallery', 'Missing gallery elements');
|
||||
return;
|
||||
}
|
||||
|
||||
blockQueueUntilReady(); // Run first
|
||||
updateGalleryStyles();
|
||||
injectGalleryStatusCSS();
|
||||
setOverlayAnimation();
|
||||
galleryClearInit();
|
||||
const progress = gradioApp().getElementById('tab-gallery-progress');
|
||||
if (progress) {
|
||||
galleryProgressBar.attachTo(progress);
|
||||
|
|
@ -1188,12 +1300,9 @@ async function initGallery() { // triggered on gradio change to monitor when ui
|
|||
el.btnSend = gradioApp().getElementById('tab-gallery-send-image');
|
||||
document.getElementById('tab-gallery-files').style.height = opts.logmonitor_show ? '75vh' : '85vh';
|
||||
|
||||
const intersectionObserver = new IntersectionObserver((entries) => {
|
||||
if (entries[0].intersectionRatio <= 0) galleryHidden();
|
||||
if (entries[0].intersectionRatio > 0) galleryVisible();
|
||||
});
|
||||
intersectionObserver.observe(el.folders);
|
||||
monitorGalleries();
|
||||
updateFolders();
|
||||
monitorOption('browser_folders', updateFolders);
|
||||
}
|
||||
|
||||
// register on startup
|
||||
|
|
|
|||
|
|
@ -36,6 +36,41 @@ async function initIndexDB() {
|
|||
if (!db) await createDB();
|
||||
}
|
||||
|
||||
function idbIsReady() {
|
||||
return db !== null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reusable setup for handling IDB transactions.
|
||||
* @param {Object} resources - Required resources for implementation
|
||||
* @param {IDBTransaction} resources.transaction
|
||||
* @param {AbortSignal} resources.signal
|
||||
* @param {Function} resources.resolve
|
||||
* @param {Function} resources.reject
|
||||
* @param {*} resolveValue - Value to resolve the outer Promise with
|
||||
* @returns {() => void} - Function for manually aborting the transaction
|
||||
*/
|
||||
function configureTransactionAbort({ transaction, signal, resolve, reject }, resolveValue) {
|
||||
function abortTransaction() {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
transaction.abort();
|
||||
}
|
||||
signal.addEventListener('abort', abortTransaction);
|
||||
transaction.onabort = () => {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
reject(new DOMException(`Aborting database transaction. ${signal.reason}`, 'AbortError'));
|
||||
};
|
||||
transaction.onerror = (e) => {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
reject(new Error('Database transaction error.', e));
|
||||
};
|
||||
transaction.oncomplete = () => {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
resolve(resolveValue);
|
||||
};
|
||||
return abortTransaction;
|
||||
}
|
||||
|
||||
async function add(record) {
|
||||
if (!db) return null;
|
||||
return new Promise((resolve, reject) => {
|
||||
|
|
@ -150,10 +185,7 @@ async function idbFolderCleanup(keepSet, folder, signal) {
|
|||
throw new Error('IndexedDB cleaning function must be told the current active folder');
|
||||
}
|
||||
|
||||
// Use range query to match folder and all its subdirectories
|
||||
const folderNormalized = folder.replace(/\/+/g, '/').replace(/\/$/, '');
|
||||
const range = IDBKeyRange.bound(folderNormalized, `${folderNormalized}\uffff`, false, true);
|
||||
let removals = new Set(await idbGetAllKeys('folder', range));
|
||||
let removals = new Set(await idbGetAllKeys('folder', folder));
|
||||
removals = removals.difference(keepSet); // Don't need to keep full set in memory
|
||||
const totalRemovals = removals.size;
|
||||
if (signal.aborted) {
|
||||
|
|
@ -161,31 +193,20 @@ async function idbFolderCleanup(keepSet, folder, signal) {
|
|||
}
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = db.transaction('thumbs', 'readwrite');
|
||||
function abortTransaction() {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
transaction.abort();
|
||||
}
|
||||
signal.addEventListener('abort', abortTransaction);
|
||||
transaction.onabort = () => {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
reject(`Aborting. ${signal.reason}`); // eslint-disable-line prefer-promise-reject-errors
|
||||
};
|
||||
transaction.onerror = () => {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
reject(new Error('Database transaction error'));
|
||||
};
|
||||
transaction.oncomplete = async () => {
|
||||
signal.removeEventListener('abort', abortTransaction);
|
||||
resolve(totalRemovals);
|
||||
};
|
||||
const props = { transaction, signal, resolve, reject };
|
||||
configureTransactionAbort(props, totalRemovals);
|
||||
const store = transaction.objectStore('thumbs');
|
||||
removals.forEach((entry) => { store.delete(entry); });
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const store = transaction.objectStore('thumbs');
|
||||
removals.forEach((entry) => { store.delete(entry); });
|
||||
} catch (err) {
|
||||
error(err);
|
||||
abortTransaction();
|
||||
}
|
||||
async function idbClearAll(signal) {
|
||||
if (!db) return null;
|
||||
return new Promise((resolve, reject) => {
|
||||
const transaction = db.transaction(['thumbs'], 'readwrite');
|
||||
const props = { transaction, signal, resolve, reject };
|
||||
configureTransactionAbort(props, null);
|
||||
transaction.objectStore('thumbs').clear();
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,31 +1,63 @@
|
|||
const getModel = () => {
|
||||
const cp = opts?.sd_model_checkpoint || '';
|
||||
if (!cp) return 'unknown model';
|
||||
const noBracket = cp.replace(/\s*\[.*\]\s*$/, ''); // remove trailing [hash]
|
||||
const parts = noBracket.split(/[\\/]/); // split on / or \
|
||||
return parts[parts.length - 1].trim() || 'unknown model';
|
||||
};
|
||||
class ConnectionMonitorState {
|
||||
static element;
|
||||
static version = '';
|
||||
static commit = '';
|
||||
static branch = '';
|
||||
static online = false;
|
||||
|
||||
static getModel() {
|
||||
const cp = opts?.sd_model_checkpoint || '';
|
||||
return cp ? this.trimModelName(cp) : 'unknown model';
|
||||
}
|
||||
|
||||
static trimModelName(name) {
|
||||
// remove trailing [hash], split on / or \, return last segment, trim
|
||||
return name.replace(/\s*\[.*\]\s*$/, '').split(/[\\/]/).pop().trim() || 'unknown model';
|
||||
}
|
||||
|
||||
static setData({ online, updated, commit, branch }) {
|
||||
this.online = online;
|
||||
this.version = updated;
|
||||
this.commit = commit;
|
||||
this.branch = branch;
|
||||
}
|
||||
|
||||
static setElement(el) {
|
||||
this.element = el;
|
||||
}
|
||||
|
||||
static toHTML(modelOverride) {
|
||||
return `
|
||||
Version: <b>${this.version}</b><br>
|
||||
Commit: <b>${this.commit}</b><br>
|
||||
Branch: <b>${this.branch}</b><br>
|
||||
Status: ${this.online ? '<b style="color:lime">online</b>' : '<b style="color:darkred">offline</b>'}<br>
|
||||
Model: <b>${modelOverride ? this.trimModelName(modelOverride) : this.getModel()}</b><br>
|
||||
Since: ${new Date().toLocaleString()}<br>
|
||||
`;
|
||||
}
|
||||
|
||||
static updateState(incomingModel) {
|
||||
this.element.dataset.hint = this.toHTML(incomingModel);
|
||||
this.element.style.backgroundColor = this.online ? 'var(--sd-main-accent-color)' : 'var(--color-error)';
|
||||
}
|
||||
}
|
||||
|
||||
let monitorAutoUpdating = false;
|
||||
|
||||
async function updateIndicator(online, data, msg) {
|
||||
const el = document.getElementById('logo_nav');
|
||||
if (!el || !data) return;
|
||||
const status = online ? '<b style="color:lime">online</b>' : '<b style="color:darkred">offline</b>';
|
||||
const date = new Date();
|
||||
const template = `
|
||||
Version: <b>${data.updated}</b><br>
|
||||
Commit: <b>${data.commit}</b><br>
|
||||
Branch: <b>${data.branch}</b><br>
|
||||
Status: ${status}<br>
|
||||
Model: <b>${getModel()}</b><br>
|
||||
Since: ${date.toLocaleString()}<br>
|
||||
`;
|
||||
ConnectionMonitorState.setElement(el);
|
||||
if (!monitorAutoUpdating) {
|
||||
monitorOption('sd_model_checkpoint', (newVal) => { ConnectionMonitorState.updateState(newVal); }); // Runs before opt actually changes
|
||||
monitorAutoUpdating = true;
|
||||
}
|
||||
ConnectionMonitorState.setData({ online, ...data });
|
||||
ConnectionMonitorState.updateState();
|
||||
if (online) {
|
||||
el.dataset.hint = template;
|
||||
el.style.backgroundColor = 'var(--sd-main-accent-color)';
|
||||
log('monitorConnection: online', data);
|
||||
} else {
|
||||
el.dataset.hint = template;
|
||||
el.style.backgroundColor = 'var(--color-error)';
|
||||
log('monitorConnection: offline', msg);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ const monitoredOpts = [
|
|||
{ sd_backend: () => gradioApp().getElementById('refresh_sd_model_checkpoint')?.click() },
|
||||
];
|
||||
|
||||
function monitorOption(option, callback) {
|
||||
monitoredOpts.push({ [option]: callback });
|
||||
}
|
||||
|
||||
const AppyOpts = [
|
||||
{ compact_view: (val, old) => toggleCompact(val, old) },
|
||||
{ gradio_theme: (val, old) => setTheme(val, old) },
|
||||
|
|
@ -25,17 +29,15 @@ async function updateOpts(json_string) {
|
|||
|
||||
const t1 = performance.now();
|
||||
for (const op of monitoredOpts) {
|
||||
const key = Object.keys(op)[0];
|
||||
const callback = op[key];
|
||||
if (opts[key] && opts[key] !== settings_data.values[key]) {
|
||||
log('updateOpt', key, opts[key], settings_data.values[key]);
|
||||
const [key, callback] = Object.entries(op)[0];
|
||||
if (Object.hasOwn(opts, key) && opts[key] !== new_opts[key]) {
|
||||
log('updateOpt', key, opts[key], new_opts[key]);
|
||||
if (callback) callback(new_opts[key], opts[key]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const op of AppyOpts) {
|
||||
const key = Object.keys(op)[0];
|
||||
const callback = op[key];
|
||||
const [key, callback] = Object.entries(op)[0];
|
||||
if (callback) callback(new_opts[key], opts[key]);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -574,7 +574,7 @@ function toggleCompact(val, old) {
|
|||
|
||||
function previewTheme() {
|
||||
let name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('input')?.[0].value || '';
|
||||
fetch(`${window.subpath}/file=html/themes.json`)
|
||||
fetch(`${window.subpath}/file=data/themes.json`)
|
||||
.then((res) => {
|
||||
res.json()
|
||||
.then((themes) => {
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 255 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 64 KiB |
|
|
@ -103,6 +103,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
|
||||
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
|
||||
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict)
|
||||
|
||||
# lora api
|
||||
from modules.api import loras
|
||||
|
|
@ -116,6 +117,10 @@ class Api:
|
|||
from modules.api import nudenet
|
||||
nudenet.register_api()
|
||||
|
||||
# xyz-grid api
|
||||
from modules.api import xyz_grid
|
||||
xyz_grid.register_api()
|
||||
|
||||
# civitai api
|
||||
from modules.civitai import api_civitai
|
||||
api_civitai.register_api()
|
||||
|
|
|
|||
|
|
@ -6,8 +6,28 @@ from modules.api import models, helpers
|
|||
|
||||
|
||||
def get_samplers():
|
||||
from modules import sd_samplers
|
||||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
||||
from modules import sd_samplers_diffusers
|
||||
all_samplers = []
|
||||
for k, v in sd_samplers_diffusers.config.items():
|
||||
if k in ['All', 'Default', 'Res4Lyf']:
|
||||
continue
|
||||
all_samplers.append({
|
||||
'name': k,
|
||||
'options': v,
|
||||
})
|
||||
return all_samplers
|
||||
|
||||
def get_sampler():
|
||||
if not shared.sd_loaded or shared.sd_model is None:
|
||||
return {}
|
||||
if hasattr(shared.sd_model, 'scheduler'):
|
||||
scheduler = shared.sd_model.scheduler
|
||||
config = {k: v for k, v in scheduler.config.items() if not k.startswith('_')}
|
||||
return {
|
||||
'name': scheduler.__class__.__name__,
|
||||
'options': config
|
||||
}
|
||||
return {}
|
||||
|
||||
def get_sd_vaes():
|
||||
from modules.sd_vae import vae_dict
|
||||
|
|
@ -75,6 +95,13 @@ def get_interrogate():
|
|||
from modules.interrogate.openclip import refresh_clip_models
|
||||
return ['deepdanbooru'] + refresh_clip_models()
|
||||
|
||||
def get_schedulers():
|
||||
from modules.sd_samplers import list_samplers
|
||||
all_schedulers = list_samplers()
|
||||
for s in all_schedulers:
|
||||
shared.log.critical(s)
|
||||
return all_schedulers
|
||||
|
||||
def post_interrogate(req: models.ReqInterrogate):
|
||||
if req.image is None or len(req.image) < 64:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
|
|
|||
|
|
@ -86,8 +86,7 @@ class PydanticModelGenerator:
|
|||
|
||||
class ItemSampler(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: List[str] = Field(title="Aliases")
|
||||
options: Dict[str, str] = Field(title="Options")
|
||||
options: dict
|
||||
|
||||
class ItemVae(BaseModel):
|
||||
model_name: str = Field(title="Model Name")
|
||||
|
|
@ -199,6 +198,11 @@ class ItemExtension(BaseModel):
|
|||
commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||
|
||||
class ItemScheduler(BaseModel):
|
||||
name: str = Field(title="Name", description="Scheduler name")
|
||||
cls: str = Field(title="Class", description="Scheduler class name")
|
||||
options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
|
||||
|
||||
### request/response classes
|
||||
|
||||
ReqTxt2Img = PydanticModelGenerator(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
from typing import List
|
||||
|
||||
|
||||
def xyz_grid_enum(option: str = "") -> List[dict]:
|
||||
from scripts.xyz import xyz_grid_classes # pylint: disable=no-name-in-module
|
||||
options = []
|
||||
for x in xyz_grid_classes.axis_options:
|
||||
_option = {
|
||||
'label': x.label,
|
||||
'type': x.type.__name__,
|
||||
'cost': x.cost,
|
||||
'choices': x.choices is not None,
|
||||
}
|
||||
if len(option) == 0:
|
||||
options.append(_option)
|
||||
else:
|
||||
if x.label.lower().startswith(option.lower()) or x.label.lower().endswith(option.lower()):
|
||||
if callable(x.choices):
|
||||
_option['choices'] = x.choices()
|
||||
options.append(_option)
|
||||
return options
|
||||
|
||||
|
||||
def register_api():
|
||||
from modules.shared import api as api_instance
|
||||
api_instance.add_api_route("/sdapi/v1/xyz-grid", xyz_grid_enum, methods=["GET"], response_model=List[dict])
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from __future__ import annotations
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
class ErrorLimiterTrigger(BaseException): # Use BaseException to avoid being caught by "except Exception:".
|
||||
def __init__(self, name: str, *args):
|
||||
super().__init__(*args)
|
||||
self.name = name
|
||||
|
||||
|
||||
class ErrorLimiterAbort(RuntimeError):
|
||||
def __init__(self, msg: str):
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ErrorLimiter:
|
||||
_store: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def start(cls, name: str, limit: int = 5):
|
||||
cls._store[name] = limit
|
||||
|
||||
@classmethod
|
||||
def notify(cls, name: str | Iterable[str]): # Can be manually triggered if execution is spread across multiple files
|
||||
if isinstance(name, str):
|
||||
name = (name,)
|
||||
for key in name:
|
||||
if key in cls._store.keys():
|
||||
cls._store[key] = cls._store[key] - 1
|
||||
if cls._store[key] <= 0:
|
||||
raise ErrorLimiterTrigger(key)
|
||||
|
||||
@classmethod
|
||||
def end(cls, name: str):
|
||||
cls._store.pop(name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def limit_errors(name: str, limit: int = 5):
|
||||
"""Limiter for aborting execution after being triggered a specified number of times (default 5).
|
||||
|
||||
>>> with limit_errors("identifier", limit=5) as elimit:
|
||||
>>> while do_thing():
|
||||
>>> if (something_bad):
|
||||
>>> print("Something bad happened")
|
||||
>>> elimit() # In this example, raises ErrorLimiterAbort on the 5th call
|
||||
>>> try:
|
||||
>>> something_broken()
|
||||
>>> except Exception:
|
||||
>>> print("Encountered an exception")
|
||||
>>> elimit() # Count is shared across all calls
|
||||
|
||||
Args:
|
||||
name (str): Identifier.
|
||||
limit (int, optional): Abort after `limit` number of triggers. Defaults to 5.
|
||||
|
||||
Raises:
|
||||
ErrorLimiterAbort: Subclass of RuntimeException.
|
||||
|
||||
Yields:
|
||||
Callable: Notification function to indicate that an error occurred.
|
||||
"""
|
||||
try:
|
||||
ErrorLimiter.start(name, limit)
|
||||
yield lambda: ErrorLimiter.notify(name)
|
||||
except ErrorLimiterTrigger as e:
|
||||
raise ErrorLimiterAbort(f"HALTING. Too many errors during '{e.name}'") from None
|
||||
finally:
|
||||
ErrorLimiter.end(name)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import warnings
|
||||
from installer import get_log, get_console, setup_logging, install_traceback
|
||||
from modules.errorlimiter import ErrorLimiterAbort
|
||||
|
||||
|
||||
log = get_log()
|
||||
|
|
@ -16,9 +17,18 @@ def install(suppress=[]):
|
|||
|
||||
|
||||
def display(e: Exception, task: str, suppress=[]):
|
||||
log.error(f"{task or 'error'}: {type(e).__name__}")
|
||||
if isinstance(e, ErrorLimiterAbort):
|
||||
return
|
||||
log.critical(f"{task or 'error'}: {type(e).__name__}")
|
||||
"""
|
||||
trace = traceback.format_exc()
|
||||
log.error(trace)
|
||||
for line in traceback.format_tb(e.__traceback__):
|
||||
log.error(repr(line))
|
||||
console = get_console()
|
||||
console.print_exception(show_locals=False, max_frames=16, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
|
||||
"""
|
||||
log.traceback(e, suppress=suppress)
|
||||
|
||||
|
||||
def display_once(e: Exception, task):
|
||||
|
|
|
|||
|
|
@ -151,33 +151,30 @@ def deactivate(p, extra_network_data=None, force=shared.opts.lora_force_reload):
|
|||
re_extra_net = re.compile(r"<(\w+):([^>]+)>")
|
||||
|
||||
|
||||
def parse_prompt(prompt):
|
||||
res = defaultdict(list)
|
||||
def parse_prompt(prompt: str | None) -> tuple[str, defaultdict[str, list[ExtraNetworkParams]]]:
|
||||
res: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list)
|
||||
if prompt is None:
|
||||
return prompt, res
|
||||
return "", res
|
||||
if isinstance(prompt, list):
|
||||
shared.log.warning(f"parse_prompt was called with a list instead of a string: {prompt}")
|
||||
return parse_prompts(prompt)
|
||||
|
||||
def found(m):
|
||||
name = m.group(1)
|
||||
args = m.group(2)
|
||||
def found(m: re.Match[str]):
|
||||
name, args = m.group(1, 2)
|
||||
res[name].append(ExtraNetworkParams(items=args.split(":")))
|
||||
return ""
|
||||
if isinstance(prompt, list):
|
||||
prompt = [re.sub(re_extra_net, found, p) for p in prompt]
|
||||
else:
|
||||
prompt = re.sub(re_extra_net, found, prompt)
|
||||
return prompt, res
|
||||
|
||||
updated_prompt = re.sub(re_extra_net, found, prompt)
|
||||
return updated_prompt, res
|
||||
|
||||
|
||||
def parse_prompts(prompts):
|
||||
res = []
|
||||
extra_data = None
|
||||
if prompts is None:
|
||||
return prompts, extra_data
|
||||
|
||||
def parse_prompts(prompts: list[str]):
|
||||
updated_prompt_list: list[str] = []
|
||||
extra_data: defaultdict[str, list[ExtraNetworkParams]] = defaultdict(list)
|
||||
for prompt in prompts:
|
||||
updated_prompt, parsed_extra_data = parse_prompt(prompt)
|
||||
if extra_data is None:
|
||||
if not extra_data:
|
||||
extra_data = parsed_extra_data
|
||||
res.append(updated_prompt)
|
||||
updated_prompt_list.append(updated_prompt)
|
||||
|
||||
return res, extra_data
|
||||
return updated_prompt_list, extra_data
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ def face_id(
|
|||
ip_model_dict["faceid_embeds"] = face_embeds # overwrite placeholder
|
||||
faceid_model.set_scale(scale)
|
||||
|
||||
if p.all_prompts is None or len(p.all_prompts) == 0:
|
||||
if not p.all_prompts:
|
||||
processing.process_init(p)
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
for n in range(p.n_iter):
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def instant_id(p: processing.StableDiffusionProcessing, app, source_images, stre
|
|||
sd_models.move_model(shared.sd_model, devices.device) # move pipeline to device
|
||||
|
||||
# pipeline specific args
|
||||
if p.all_prompts is None or len(p.all_prompts) == 0:
|
||||
if not p.all_prompts:
|
||||
processing.process_init(p)
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
orig_prompt_attention = shared.opts.prompt_attention
|
||||
|
|
@ -73,8 +73,8 @@ def instant_id(p: processing.StableDiffusionProcessing, app, source_images, stre
|
|||
p.task_args['controlnet_conditioning_scale'] = float(conditioning)
|
||||
p.task_args['ip_adapter_scale'] = float(strength)
|
||||
shared.log.debug(f"InstantID args: {p.task_args}")
|
||||
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts is not None else p.prompt
|
||||
p.task_args['negative_prompt'] = p.all_negative_prompts[0] if p.all_negative_prompts is not None else p.negative_prompt
|
||||
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts else p.prompt
|
||||
p.task_args['negative_prompt'] = p.all_negative_prompts[0] if p.all_negative_prompts else p.negative_prompt
|
||||
p.task_args['image_embeds'] = face_embeds[0] # overwrite placeholder
|
||||
|
||||
# run processing
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def photo_maker(p: processing.StableDiffusionProcessing, app, model: str, input_
|
|||
return None
|
||||
|
||||
# validate prompt
|
||||
if p.all_prompts is None or len(p.all_prompts) == 0:
|
||||
if not p.all_prompts:
|
||||
processing.process_init(p)
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
trigger_ids = shared.sd_model.tokenizer.encode(trigger) + shared.sd_model.tokenizer_2.encode(trigger)
|
||||
|
|
@ -61,7 +61,7 @@ def photo_maker(p: processing.StableDiffusionProcessing, app, model: str, input_
|
|||
shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask
|
||||
p.task_args['input_id_images'] = input_images
|
||||
p.task_args['start_merge_step'] = int(start * p.steps)
|
||||
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts is not None else p.prompt
|
||||
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts else p.prompt
|
||||
|
||||
is_v2 = 'v2' in model
|
||||
if is_v2:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def vae_decode_simple(latents):
|
|||
def vae_decode_tiny(latents):
|
||||
global taesd # pylint: disable=global-statement
|
||||
if taesd is None:
|
||||
from modules import sd_vae_taesd
|
||||
from modules.vae import sd_vae_taesd
|
||||
taesd, _variant = sd_vae_taesd.get_model(variant='TAE HunyuanVideo')
|
||||
shared.log.debug(f'Video VAE: type=Tiny cls={taesd.__class__.__name__} latents={latents.shape}')
|
||||
with devices.inference_context():
|
||||
|
|
@ -56,7 +56,7 @@ def vae_decode_tiny(latents):
|
|||
|
||||
|
||||
def vae_decode_remote(latents):
|
||||
# from modules.sd_vae_remote import remote_decode
|
||||
# from modules.vae.sd_vae_remote import remote_decode
|
||||
# images = remote_decode(latents, model_type='hunyuanvideo')
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
images = remote_decode(
|
||||
|
|
|
|||
|
|
@ -309,16 +309,18 @@ def worker(
|
|||
break
|
||||
|
||||
total_generated_frames, _video_filename = save_video(
|
||||
None,
|
||||
history_pixels,
|
||||
mp4_fps,
|
||||
mp4_codec,
|
||||
mp4_opt,
|
||||
mp4_ext,
|
||||
mp4_sf,
|
||||
mp4_video,
|
||||
mp4_frames,
|
||||
mp4_interpolate,
|
||||
p=None,
|
||||
pixels=history_pixels,
|
||||
audio=None,
|
||||
binary=None,
|
||||
mp4_fps=mp4_fps,
|
||||
mp4_codec=mp4_codec,
|
||||
mp4_opt=mp4_opt,
|
||||
mp4_ext=mp4_ext,
|
||||
mp4_sf=mp4_sf,
|
||||
mp4_video=mp4_video,
|
||||
mp4_frames=mp4_frames,
|
||||
mp4_interpolate=mp4_interpolate,
|
||||
pbar=pbar,
|
||||
stream=stream,
|
||||
metadata=metadata,
|
||||
|
|
@ -327,7 +329,23 @@ def worker(
|
|||
except AssertionError:
|
||||
shared.log.info('FramePack: interrupted')
|
||||
if shared.opts.keep_incomplete:
|
||||
save_video(None, history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate=0, stream=stream, metadata=metadata)
|
||||
save_video(
|
||||
p=None,
|
||||
pixels=history_pixels,
|
||||
audio=None,
|
||||
binary=None,
|
||||
mp4_fps=mp4_fps,
|
||||
mp4_codec=mp4_codec,
|
||||
mp4_opt=mp4_opt,
|
||||
mp4_ext=mp4_ext,
|
||||
mp4_sf=mp4_sf,
|
||||
mp4_video=mp4_video,
|
||||
mp4_frames=mp4_frames,
|
||||
mp4_interpolate=0,
|
||||
pbar=pbar,
|
||||
stream=stream,
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
shared.log.error(f'FramePack: {e}')
|
||||
errors.display(e, 'FramePack')
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ debug('Trace: PASTE')
|
|||
parse_generation_parameters = parse # compatibility
|
||||
infotext_to_setting_name_mapping = mapping # compatibility
|
||||
|
||||
# Mapping of aliases to metadata parameter names, populated automatically from component labels/elem_ids
|
||||
# This allows users to use component labels, elem_ids, or metadata names in the "skip params" setting
|
||||
param_aliases: dict[str, str] = {}
|
||||
|
||||
|
||||
class ParamBinding:
|
||||
def __init__(self, paste_button, tabname: str, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||
|
|
@ -74,7 +78,8 @@ def image_from_url_text(filedata):
|
|||
filedata = filedata[len("data:image/jxl;base64,"):]
|
||||
filebytes = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filebytes))
|
||||
images.read_info_from_image(image)
|
||||
image.load()
|
||||
# images.read_info_from_image(image)
|
||||
return image
|
||||
|
||||
|
||||
|
|
@ -85,9 +90,36 @@ def add_paste_fields(tabname: str, init_img: gr.Image | gr.HTML | None, fields:
|
|||
except Exception as e:
|
||||
shared.log.error(f"Paste fields: tab={tabname} fields={fields} {e}")
|
||||
field_names[tabname] = []
|
||||
|
||||
# Build param_aliases automatically from component labels and elem_ids
|
||||
if fields is not None:
|
||||
for component, metadata_name in fields:
|
||||
if metadata_name is None or callable(metadata_name):
|
||||
continue
|
||||
metadata_lower = metadata_name.lower()
|
||||
# Extract label from component (e.g., "Batch size" -> maps to "Batch-2")
|
||||
label = getattr(component, 'label', None)
|
||||
if label and isinstance(label, str):
|
||||
label_lower = label.lower()
|
||||
if label_lower != metadata_lower and label_lower not in param_aliases:
|
||||
param_aliases[label_lower] = metadata_lower
|
||||
# Extract elem_id and derive variable name (e.g., "txt2img_batch_size" -> "batch_size")
|
||||
elem_id = getattr(component, 'elem_id', None)
|
||||
if elem_id and isinstance(elem_id, str):
|
||||
# Strip common prefixes like "txt2img_", "img2img_", "control_"
|
||||
var_name = elem_id
|
||||
for prefix in ['txt2img_', 'img2img_', 'control_', 'video_', 'extras_']:
|
||||
if var_name.startswith(prefix):
|
||||
var_name = var_name[len(prefix):]
|
||||
break
|
||||
var_name_lower = var_name.lower()
|
||||
if var_name_lower != metadata_lower and var_name_lower not in param_aliases:
|
||||
param_aliases[var_name_lower] = metadata_lower
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
debug(f'Paste fields: tab={tabname} fields={field_names[tabname]}')
|
||||
debug(f'All fields: {get_all_fields()}')
|
||||
debug(f'Param aliases: {param_aliases}')
|
||||
import modules.ui
|
||||
if tabname == 'txt2img':
|
||||
modules.ui.txt2img_paste_fields = fields # compatibility
|
||||
|
|
@ -133,10 +165,22 @@ def should_skip(param: str):
|
|||
skip_params = [p.strip().lower() for p in shared.opts.disable_apply_params.split(",")]
|
||||
if not shared.opts.clip_skip_enabled:
|
||||
skip_params += ['clip skip']
|
||||
|
||||
# Expand skip_params with aliases (e.g., "batch_size" -> "batch-2")
|
||||
expanded_skip = set(skip_params)
|
||||
for skip in skip_params:
|
||||
if skip in param_aliases:
|
||||
expanded_skip.add(param_aliases[skip])
|
||||
|
||||
# Check if param should be skipped
|
||||
param_lower = param.lower()
|
||||
# Also check normalized name (without -1/-2) so "batch" skips both "batch-1" and "batch-2"
|
||||
param_normalized = param_lower.replace('-1', '').replace('-2', '')
|
||||
|
||||
all_params = [p.lower() for p in get_all_fields()]
|
||||
valid = any(p in all_params for p in skip_params)
|
||||
skip = param.lower() in skip_params
|
||||
debug(f'Check: param="{param}" valid={valid} skip={skip}')
|
||||
skip = param_lower in expanded_skip or param_normalized in expanded_skip
|
||||
debug(f'Check: param="{param}" valid={valid} skip={skip} expanded={expanded_skip}')
|
||||
return skip
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -104,6 +104,9 @@ def on_tmpdir_changed():
|
|||
def cleanup_tmpdr():
|
||||
temp_dir = shared.opts.temp_dir
|
||||
if temp_dir == "" or not os.path.isdir(temp_dir):
|
||||
temp_dir = os.path.join(paths.temp_dir, "gradio")
|
||||
shared.log.debug(f'Temp folder: path="{temp_dir}"')
|
||||
if not os.path.isdir(temp_dir):
|
||||
return
|
||||
for root, _dirs, files in os.walk(temp_dir, topdown=False):
|
||||
for name in files:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from modules.json_helpers import readfile, writefile
|
|||
from modules.paths import data_path
|
||||
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_filename = os.path.join(data_path, 'data', 'cache.json')
|
||||
cache_data = None
|
||||
progress_ok = True
|
||||
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ def parse_novelai_metadata(data: dict):
|
|||
return geninfo
|
||||
|
||||
|
||||
def read_info_from_image(image: Image.Image, watermark: bool = False):
|
||||
def read_info_from_image(image: Image.Image, watermark: bool = False) -> tuple[str, dict]:
|
||||
if image is None:
|
||||
return '', {}
|
||||
if isinstance(image, str):
|
||||
|
|
@ -322,9 +322,11 @@ def read_info_from_image(image: Image.Image, watermark: bool = False):
|
|||
return '', {}
|
||||
items = image.info or {}
|
||||
geninfo = items.pop('parameters', None) or items.pop('UserComment', None) or ''
|
||||
if geninfo is not None and len(geninfo) > 0:
|
||||
if isinstance(geninfo, dict):
|
||||
if 'UserComment' in geninfo:
|
||||
geninfo = geninfo['UserComment']
|
||||
geninfo = geninfo['UserComment'] # Info was nested
|
||||
else:
|
||||
geninfo = '' # Unknown format. Ignore contents
|
||||
items['UserComment'] = geninfo
|
||||
|
||||
if "exif" in items:
|
||||
|
|
@ -342,7 +344,7 @@ def read_info_from_image(image: Image.Image, watermark: bool = False):
|
|||
val = round(val[0] / val[1], 2)
|
||||
if val is not None and key in ExifTags.TAGS: # add known tags
|
||||
if ExifTags.TAGS[key] == 'UserComment': # add geninfo from UserComment
|
||||
geninfo = val
|
||||
geninfo = str(val)
|
||||
items['parameters'] = val
|
||||
else:
|
||||
items[ExifTags.TAGS[key]] = val
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from pathlib import Path
|
|||
from modules import shared, errors
|
||||
|
||||
|
||||
debug = errors.log.trace if os.environ.get('SD_NAMEGEN_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
debug= os.environ.get('SD_NAMEGEN_DEBUG', None) is not None
|
||||
debug_log = errors.log.trace if debug else lambda *args, **kwargs: None
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
|
|
@ -66,9 +67,9 @@ class FilenameGenerator:
|
|||
|
||||
def __init__(self, p, seed, prompt, image=None, grid=False, width=None, height=None):
|
||||
if p is None:
|
||||
debug('Filename generator init skip')
|
||||
debug_log('Filename generator init skip')
|
||||
else:
|
||||
debug(f'Filename generator init: seed={seed} prompt="{prompt}"')
|
||||
debug_log(f'Filename generator init: seed={seed} prompt="{prompt}"')
|
||||
self.p = p
|
||||
if seed is not None and int(seed) > 0:
|
||||
self.seed = seed
|
||||
|
|
@ -163,7 +164,7 @@ class FilenameGenerator:
|
|||
def prompt_sanitize(self, prompt):
|
||||
invalid_chars = '#<>:\'"\\|?*\n\t\r'
|
||||
sanitized = prompt.translate({ ord(x): '_' for x in invalid_chars }).strip()
|
||||
debug(f'Prompt sanitize: input="{prompt}" output={sanitized}')
|
||||
debug_log(f'Prompt sanitize: input="{prompt}" output="{sanitized}"')
|
||||
return sanitized
|
||||
|
||||
def sanitize(self, filename):
|
||||
|
|
@ -200,7 +201,7 @@ class FilenameGenerator:
|
|||
while len(os.path.abspath(fn)) > max_length:
|
||||
fn = fn[:-1]
|
||||
fn += ext
|
||||
debug(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
|
||||
debug_log(f'Filename sanitize: input="{filename}" parts={parts} output="{fn}" ext={ext} max={max_length} len={len(fn)}')
|
||||
return fn
|
||||
|
||||
def safe_int(self, s):
|
||||
|
|
@ -234,25 +235,38 @@ class FilenameGenerator:
|
|||
|
||||
def apply(self, x):
|
||||
res = ''
|
||||
if debug:
|
||||
for k in self.replacements.keys():
|
||||
try:
|
||||
fn = self.replacements.get(k, None)
|
||||
debug_log(f'Namegen: key={k} value={fn(self)}')
|
||||
except Exception as e:
|
||||
shared.log.error(f'Namegen: key={k} {e}')
|
||||
errors.display(e, 'namegen')
|
||||
for m in re_pattern.finditer(x):
|
||||
text, pattern = m.groups()
|
||||
if pattern is None:
|
||||
res += text
|
||||
continue
|
||||
pattern_args = []
|
||||
while True:
|
||||
m = re_pattern_arg.match(pattern)
|
||||
if m is None:
|
||||
break
|
||||
pattern, arg = m.groups()
|
||||
pattern_args.insert(0, arg)
|
||||
debug_log(f'Filename apply: text="{text}" pattern="{pattern}"')
|
||||
if isinstance(pattern, list):
|
||||
pattern = ' '.join(pattern)
|
||||
if pattern is None or not isinstance(pattern, str) or pattern.strip() == '':
|
||||
debug_log(f'Filename skip: pattern="{pattern}"')
|
||||
res += text
|
||||
continue
|
||||
|
||||
_pattern = pattern
|
||||
pattern_args = []
|
||||
while True:
|
||||
m = re_pattern_arg.match(_pattern)
|
||||
if m is None:
|
||||
break
|
||||
_pattern, arg = m.groups()
|
||||
pattern_args.insert(0, arg)
|
||||
|
||||
fun = self.replacements.get(pattern.lower(), None)
|
||||
if fun is not None:
|
||||
try:
|
||||
debug(f'Filename apply: pattern={pattern.lower()} args={pattern_args}')
|
||||
replacement = fun(self, *pattern_args)
|
||||
debug_log(f'Filename apply: pattern="{pattern}" args={pattern_args} replacement="{replacement}"')
|
||||
except Exception as e:
|
||||
replacement = None
|
||||
errors.display(e, 'namegen')
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import threading
|
|||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from modules import modelloader, paths, devices, shared, sd_models
|
||||
from modules import modelloader, paths, devices, shared
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
load_lock = threading.Lock()
|
||||
|
|
@ -35,21 +35,55 @@ class DeepDanbooru:
|
|||
|
||||
def start(self):
|
||||
self.load()
|
||||
sd_models.move_model(self.model, devices.device)
|
||||
self.model.to(devices.device)
|
||||
|
||||
def stop(self):
|
||||
if shared.opts.interrogate_offload:
|
||||
sd_models.move_model(self.model, devices.cpu)
|
||||
self.model.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
|
||||
def tag(self, pil_image):
|
||||
def tag(self, pil_image, **kwargs):
|
||||
self.start()
|
||||
res = self.tag_multi(pil_image)
|
||||
res = self.tag_multi(pil_image, **kwargs)
|
||||
self.stop()
|
||||
|
||||
return res
|
||||
|
||||
def tag_multi(self, pil_image, force_disable_ranks=False):
|
||||
def tag_multi(
|
||||
self,
|
||||
pil_image,
|
||||
general_threshold: float = None,
|
||||
include_rating: bool = None,
|
||||
exclude_tags: str = None,
|
||||
max_tags: int = None,
|
||||
sort_alpha: bool = None,
|
||||
use_spaces: bool = None,
|
||||
escape_brackets: bool = None,
|
||||
):
|
||||
"""Run inference and return formatted tag string.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image to tag
|
||||
general_threshold: Threshold for tag scores (0-1)
|
||||
include_rating: Whether to include rating tags
|
||||
exclude_tags: Comma-separated tags to exclude
|
||||
max_tags: Maximum number of tags to return
|
||||
sort_alpha: Sort tags alphabetically vs by confidence
|
||||
use_spaces: Use spaces instead of underscores
|
||||
escape_brackets: Escape parentheses/brackets in tags
|
||||
|
||||
Returns:
|
||||
Formatted tag string
|
||||
"""
|
||||
# Use settings defaults if not specified
|
||||
general_threshold = general_threshold or shared.opts.tagger_threshold
|
||||
include_rating = include_rating if include_rating is not None else shared.opts.tagger_include_rating
|
||||
exclude_tags = exclude_tags or shared.opts.tagger_exclude_tags
|
||||
max_tags = max_tags or shared.opts.tagger_max_tags
|
||||
sort_alpha = sort_alpha if sort_alpha is not None else shared.opts.tagger_sort_alpha
|
||||
use_spaces = use_spaces if use_spaces is not None else shared.opts.tagger_use_spaces
|
||||
escape_brackets = escape_brackets if escape_brackets is not None else shared.opts.tagger_escape_brackets
|
||||
|
||||
if isinstance(pil_image, list):
|
||||
pil_image = pil_image[0] if len(pil_image) > 0 else None
|
||||
if isinstance(pil_image, dict) and 'name' in pil_image:
|
||||
|
|
@ -58,35 +92,237 @@ class DeepDanbooru:
|
|||
return ''
|
||||
pic = pil_image.resize((512, 512), resample=Image.Resampling.LANCZOS).convert("RGB")
|
||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||
with devices.inference_context(), devices.autocast():
|
||||
x = torch.from_numpy(a).to(devices.device)
|
||||
with devices.inference_context():
|
||||
x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype)
|
||||
y = self.model(x)[0].detach().float().cpu().numpy()
|
||||
probability_dict = {}
|
||||
for tag, probability in zip(self.model.tags, y):
|
||||
if probability < shared.opts.deepbooru_score_threshold:
|
||||
for current, probability in zip(self.model.tags, y):
|
||||
if probability < general_threshold:
|
||||
continue
|
||||
if tag.startswith("rating:"):
|
||||
if current.startswith("rating:") and not include_rating:
|
||||
continue
|
||||
probability_dict[tag] = probability
|
||||
if shared.opts.deepbooru_sort_alpha:
|
||||
probability_dict[current] = probability
|
||||
if sort_alpha:
|
||||
tags = sorted(probability_dict)
|
||||
else:
|
||||
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
|
||||
res = []
|
||||
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||
for tag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[tag]
|
||||
tag_outformat = tag
|
||||
if shared.opts.deepbooru_use_spaces:
|
||||
filtertags = {x.strip().replace(' ', '_') for x in exclude_tags.split(",")}
|
||||
for filtertag in [x for x in tags if x not in filtertags]:
|
||||
probability = probability_dict[filtertag]
|
||||
tag_outformat = filtertag
|
||||
if use_spaces:
|
||||
tag_outformat = tag_outformat.replace('_', ' ')
|
||||
if shared.opts.deepbooru_escape:
|
||||
if escape_brackets:
|
||||
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
|
||||
if shared.opts.interrogate_score and not force_disable_ranks:
|
||||
if shared.opts.tagger_show_scores:
|
||||
tag_outformat = f"({tag_outformat}:{probability:.2f})"
|
||||
res.append(tag_outformat)
|
||||
if len(res) > shared.opts.deepbooru_max_tags:
|
||||
res = res[:shared.opts.deepbooru_max_tags]
|
||||
if max_tags > 0 and len(res) > max_tags:
|
||||
res = res[:max_tags]
|
||||
return ", ".join(res)
|
||||
|
||||
|
||||
model = DeepDanbooru()
|
||||
|
||||
|
||||
def _save_tags_to_file(img_path, tags_str: str, save_append: bool) -> bool:
|
||||
"""Save tags to a text file with error handling.
|
||||
|
||||
Args:
|
||||
img_path: Path to the image file
|
||||
tags_str: Tags string to save
|
||||
save_append: If True, append to existing file; otherwise overwrite
|
||||
|
||||
Returns:
|
||||
True if save succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
txt_path = img_path.with_suffix('.txt')
|
||||
if save_append and txt_path.exists():
|
||||
with open(txt_path, 'a', encoding='utf-8') as f:
|
||||
f.write(f', {tags_str}')
|
||||
else:
|
||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(tags_str)
|
||||
return True
|
||||
except Exception as e:
|
||||
shared.log.error(f'DeepBooru batch: failed to save file="{img_path}" error={e}')
|
||||
return False
|
||||
|
||||
|
||||
def get_models() -> list:
|
||||
"""Return list of available DeepBooru models (just one)."""
|
||||
return ["DeepBooru"]
|
||||
|
||||
|
||||
def load_model(model_name: str = None) -> bool: # pylint: disable=unused-argument
|
||||
"""Load the DeepBooru model."""
|
||||
try:
|
||||
model.load()
|
||||
return model.model is not None
|
||||
except Exception as e:
|
||||
shared.log.error(f'DeepBooru load: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def unload_model():
|
||||
"""Unload the DeepBooru model and free memory."""
|
||||
if model.model is not None:
|
||||
shared.log.debug('DeepBooru unload')
|
||||
model.model = None
|
||||
devices.torch_gc(force=True)
|
||||
|
||||
|
||||
def tag(image, **kwargs) -> str:
|
||||
"""Tag an image using DeepBooru.
|
||||
|
||||
Args:
|
||||
image: PIL Image to tag
|
||||
**kwargs: Tagger parameters (general_threshold, include_rating, exclude_tags,
|
||||
max_tags, sort_alpha, use_spaces, escape_brackets)
|
||||
|
||||
Returns:
|
||||
Formatted tag string
|
||||
"""
|
||||
import time
|
||||
t0 = time.time()
|
||||
jobid = shared.state.begin('DeepBooru Tag')
|
||||
shared.log.info(f'DeepBooru: image_size={image.size if image else None}')
|
||||
|
||||
try:
|
||||
result = model.tag(image, **kwargs)
|
||||
shared.log.debug(f'DeepBooru: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
|
||||
except Exception as e:
|
||||
result = f"Exception {type(e)}"
|
||||
shared.log.error(f'DeepBooru: {e}')
|
||||
|
||||
shared.state.end(jobid)
|
||||
return result
|
||||
|
||||
|
||||
def batch(
|
||||
model_name: str, # pylint: disable=unused-argument
|
||||
batch_files: list,
|
||||
batch_folder: str,
|
||||
batch_str: str,
|
||||
save_output: bool = True,
|
||||
save_append: bool = False,
|
||||
recursive: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Process multiple images in batch mode.
|
||||
|
||||
Args:
|
||||
model_name: Model name (ignored, only DeepBooru available)
|
||||
batch_files: List of file paths
|
||||
batch_folder: Folder path from file picker
|
||||
batch_str: Folder path as string
|
||||
save_output: Save caption to .txt files
|
||||
save_append: Append to existing caption files
|
||||
recursive: Recursively process subfolders
|
||||
**kwargs: Additional arguments (for interface compatibility)
|
||||
|
||||
Returns:
|
||||
Combined tag results
|
||||
"""
|
||||
import time
|
||||
from pathlib import Path
|
||||
import rich.progress as rp
|
||||
|
||||
# Load model
|
||||
model.load()
|
||||
|
||||
# Collect image files
|
||||
image_files = []
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'}
|
||||
|
||||
# From file picker
|
||||
if batch_files:
|
||||
for f in batch_files:
|
||||
if isinstance(f, dict):
|
||||
image_files.append(Path(f['name']))
|
||||
elif hasattr(f, 'name'):
|
||||
image_files.append(Path(f.name))
|
||||
else:
|
||||
image_files.append(Path(f))
|
||||
|
||||
# From folder picker
|
||||
if batch_folder:
|
||||
folder_path = None
|
||||
if isinstance(batch_folder, list) and len(batch_folder) > 0:
|
||||
f = batch_folder[0]
|
||||
if isinstance(f, dict):
|
||||
folder_path = Path(f['name']).parent
|
||||
elif hasattr(f, 'name'):
|
||||
folder_path = Path(f.name).parent
|
||||
if folder_path and folder_path.is_dir():
|
||||
if recursive:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.rglob(f'*{ext}'))
|
||||
else:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.glob(f'*{ext}'))
|
||||
|
||||
# From string path
|
||||
if batch_str and batch_str.strip():
|
||||
folder_path = Path(batch_str.strip())
|
||||
if folder_path.is_dir():
|
||||
if recursive:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.rglob(f'*{ext}'))
|
||||
else:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.glob(f'*{ext}'))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_files = []
|
||||
for f in image_files:
|
||||
f_resolved = f.resolve()
|
||||
if f_resolved not in seen:
|
||||
seen.add(f_resolved)
|
||||
unique_files.append(f)
|
||||
image_files = unique_files
|
||||
|
||||
if not image_files:
|
||||
shared.log.warning('DeepBooru batch: no images found')
|
||||
return ''
|
||||
|
||||
t0 = time.time()
|
||||
jobid = shared.state.begin('DeepBooru Batch')
|
||||
shared.log.info(f'DeepBooru batch: images={len(image_files)} write={save_output} append={save_append} recursive={recursive}')
|
||||
|
||||
results = []
|
||||
model.start()
|
||||
|
||||
# Progress bar
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]DeepBooru:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
|
||||
with pbar:
|
||||
task = pbar.add_task(total=len(image_files), description='starting...')
|
||||
for img_path in image_files:
|
||||
pbar.update(task, advance=1, description=str(img_path.name))
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
shared.log.info('DeepBooru batch: interrupted')
|
||||
break
|
||||
|
||||
image = Image.open(img_path)
|
||||
tags_str = model.tag_multi(image, **kwargs)
|
||||
|
||||
if save_output:
|
||||
_save_tags_to_file(img_path, tags_str, save_append)
|
||||
|
||||
results.append(f'{img_path.name}: {tags_str[:100]}...' if len(tags_str) > 100 else f'{img_path.name}: {tags_str}')
|
||||
|
||||
except Exception as e:
|
||||
shared.log.error(f'DeepBooru batch: file="{img_path}" error={e}')
|
||||
results.append(f'{img_path.name}: ERROR - {e}')
|
||||
|
||||
model.stop()
|
||||
elapsed = time.time() - t0
|
||||
shared.log.info(f'DeepBooru batch: complete images={len(results)} time={elapsed:.1f}s')
|
||||
shared.state.end(jobid)
|
||||
|
||||
return '\n'.join(results)
|
||||
|
|
|
|||
|
|
@ -20,10 +20,21 @@ def interrogate(image):
|
|||
prompt = openclip.interrogate(image, mode=shared.opts.interrogate_clip_mode)
|
||||
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
elif shared.opts.interrogate_default_type == 'DeepBooru':
|
||||
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type}')
|
||||
from modules.interrogate import deepbooru
|
||||
prompt = deepbooru.model.tag(image)
|
||||
elif shared.opts.interrogate_default_type == 'Tagger':
|
||||
shared.log.info(f'Interrogate: type={shared.opts.interrogate_default_type} model="{shared.opts.waifudiffusion_model}"')
|
||||
from modules.interrogate import tagger
|
||||
prompt = tagger.tag(
|
||||
image=image,
|
||||
model_name=shared.opts.waifudiffusion_model,
|
||||
general_threshold=shared.opts.tagger_threshold,
|
||||
character_threshold=shared.opts.waifudiffusion_character_threshold,
|
||||
include_rating=shared.opts.tagger_include_rating,
|
||||
exclude_tags=shared.opts.tagger_exclude_tags,
|
||||
max_tags=shared.opts.tagger_max_tags,
|
||||
sort_alpha=shared.opts.tagger_sort_alpha,
|
||||
use_spaces=shared.opts.tagger_use_spaces,
|
||||
escape_brackets=shared.opts.tagger_escape_brackets,
|
||||
)
|
||||
shared.log.debug(f'Interrogate: time={time.time()-t0:.2f} answer="{prompt}"')
|
||||
return prompt
|
||||
elif shared.opts.interrogate_default_type == 'VLM':
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from modules.interrogate import vqa_detection
|
|||
|
||||
|
||||
# Debug logging - function-based to avoid circular import
|
||||
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
if debug_enabled:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import time
|
||||
from collections import namedtuple
|
||||
import threading
|
||||
import re
|
||||
|
|
@ -7,6 +8,23 @@ from PIL import Image
|
|||
from modules import devices, paths, shared, errors, sd_models
|
||||
|
||||
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def _apply_blip2_fix(model, processor):
|
||||
"""Apply compatibility fix for BLIP2 models with newer transformers versions."""
|
||||
from transformers import AddedToken
|
||||
if not hasattr(model.config, 'num_query_tokens'):
|
||||
return
|
||||
processor.num_query_tokens = model.config.num_query_tokens
|
||||
image_token = AddedToken("<image>", normalized=False, special=True)
|
||||
processor.tokenizer.add_tokens([image_token], special_tokens=True)
|
||||
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64)
|
||||
model.config.image_token_index = len(processor.tokenizer) - 1
|
||||
debug_log(f'CLIP load: applied BLIP2 tokenizer fix num_query_tokens={model.config.num_query_tokens}')
|
||||
|
||||
|
||||
caption_models = {
|
||||
'blip-base': 'Salesforce/blip-image-captioning-base',
|
||||
'blip-large': 'Salesforce/blip-image-captioning-large',
|
||||
|
|
@ -79,10 +97,14 @@ def load_interrogator(clip_model, blip_model):
|
|||
clip_interrogator.clip_interrogator.CAPTION_MODELS = caption_models
|
||||
global ci # pylint: disable=global-statement
|
||||
if ci is None:
|
||||
shared.log.debug(f'Interrogate load: clip="{clip_model}" blip="{blip_model}"')
|
||||
t0 = time.time()
|
||||
device = devices.get_optimal_device()
|
||||
cache_path = os.path.join(paths.models_path, 'Interrogator')
|
||||
shared.log.info(f'CLIP load: clip="{clip_model}" blip="{blip_model}" device={device}')
|
||||
debug_log(f'CLIP load: cache_path="{cache_path}" max_length={shared.opts.interrogate_clip_max_length} chunk_size={shared.opts.interrogate_clip_chunk_size} flavor_count={shared.opts.interrogate_clip_flavor_count} offload={shared.opts.interrogate_offload}')
|
||||
interrogator_config = clip_interrogator.Config(
|
||||
device=devices.get_optimal_device(),
|
||||
cache_path=os.path.join(paths.models_path, 'Interrogator'),
|
||||
device=device,
|
||||
cache_path=cache_path,
|
||||
clip_model_name=clip_model,
|
||||
caption_model_name=blip_model,
|
||||
quiet=True,
|
||||
|
|
@ -93,22 +115,39 @@ def load_interrogator(clip_model, blip_model):
|
|||
caption_offload=shared.opts.interrogate_offload,
|
||||
)
|
||||
ci = clip_interrogator.Interrogator(interrogator_config)
|
||||
if blip_model.startswith('blip2-'):
|
||||
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
|
||||
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
|
||||
elif clip_model != ci.config.clip_model_name or blip_model != ci.config.caption_model_name:
|
||||
ci.config.clip_model_name = clip_model
|
||||
ci.config.clip_model = None
|
||||
ci.load_clip_model()
|
||||
ci.config.caption_model_name = blip_model
|
||||
ci.config.caption_model = None
|
||||
ci.load_caption_model()
|
||||
t0 = time.time()
|
||||
if clip_model != ci.config.clip_model_name:
|
||||
shared.log.info(f'CLIP load: clip="{clip_model}" reloading')
|
||||
debug_log(f'CLIP load: previous clip="{ci.config.clip_model_name}"')
|
||||
ci.config.clip_model_name = clip_model
|
||||
ci.config.clip_model = None
|
||||
ci.load_clip_model()
|
||||
if blip_model != ci.config.caption_model_name:
|
||||
shared.log.info(f'CLIP load: blip="{blip_model}" reloading')
|
||||
debug_log(f'CLIP load: previous blip="{ci.config.caption_model_name}"')
|
||||
ci.config.caption_model_name = blip_model
|
||||
ci.config.caption_model = None
|
||||
ci.load_caption_model()
|
||||
if blip_model.startswith('blip2-'):
|
||||
_apply_blip2_fix(ci.caption_model, ci.caption_processor)
|
||||
shared.log.debug(f'CLIP load: time={time.time()-t0:.2f}')
|
||||
else:
|
||||
debug_log(f'CLIP: models already loaded clip="{clip_model}" blip="{blip_model}"')
|
||||
|
||||
|
||||
def unload_clip_model():
|
||||
if ci is not None and shared.opts.interrogate_offload:
|
||||
shared.log.debug('CLIP unload: offloading models to CPU')
|
||||
sd_models.move_model(ci.caption_model, devices.cpu)
|
||||
sd_models.move_model(ci.clip_model, devices.cpu)
|
||||
ci.caption_offloaded = True
|
||||
ci.clip_offloaded = True
|
||||
devices.torch_gc()
|
||||
debug_log('CLIP unload: complete')
|
||||
|
||||
|
||||
def interrogate(image, mode, caption=None):
|
||||
|
|
@ -119,6 +158,8 @@ def interrogate(image, mode, caption=None):
|
|||
if image is None:
|
||||
return ''
|
||||
image = image.convert("RGB")
|
||||
t0 = time.time()
|
||||
debug_log(f'CLIP: mode="{mode}" image_size={image.size} caption={caption is not None} min_flavors={shared.opts.interrogate_clip_min_flavors} max_flavors={shared.opts.interrogate_clip_max_flavors}')
|
||||
if mode == 'best':
|
||||
prompt = ci.interrogate(image, caption=caption, min_flavors=shared.opts.interrogate_clip_min_flavors, max_flavors=shared.opts.interrogate_clip_max_flavors, )
|
||||
elif mode == 'caption':
|
||||
|
|
@ -131,22 +172,27 @@ def interrogate(image, mode, caption=None):
|
|||
prompt = ci.interrogate_negative(image, max_flavors=shared.opts.interrogate_clip_max_flavors)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown mode {mode}")
|
||||
debug_log(f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt[:100]}..."' if len(prompt) > 100 else f'CLIP: mode="{mode}" time={time.time()-t0:.2f} result="{prompt}"')
|
||||
return prompt
|
||||
|
||||
|
||||
def interrogate_image(image, clip_model, blip_model, mode):
|
||||
jobid = shared.state.begin('Interrogate CLiP')
|
||||
t0 = time.time()
|
||||
shared.log.info(f'CLIP: mode="{mode}" clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
|
||||
try:
|
||||
if shared.sd_loaded:
|
||||
from modules.sd_models import apply_balanced_offload # prevent circular import
|
||||
apply_balanced_offload(shared.sd_model)
|
||||
debug_log('CLIP: applied balanced offload to sd_model')
|
||||
load_interrogator(clip_model, blip_model)
|
||||
image = image.convert('RGB')
|
||||
prompt = interrogate(image, mode)
|
||||
devices.torch_gc()
|
||||
shared.log.debug(f'CLIP: complete time={time.time()-t0:.2f}')
|
||||
except Exception as e:
|
||||
prompt = f"Exception {type(e)}"
|
||||
shared.log.error(f'Interrogate: {e}')
|
||||
shared.log.error(f'CLIP: {e}')
|
||||
errors.display(e, 'Interrogate')
|
||||
shared.state.end(jobid)
|
||||
return prompt
|
||||
|
|
@ -162,8 +208,11 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
from modules.files_cache import list_files
|
||||
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
|
||||
if len(files) == 0:
|
||||
shared.log.warning('Interrogate batch: type=clip no images')
|
||||
shared.log.warning('CLIP batch: no images found')
|
||||
return ''
|
||||
t0 = time.time()
|
||||
shared.log.info(f'CLIP batch: mode="{mode}" images={len(files)} clip="{clip_model}" blip="{blip_model}" write={write} append={append}')
|
||||
debug_log(f'CLIP batch: recursive={recursive} files={files[:5]}{"..." if len(files) > 5 else ""}')
|
||||
jobid = shared.state.begin('Interrogate batch')
|
||||
prompts = []
|
||||
|
||||
|
|
@ -171,6 +220,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
if write:
|
||||
file_mode = 'w' if not append else 'a'
|
||||
writer = BatchWriter(os.path.dirname(files[0]), mode=file_mode)
|
||||
debug_log(f'CLIP batch: writing to "{os.path.dirname(files[0])}" mode="{file_mode}"')
|
||||
import rich.progress as rp
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
with pbar:
|
||||
|
|
@ -179,6 +229,7 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
pbar.update(task, advance=1, description=file)
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
shared.log.info('CLIP batch: interrupted')
|
||||
break
|
||||
image = Image.open(file).convert('RGB')
|
||||
prompt = interrogate(image, mode)
|
||||
|
|
@ -186,19 +237,23 @@ def interrogate_batch(batch_files, batch_folder, batch_str, clip_model, blip_mod
|
|||
if write:
|
||||
writer.add(file, prompt)
|
||||
except OSError as e:
|
||||
shared.log.error(f'Interrogate batch: {e}')
|
||||
shared.log.error(f'CLIP batch: file="{file}" error={e}')
|
||||
if write:
|
||||
writer.close()
|
||||
ci.config.quiet = False
|
||||
unload_clip_model()
|
||||
shared.state.end(jobid)
|
||||
shared.log.info(f'CLIP batch: complete images={len(prompts)} time={time.time()-t0:.2f}')
|
||||
return '\n\n'.join(prompts)
|
||||
|
||||
|
||||
def analyze_image(image, clip_model, blip_model):
|
||||
t0 = time.time()
|
||||
shared.log.info(f'CLIP analyze: clip="{clip_model}" blip="{blip_model}" image_size={image.size if image else None}')
|
||||
load_interrogator(clip_model, blip_model)
|
||||
image = image.convert('RGB')
|
||||
image_features = ci.image_to_features(image)
|
||||
debug_log(f'CLIP analyze: features shape={image_features.shape if hasattr(image_features, "shape") else "unknown"}')
|
||||
top_mediums = ci.mediums.rank(image_features, 5)
|
||||
top_artists = ci.artists.rank(image_features, 5)
|
||||
top_movements = ci.movements.rank(image_features, 5)
|
||||
|
|
@ -209,6 +264,7 @@ def analyze_image(image, clip_model, blip_model):
|
|||
movement_ranks = dict(sorted(zip(top_movements, ci.similarities(image_features, top_movements)), key=lambda x: x[1], reverse=True))
|
||||
trending_ranks = dict(sorted(zip(top_trendings, ci.similarities(image_features, top_trendings)), key=lambda x: x[1], reverse=True))
|
||||
flavor_ranks = dict(sorted(zip(top_flavors, ci.similarities(image_features, top_flavors)), key=lambda x: x[1], reverse=True))
|
||||
shared.log.debug(f'CLIP analyze: complete time={time.time()-t0:.2f}')
|
||||
|
||||
# Format labels as text
|
||||
def format_category(name, ranks):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
# Unified Tagger Interface - Dispatches to WaifuDiffusion or DeepBooru based on model selection
|
||||
# Provides a common interface for the Booru Tags tab
|
||||
|
||||
from modules import shared
|
||||
|
||||
DEEPBOORU_MODEL = "DeepBooru"
|
||||
|
||||
|
||||
def get_models() -> list:
|
||||
"""Return combined list: DeepBooru + WaifuDiffusion models."""
|
||||
from modules.interrogate import waifudiffusion
|
||||
return [DEEPBOORU_MODEL] + waifudiffusion.get_models()
|
||||
|
||||
|
||||
def refresh_models() -> list:
|
||||
"""Refresh and return all models."""
|
||||
return get_models()
|
||||
|
||||
|
||||
def is_deepbooru(model_name: str) -> bool:
|
||||
"""Check if selected model is DeepBooru."""
|
||||
return model_name == DEEPBOORU_MODEL
|
||||
|
||||
|
||||
def load_model(model_name: str) -> bool:
|
||||
"""Load appropriate backend."""
|
||||
if is_deepbooru(model_name):
|
||||
from modules.interrogate import deepbooru
|
||||
return deepbooru.load_model()
|
||||
else:
|
||||
from modules.interrogate import waifudiffusion
|
||||
return waifudiffusion.load_model(model_name)
|
||||
|
||||
|
||||
def unload_model():
|
||||
"""Unload both backends to ensure memory is freed."""
|
||||
from modules.interrogate import deepbooru, waifudiffusion
|
||||
deepbooru.unload_model()
|
||||
waifudiffusion.unload_model()
|
||||
|
||||
|
||||
def tag(image, model_name: str = None, **kwargs) -> str:
|
||||
"""Unified tagging - dispatch to correct backend.
|
||||
|
||||
Args:
|
||||
image: PIL Image to tag
|
||||
model_name: Model to use (DeepBooru or WaifuDiffusion model name)
|
||||
**kwargs: Additional arguments passed to the backend
|
||||
|
||||
Returns:
|
||||
Formatted tag string
|
||||
"""
|
||||
if model_name is None:
|
||||
model_name = shared.opts.waifudiffusion_model
|
||||
|
||||
if is_deepbooru(model_name):
|
||||
from modules.interrogate import deepbooru
|
||||
return deepbooru.tag(image, **kwargs)
|
||||
else:
|
||||
from modules.interrogate import waifudiffusion
|
||||
return waifudiffusion.tag(image, model_name=model_name, **kwargs)
|
||||
|
||||
|
||||
def batch(model_name: str, **kwargs) -> str:
|
||||
"""Unified batch processing.
|
||||
|
||||
Args:
|
||||
model_name: Model to use (DeepBooru or WaifuDiffusion model name)
|
||||
**kwargs: Additional arguments passed to the backend
|
||||
|
||||
Returns:
|
||||
Combined tag results
|
||||
"""
|
||||
if is_deepbooru(model_name):
|
||||
from modules.interrogate import deepbooru
|
||||
return deepbooru.batch(model_name=model_name, **kwargs)
|
||||
else:
|
||||
from modules.interrogate import waifudiffusion
|
||||
return waifudiffusion.batch(model_name=model_name, **kwargs)
|
||||
|
|
@ -13,7 +13,7 @@ from modules.interrogate import vqa_detection
|
|||
|
||||
|
||||
# Debug logging - function-based to avoid circular import
|
||||
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
|
||||
def debug(*args, **kwargs):
|
||||
if debug_enabled:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,544 @@
|
|||
# WaifuDiffusion Tagger - ONNX-based anime/illustration tagging
|
||||
# Based on SmilingWolf's tagger models: https://huggingface.co/SmilingWolf
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from modules import shared, devices, errors
|
||||
|
||||
|
||||
# Debug logging - enable with SD_INTERROGATE_DEBUG environment variable
|
||||
debug_enabled = os.environ.get('SD_INTERROGATE_DEBUG', None) is not None
|
||||
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
||||
|
||||
re_special = re.compile(r'([\\()])')
|
||||
load_lock = threading.Lock()
|
||||
|
||||
# WaifuDiffusion model repository mappings
|
||||
WAIFUDIFFUSION_MODELS = {
|
||||
# v3 models (latest, recommended)
|
||||
"wd-eva02-large-tagger-v3": "SmilingWolf/wd-eva02-large-tagger-v3",
|
||||
"wd-vit-tagger-v3": "SmilingWolf/wd-vit-tagger-v3",
|
||||
"wd-convnext-tagger-v3": "SmilingWolf/wd-convnext-tagger-v3",
|
||||
"wd-swinv2-tagger-v3": "SmilingWolf/wd-swinv2-tagger-v3",
|
||||
# v2 models
|
||||
"wd-v1-4-moat-tagger-v2": "SmilingWolf/wd-v1-4-moat-tagger-v2",
|
||||
"wd-v1-4-swinv2-tagger-v2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2",
|
||||
"wd-v1-4-convnext-tagger-v2": "SmilingWolf/wd-v1-4-convnext-tagger-v2",
|
||||
"wd-v1-4-convnextv2-tagger-v2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2",
|
||||
"wd-v1-4-vit-tagger-v2": "SmilingWolf/wd-v1-4-vit-tagger-v2",
|
||||
}
|
||||
|
||||
# Tag categories from selected_tags.csv
|
||||
CATEGORY_GENERAL = 0
|
||||
CATEGORY_CHARACTER = 4
|
||||
CATEGORY_RATING = 9
|
||||
|
||||
|
||||
class WaifuDiffusionTagger:
|
||||
"""WaifuDiffusion Tagger using ONNX inference."""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
self.tags = None
|
||||
self.tag_categories = None
|
||||
self.model_name = None
|
||||
self.model_path = None
|
||||
self.image_size = 448 # Standard for WD models
|
||||
|
||||
def load(self, model_name: str = None):
|
||||
"""Load the ONNX model and tags from HuggingFace."""
|
||||
import huggingface_hub
|
||||
|
||||
if model_name is None:
|
||||
model_name = shared.opts.waifudiffusion_model
|
||||
if model_name not in WAIFUDIFFUSION_MODELS:
|
||||
shared.log.error(f'WaifuDiffusion: unknown model "{model_name}"')
|
||||
return False
|
||||
|
||||
with load_lock:
|
||||
if self.session is not None and self.model_name == model_name:
|
||||
debug_log(f'WaifuDiffusion: model already loaded model="{model_name}"')
|
||||
return True # Already loaded
|
||||
|
||||
# Unload previous model if different
|
||||
if self.model_name != model_name and self.session is not None:
|
||||
debug_log(f'WaifuDiffusion: switching model from "{self.model_name}" to "{model_name}"')
|
||||
self.unload()
|
||||
|
||||
repo_id = WAIFUDIFFUSION_MODELS[model_name]
|
||||
t0 = time.time()
|
||||
shared.log.info(f'WaifuDiffusion load: model="{model_name}" repo="{repo_id}"')
|
||||
|
||||
try:
|
||||
# Download only ONNX model and tags CSV (skip safetensors/msgpack variants)
|
||||
debug_log(f'WaifuDiffusion load: downloading from HuggingFace cache_dir="{shared.opts.hfcache_dir}"')
|
||||
self.model_path = huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
allow_patterns=["model.onnx", "selected_tags.csv"],
|
||||
)
|
||||
debug_log(f'WaifuDiffusion load: model_path="{self.model_path}"')
|
||||
|
||||
# Load ONNX model
|
||||
model_file = os.path.join(self.model_path, "model.onnx")
|
||||
if not os.path.exists(model_file):
|
||||
shared.log.error(f'WaifuDiffusion load: model file not found: {model_file}')
|
||||
return False
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
debug_log(f'WaifuDiffusion load: onnxruntime version={ort.__version__}')
|
||||
|
||||
self.session = ort.InferenceSession(model_file, providers=devices.onnx)
|
||||
self.model_name = model_name
|
||||
|
||||
# Get actual providers used
|
||||
actual_providers = self.session.get_providers()
|
||||
debug_log(f'WaifuDiffusion load: active providers={actual_providers}')
|
||||
|
||||
# Load tags from CSV
|
||||
self._load_tags()
|
||||
|
||||
load_time = time.time() - t0
|
||||
shared.log.debug(f'WaifuDiffusion load: time={load_time:.2f} tags={len(self.tags)}')
|
||||
debug_log(f'WaifuDiffusion load: input_name={self.session.get_inputs()[0].name} output_name={self.session.get_outputs()[0].name}')
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
shared.log.error(f'WaifuDiffusion load: failed error={e}')
|
||||
errors.display(e, 'WaifuDiffusion load')
|
||||
self.unload()
|
||||
return False
|
||||
|
||||
def _load_tags(self):
|
||||
"""Load tags and categories from selected_tags.csv."""
|
||||
import csv
|
||||
|
||||
csv_path = os.path.join(self.model_path, "selected_tags.csv")
|
||||
if not os.path.exists(csv_path):
|
||||
shared.log.error(f'WaifuDiffusion load: tags file not found: {csv_path}')
|
||||
return
|
||||
|
||||
self.tags = []
|
||||
self.tag_categories = []
|
||||
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
self.tags.append(row['name'])
|
||||
self.tag_categories.append(int(row['category']))
|
||||
|
||||
# Count tags by category
|
||||
category_counts = {}
|
||||
for cat in self.tag_categories:
|
||||
category_counts[cat] = category_counts.get(cat, 0) + 1
|
||||
debug_log(f'WaifuDiffusion load: tag categories={category_counts}')
|
||||
|
||||
def unload(self):
|
||||
"""Unload the model and free resources."""
|
||||
if self.session is not None:
|
||||
shared.log.debug(f'WaifuDiffusion unload: model="{self.model_name}"')
|
||||
self.session = None
|
||||
self.tags = None
|
||||
self.tag_categories = None
|
||||
self.model_name = None
|
||||
self.model_path = None
|
||||
devices.torch_gc(force=True)
|
||||
debug_log('WaifuDiffusion unload: complete')
|
||||
else:
|
||||
debug_log('WaifuDiffusion unload: no model loaded')
|
||||
|
||||
def preprocess_image(self, image: Image.Image) -> np.ndarray:
|
||||
"""Preprocess image for WaifuDiffusion model input.
|
||||
|
||||
- Resize to 448x448 (standard for WD models)
|
||||
- Pad to square with white background
|
||||
- Normalize to [0, 1] range
|
||||
- BGR channel order (as used by these models)
|
||||
"""
|
||||
original_size = image.size
|
||||
original_mode = image.mode
|
||||
|
||||
# Convert to RGB if needed
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Pad to square with white background
|
||||
w, h = image.size
|
||||
max_dim = max(w, h)
|
||||
pad_left = (max_dim - w) // 2
|
||||
pad_top = (max_dim - h) // 2
|
||||
|
||||
padded = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
|
||||
padded.paste(image, (pad_left, pad_top))
|
||||
|
||||
# Resize to model input size
|
||||
if max_dim != self.image_size:
|
||||
padded = padded.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to numpy array and normalize
|
||||
img_array = np.array(padded, dtype=np.float32)
|
||||
|
||||
# Convert RGB to BGR (model expects BGR)
|
||||
img_array = img_array[:, :, ::-1]
|
||||
|
||||
# Add batch dimension
|
||||
img_array = np.expand_dims(img_array, axis=0)
|
||||
|
||||
debug_log(f'WaifuDiffusion preprocess: original_size={original_size} mode={original_mode} padded_size={max_dim} output_shape={img_array.shape}')
|
||||
return img_array
|
||||
|
||||
def predict(
|
||||
self,
|
||||
image: Image.Image,
|
||||
general_threshold: float = None,
|
||||
character_threshold: float = None,
|
||||
include_rating: bool = None,
|
||||
exclude_tags: str = None,
|
||||
max_tags: int = None,
|
||||
sort_alpha: bool = None,
|
||||
use_spaces: bool = None,
|
||||
escape_brackets: bool = None,
|
||||
) -> str:
|
||||
"""Run inference and return formatted tag string.
|
||||
|
||||
Args:
|
||||
image: PIL Image to tag
|
||||
general_threshold: Threshold for general tags (0-1)
|
||||
character_threshold: Threshold for character tags (0-1)
|
||||
include_rating: Whether to include rating tags
|
||||
exclude_tags: Comma-separated tags to exclude
|
||||
max_tags: Maximum number of tags to return
|
||||
sort_alpha: Sort tags alphabetically vs by confidence
|
||||
use_spaces: Use spaces instead of underscores
|
||||
escape_brackets: Escape parentheses/brackets in tags
|
||||
|
||||
Returns:
|
||||
Formatted tag string
|
||||
"""
|
||||
t0 = time.time()
|
||||
|
||||
# Use settings defaults if not specified
|
||||
general_threshold = general_threshold or shared.opts.tagger_threshold
|
||||
character_threshold = character_threshold or shared.opts.waifudiffusion_character_threshold
|
||||
include_rating = include_rating if include_rating is not None else shared.opts.tagger_include_rating
|
||||
exclude_tags = exclude_tags or shared.opts.tagger_exclude_tags
|
||||
max_tags = max_tags or shared.opts.tagger_max_tags
|
||||
sort_alpha = sort_alpha if sort_alpha is not None else shared.opts.tagger_sort_alpha
|
||||
use_spaces = use_spaces if use_spaces is not None else shared.opts.tagger_use_spaces
|
||||
escape_brackets = escape_brackets if escape_brackets is not None else shared.opts.tagger_escape_brackets
|
||||
|
||||
debug_log(f'WaifuDiffusion predict: general_threshold={general_threshold} character_threshold={character_threshold} max_tags={max_tags} include_rating={include_rating} sort_alpha={sort_alpha}')
|
||||
|
||||
# Handle input variations
|
||||
if isinstance(image, list):
|
||||
image = image[0] if len(image) > 0 else None
|
||||
if isinstance(image, dict) and 'name' in image:
|
||||
image = Image.open(image['name'])
|
||||
if image is None:
|
||||
shared.log.error('WaifuDiffusion predict: no image provided')
|
||||
return ''
|
||||
|
||||
# Load model if needed
|
||||
if self.session is None:
|
||||
if not self.load():
|
||||
return ''
|
||||
|
||||
# Preprocess image
|
||||
img_input = self.preprocess_image(image)
|
||||
|
||||
# Run inference
|
||||
t_infer = time.time()
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
probs = self.session.run([output_name], {input_name: img_input})[0][0]
|
||||
infer_time = time.time() - t_infer
|
||||
debug_log(f'WaifuDiffusion predict: inference time={infer_time:.3f}s output_shape={probs.shape}')
|
||||
|
||||
# Build tag list with probabilities
|
||||
tag_probs = {}
|
||||
exclude_set = {x.strip().replace(' ', '_').lower() for x in exclude_tags.split(',') if x.strip()}
|
||||
if exclude_set:
|
||||
debug_log(f'WaifuDiffusion predict: exclude_tags={exclude_set}')
|
||||
|
||||
general_count = 0
|
||||
character_count = 0
|
||||
rating_count = 0
|
||||
|
||||
for i, (tag_name, prob) in enumerate(zip(self.tags, probs)):
|
||||
category = self.tag_categories[i]
|
||||
tag_lower = tag_name.lower()
|
||||
|
||||
# Skip excluded tags
|
||||
if tag_lower in exclude_set:
|
||||
continue
|
||||
|
||||
# Apply category-specific thresholds
|
||||
if category == CATEGORY_RATING:
|
||||
if not include_rating:
|
||||
continue
|
||||
# Always include rating if enabled
|
||||
tag_probs[tag_name] = float(prob)
|
||||
rating_count += 1
|
||||
elif category == CATEGORY_CHARACTER:
|
||||
if prob >= character_threshold:
|
||||
tag_probs[tag_name] = float(prob)
|
||||
character_count += 1
|
||||
elif category == CATEGORY_GENERAL:
|
||||
if prob >= general_threshold:
|
||||
tag_probs[tag_name] = float(prob)
|
||||
general_count += 1
|
||||
else:
|
||||
# Other categories use general threshold
|
||||
if prob >= general_threshold:
|
||||
tag_probs[tag_name] = float(prob)
|
||||
|
||||
debug_log(f'WaifuDiffusion predict: matched tags general={general_count} character={character_count} rating={rating_count} total={len(tag_probs)}')
|
||||
|
||||
# Sort tags
|
||||
if sort_alpha:
|
||||
sorted_tags = sorted(tag_probs.keys())
|
||||
else:
|
||||
sorted_tags = [t for t, _ in sorted(tag_probs.items(), key=lambda x: -x[1])]
|
||||
|
||||
# Limit number of tags
|
||||
if max_tags > 0 and len(sorted_tags) > max_tags:
|
||||
sorted_tags = sorted_tags[:max_tags]
|
||||
debug_log(f'WaifuDiffusion predict: limited to max_tags={max_tags}')
|
||||
|
||||
# Format output
|
||||
result = []
|
||||
for tag_name in sorted_tags:
|
||||
formatted_tag = tag_name
|
||||
if use_spaces:
|
||||
formatted_tag = formatted_tag.replace('_', ' ')
|
||||
if escape_brackets:
|
||||
formatted_tag = re.sub(re_special, r'\\\1', formatted_tag)
|
||||
if shared.opts.tagger_show_scores:
|
||||
formatted_tag = f"({formatted_tag}:{tag_probs[tag_name]:.2f})"
|
||||
result.append(formatted_tag)
|
||||
|
||||
output = ", ".join(result)
|
||||
total_time = time.time() - t0
|
||||
debug_log(f'WaifuDiffusion predict: complete tags={len(result)} time={total_time:.2f} result="{output[:100]}..."' if len(output) > 100 else f'WaifuDiffusion predict: complete tags={len(result)} time={total_time:.2f} result="{output}"')
|
||||
|
||||
return output
|
||||
|
||||
def tag(self, image: Image.Image, **kwargs) -> str:
|
||||
"""Alias for predict() to match deepbooru interface."""
|
||||
return self.predict(image, **kwargs)
|
||||
|
||||
|
||||
# Global tagger instance
|
||||
tagger = WaifuDiffusionTagger()
|
||||
|
||||
|
||||
def _save_tags_to_file(img_path, tags_str: str, save_append: bool) -> bool:
|
||||
"""Save tags to a text file with error handling.
|
||||
|
||||
Args:
|
||||
img_path: Path to the image file
|
||||
tags_str: Tags string to save
|
||||
save_append: If True, append to existing file; otherwise overwrite
|
||||
|
||||
Returns:
|
||||
True if save succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
txt_path = img_path.with_suffix('.txt')
|
||||
if save_append and txt_path.exists():
|
||||
with open(txt_path, 'a', encoding='utf-8') as f:
|
||||
f.write(f', {tags_str}')
|
||||
else:
|
||||
with open(txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write(tags_str)
|
||||
return True
|
||||
except Exception as e:
|
||||
shared.log.error(f'WaifuDiffusion batch: failed to save file="{img_path}" error={e}')
|
||||
return False
|
||||
|
||||
|
||||
def get_models() -> list:
|
||||
"""Return list of available WaifuDiffusion model names."""
|
||||
return list(WAIFUDIFFUSION_MODELS.keys())
|
||||
|
||||
|
||||
def refresh_models() -> list:
|
||||
"""Refresh and return list of available models."""
|
||||
# For now, just return the static list
|
||||
# Could be extended to check for locally cached models
|
||||
return get_models()
|
||||
|
||||
|
||||
def load_model(model_name: str = None) -> bool:
|
||||
"""Load the specified WaifuDiffusion model."""
|
||||
return tagger.load(model_name)
|
||||
|
||||
|
||||
def unload_model():
|
||||
"""Unload the current WaifuDiffusion model."""
|
||||
tagger.unload()
|
||||
|
||||
|
||||
def tag(image: Image.Image, model_name: str = None, **kwargs) -> str:
|
||||
"""Tag an image using WaifuDiffusion tagger.
|
||||
|
||||
Args:
|
||||
image: PIL Image to tag
|
||||
model_name: Model to use (loads if needed)
|
||||
**kwargs: Additional arguments passed to predict()
|
||||
|
||||
Returns:
|
||||
Formatted tag string
|
||||
"""
|
||||
t0 = time.time()
|
||||
jobid = shared.state.begin('WaifuDiffusion Tag')
|
||||
shared.log.info(f'WaifuDiffusion: model="{model_name or tagger.model_name or shared.opts.waifudiffusion_model}" image_size={image.size if image else None}')
|
||||
|
||||
try:
|
||||
if model_name and model_name != tagger.model_name:
|
||||
tagger.load(model_name)
|
||||
result = tagger.predict(image, **kwargs)
|
||||
shared.log.debug(f'WaifuDiffusion: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
|
||||
# Offload model if setting enabled
|
||||
if shared.opts.interrogate_offload:
|
||||
tagger.unload()
|
||||
except Exception as e:
|
||||
result = f"Exception {type(e)}"
|
||||
shared.log.error(f'WaifuDiffusion: {e}')
|
||||
errors.display(e, 'WaifuDiffusion Tag')
|
||||
|
||||
shared.state.end(jobid)
|
||||
return result
|
||||
|
||||
|
||||
def batch(
|
||||
model_name: str,
|
||||
batch_files: list,
|
||||
batch_folder: str,
|
||||
batch_str: str,
|
||||
save_output: bool = True,
|
||||
save_append: bool = False,
|
||||
recursive: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Process multiple images in batch mode.
|
||||
|
||||
Args:
|
||||
model_name: Model to use
|
||||
batch_files: List of file paths
|
||||
batch_folder: Folder path from file picker
|
||||
batch_str: Folder path as string
|
||||
save_output: Save caption to .txt files
|
||||
save_append: Append to existing caption files
|
||||
recursive: Recursively process subfolders
|
||||
**kwargs: Additional arguments passed to predict()
|
||||
|
||||
Returns:
|
||||
Combined tag results
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# Load model
|
||||
if model_name:
|
||||
tagger.load(model_name)
|
||||
elif tagger.session is None:
|
||||
tagger.load()
|
||||
|
||||
# Collect image files
|
||||
image_files = []
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'}
|
||||
|
||||
# From file picker
|
||||
if batch_files:
|
||||
for f in batch_files:
|
||||
if isinstance(f, dict):
|
||||
image_files.append(Path(f['name']))
|
||||
elif hasattr(f, 'name'):
|
||||
image_files.append(Path(f.name))
|
||||
else:
|
||||
image_files.append(Path(f))
|
||||
|
||||
# From folder picker
|
||||
if batch_folder:
|
||||
folder_path = None
|
||||
if isinstance(batch_folder, list) and len(batch_folder) > 0:
|
||||
f = batch_folder[0]
|
||||
if isinstance(f, dict):
|
||||
folder_path = Path(f['name']).parent
|
||||
elif hasattr(f, 'name'):
|
||||
folder_path = Path(f.name).parent
|
||||
if folder_path and folder_path.is_dir():
|
||||
if recursive:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.rglob(f'*{ext}'))
|
||||
else:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.glob(f'*{ext}'))
|
||||
|
||||
# From string path
|
||||
if batch_str and batch_str.strip():
|
||||
folder_path = Path(batch_str.strip())
|
||||
if folder_path.is_dir():
|
||||
if recursive:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.rglob(f'*{ext}'))
|
||||
else:
|
||||
for ext in image_extensions:
|
||||
image_files.extend(folder_path.glob(f'*{ext}'))
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_files = []
|
||||
for f in image_files:
|
||||
f_resolved = f.resolve()
|
||||
if f_resolved not in seen:
|
||||
seen.add(f_resolved)
|
||||
unique_files.append(f)
|
||||
image_files = unique_files
|
||||
|
||||
if not image_files:
|
||||
shared.log.warning('WaifuDiffusion batch: no images found')
|
||||
return ''
|
||||
|
||||
t0 = time.time()
|
||||
jobid = shared.state.begin('WaifuDiffusion Batch')
|
||||
shared.log.info(f'WaifuDiffusion batch: model="{tagger.model_name}" images={len(image_files)} write={save_output} append={save_append} recursive={recursive}')
|
||||
debug_log(f'WaifuDiffusion batch: files={[str(f) for f in image_files[:5]]}{"..." if len(image_files) > 5 else ""}')
|
||||
|
||||
results = []
|
||||
|
||||
# Progress bar
|
||||
import rich.progress as rp
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]WaifuDiffusion:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
|
||||
with pbar:
|
||||
task = pbar.add_task(total=len(image_files), description='starting...')
|
||||
for img_path in image_files:
|
||||
pbar.update(task, advance=1, description=str(img_path.name))
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
shared.log.info('WaifuDiffusion batch: interrupted')
|
||||
break
|
||||
|
||||
image = Image.open(img_path)
|
||||
tags_str = tagger.predict(image, **kwargs)
|
||||
|
||||
if save_output:
|
||||
_save_tags_to_file(img_path, tags_str, save_append)
|
||||
|
||||
results.append(f'{img_path.name}: {tags_str[:100]}...' if len(tags_str) > 100 else f'{img_path.name}: {tags_str}')
|
||||
|
||||
except Exception as e:
|
||||
shared.log.error(f'WaifuDiffusion batch: file="{img_path}" error={e}')
|
||||
results.append(f'{img_path.name}: ERROR - {e}')
|
||||
|
||||
elapsed = time.time() - t0
|
||||
shared.log.info(f'WaifuDiffusion batch: complete images={len(results)} time={elapsed:.1f}s')
|
||||
shared.state.end(jobid)
|
||||
|
||||
return '\n'.join(results)
|
||||
|
|
@ -5,14 +5,18 @@ Lightweight IP-Adapter applied to existing pipeline in Diffusers
|
|||
- IP adapters: https://huggingface.co/h94/IP-Adapter
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
from PIL import Image
|
||||
import diffusers
|
||||
import transformers
|
||||
from modules import processing, shared, devices, sd_models, errors, model_quant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
clip_loaded = None
|
||||
adapters_loaded = []
|
||||
|
|
@ -160,7 +164,7 @@ def unapply(pipe, unload: bool = False): # pylint: disable=arguments-differ
|
|||
pass
|
||||
|
||||
|
||||
def load_image_encoder(pipe: diffusers.DiffusionPipeline, adapter_names: list[str]):
|
||||
def load_image_encoder(pipe: DiffusionPipeline, adapter_names: list[str]):
|
||||
global clip_loaded # pylint: disable=global-statement
|
||||
for adapter_name in adapter_names:
|
||||
# which clip to use
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ def readfile(filename: str, silent: bool = False, lock: bool = False, *, as_type
|
|||
log.debug(f'Read: file="{filename}" json={len(data)} bytes={os.path.getsize(filename)} time={t1-t0:.3f} fn={fn}')
|
||||
except FileNotFoundError as err:
|
||||
if not silent:
|
||||
log.debug(f'Reading failed: {filename} {err}')
|
||||
log.debug(f'Read failed: file="{filename}" {err}')
|
||||
except Exception as err:
|
||||
if not silent:
|
||||
log.error(f'Reading failed: {filename} {err}')
|
||||
log.error(f'Read failed: file="{filename}" {err}')
|
||||
try:
|
||||
if locking_available and lock_file is not None:
|
||||
lock_file.release_read_lock()
|
||||
|
|
|
|||
|
|
@ -46,6 +46,13 @@ except Exception as e:
|
|||
sys.exit(1)
|
||||
timer.startup.record("scipy")
|
||||
|
||||
try:
|
||||
import atexit
|
||||
import torch._inductor.async_compile as ac
|
||||
atexit.unregister(ac.shutdown_compile_workers)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import torch # pylint: disable=C0411
|
||||
if torch.__version__.startswith('2.5.0'):
|
||||
errors.log.warning(f'Disabling cuDNN for SDP on torch={torch.__version__}')
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import re
|
|||
import time
|
||||
import torch
|
||||
import diffusers.models.lora
|
||||
from modules.errorlimiter import ErrorLimiter
|
||||
from modules.lora import lora_common as l
|
||||
from modules import shared, devices, errors, model_quant
|
||||
|
||||
|
|
@ -141,6 +142,7 @@ def network_calc_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.
|
|||
if l.debug:
|
||||
errors.display(e, 'LoRA')
|
||||
raise RuntimeError('LoRA apply weight') from e
|
||||
ErrorLimiter.notify(("network_activate", "network_deactivate"))
|
||||
continue
|
||||
return batch_updown, batch_ex_bias
|
||||
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=Non
|
|||
continue
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
|
||||
shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} not found')
|
||||
continue
|
||||
if hasattr(sd_model, 'embedding_db'):
|
||||
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from contextlib import nullcontext
|
||||
import time
|
||||
import rich.progress as rp
|
||||
from modules.errorlimiter import limit_errors
|
||||
from modules.lora import lora_common as l
|
||||
from modules.lora.lora_apply import network_apply_weights, network_apply_direct, network_backup_weights, network_calc_weights
|
||||
from modules import shared, devices, sd_models
|
||||
|
|
@ -12,61 +13,62 @@ default_components = ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'text_
|
|||
|
||||
def network_activate(include=[], exclude=[]):
|
||||
t0 = time.time()
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
if shared.opts.diffusers_offload_mode == "sequential":
|
||||
sd_models.disable_offload(sd_model)
|
||||
sd_models.move_model(sd_model, device=devices.cpu)
|
||||
device = None
|
||||
modules = {}
|
||||
components = include if len(include) > 0 else default_components
|
||||
components = [x for x in components if x not in exclude]
|
||||
active_components = []
|
||||
for name in components:
|
||||
component = getattr(sd_model, name, None)
|
||||
if component is not None and hasattr(component, 'named_modules'):
|
||||
active_components.append(name)
|
||||
modules[name] = list(component.named_modules())
|
||||
total = sum(len(x) for x in modules.values())
|
||||
if len(l.loaded_networks) > 0:
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=activate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
task = pbar.add_task(description='' , total=total)
|
||||
else:
|
||||
task = None
|
||||
pbar = nullcontext()
|
||||
applied_weight = 0
|
||||
applied_bias = 0
|
||||
with devices.inference_context(), pbar:
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in l.loaded_networks) if len(l.loaded_networks) > 0 else ()
|
||||
applied_layers.clear()
|
||||
backup_size = 0
|
||||
for component in modules.keys():
|
||||
device = getattr(sd_model, component, None).device
|
||||
for _, module in modules[component]:
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
current_names = getattr(module, "network_current_names", ())
|
||||
if getattr(module, 'weight', None) is None or shared.state.interrupted or (network_layer_name is None) or (current_names == wanted_names):
|
||||
with limit_errors("network_activate"):
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
if shared.opts.diffusers_offload_mode == "sequential":
|
||||
sd_models.disable_offload(sd_model)
|
||||
sd_models.move_model(sd_model, device=devices.cpu)
|
||||
device = None
|
||||
modules = {}
|
||||
components = include if len(include) > 0 else default_components
|
||||
components = [x for x in components if x not in exclude]
|
||||
active_components = []
|
||||
for name in components:
|
||||
component = getattr(sd_model, name, None)
|
||||
if component is not None and hasattr(component, 'named_modules'):
|
||||
active_components.append(name)
|
||||
modules[name] = list(component.named_modules())
|
||||
total = sum(len(x) for x in modules.values())
|
||||
if len(l.loaded_networks) > 0:
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=activate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
task = pbar.add_task(description='' , total=total)
|
||||
else:
|
||||
task = None
|
||||
pbar = nullcontext()
|
||||
applied_weight = 0
|
||||
applied_bias = 0
|
||||
with devices.inference_context(), pbar:
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in l.loaded_networks) if len(l.loaded_networks) > 0 else ()
|
||||
applied_layers.clear()
|
||||
backup_size = 0
|
||||
for component in modules.keys():
|
||||
device = getattr(sd_model, component, None).device
|
||||
for _, module in modules[component]:
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
current_names = getattr(module, "network_current_names", ())
|
||||
if getattr(module, 'weight', None) is None or shared.state.interrupted or (network_layer_name is None) or (current_names == wanted_names):
|
||||
if task is not None:
|
||||
pbar.update(task, advance=1)
|
||||
continue
|
||||
backup_size += network_backup_weights(module, network_layer_name, wanted_names)
|
||||
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name)
|
||||
if shared.opts.lora_fuse_native:
|
||||
network_apply_direct(module, batch_updown, batch_ex_bias, device=device)
|
||||
else:
|
||||
network_apply_weights(module, batch_updown, batch_ex_bias, device=device)
|
||||
if batch_updown is not None or batch_ex_bias is not None:
|
||||
applied_layers.append(network_layer_name)
|
||||
applied_weight += 1 if batch_updown is not None else 0
|
||||
applied_bias += 1 if batch_ex_bias is not None else 0
|
||||
batch_updown, batch_ex_bias = None, None
|
||||
del batch_updown, batch_ex_bias
|
||||
module.network_current_names = wanted_names
|
||||
if task is not None:
|
||||
pbar.update(task, advance=1)
|
||||
continue
|
||||
backup_size += network_backup_weights(module, network_layer_name, wanted_names)
|
||||
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name)
|
||||
if shared.opts.lora_fuse_native:
|
||||
network_apply_direct(module, batch_updown, batch_ex_bias, device=device)
|
||||
else:
|
||||
network_apply_weights(module, batch_updown, batch_ex_bias, device=device)
|
||||
if batch_updown is not None or batch_ex_bias is not None:
|
||||
applied_layers.append(network_layer_name)
|
||||
applied_weight += 1 if batch_updown is not None else 0
|
||||
applied_bias += 1 if batch_ex_bias is not None else 0
|
||||
batch_updown, batch_ex_bias = None, None
|
||||
del batch_updown, batch_ex_bias
|
||||
module.network_current_names = wanted_names
|
||||
if task is not None:
|
||||
bs = round(backup_size/1024/1024/1024, 2) if backup_size > 0 else None
|
||||
pbar.update(task, advance=1, description=f'networks={len(l.loaded_networks)} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={bs} device={device}')
|
||||
bs = round(backup_size/1024/1024/1024, 2) if backup_size > 0 else None
|
||||
pbar.update(task, advance=1, description=f'networks={len(l.loaded_networks)} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={bs} device={device}')
|
||||
|
||||
if task is not None and len(applied_layers) == 0:
|
||||
pbar.remove_task(task) # hide progress bar for no action
|
||||
if task is not None and len(applied_layers) == 0:
|
||||
pbar.remove_task(task) # hide progress bar for no action
|
||||
l.timer.activate += time.time() - t0
|
||||
if l.debug and len(l.loaded_networks) > 0:
|
||||
shared.log.debug(f'Network load: type=LoRA networks={[n.name for n in l.loaded_networks]} modules={active_components} layers={total} weights={applied_weight} bias={applied_bias} backup={round(backup_size/1024/1024/1024, 2)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers} device={device} time={l.timer.summary}')
|
||||
|
|
@ -81,49 +83,49 @@ def network_deactivate(include=[], exclude=[]):
|
|||
if len(l.previously_loaded_networks) == 0:
|
||||
return
|
||||
t0 = time.time()
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
if shared.opts.diffusers_offload_mode == "sequential":
|
||||
sd_models.disable_offload(sd_model)
|
||||
sd_models.move_model(sd_model, device=devices.cpu)
|
||||
modules = {}
|
||||
with limit_errors("network_deactivate"):
|
||||
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
||||
if shared.opts.diffusers_offload_mode == "sequential":
|
||||
sd_models.disable_offload(sd_model)
|
||||
sd_models.move_model(sd_model, device=devices.cpu)
|
||||
modules = {}
|
||||
|
||||
components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer']
|
||||
components = [x for x in components if x not in exclude]
|
||||
active_components = []
|
||||
for name in components:
|
||||
component = getattr(sd_model, name, None)
|
||||
if component is not None and hasattr(component, 'named_modules'):
|
||||
modules[name] = list(component.named_modules())
|
||||
active_components.append(name)
|
||||
total = sum(len(x) for x in modules.values())
|
||||
if len(l.previously_loaded_networks) > 0 and l.debug:
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
task = pbar.add_task(description='', total=total)
|
||||
else:
|
||||
task = None
|
||||
pbar = nullcontext()
|
||||
with devices.inference_context(), pbar:
|
||||
applied_layers.clear()
|
||||
for component in modules.keys():
|
||||
device = getattr(sd_model, component, None).device
|
||||
for _, module in modules[component]:
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
if shared.state.interrupted or network_layer_name is None:
|
||||
components = include if len(include) > 0 else ['text_encoder', 'text_encoder_2', 'text_encoder_3', 'unet', 'transformer']
|
||||
components = [x for x in components if x not in exclude]
|
||||
active_components = []
|
||||
for name in components:
|
||||
component = getattr(sd_model, name, None)
|
||||
if component is not None and hasattr(component, 'named_modules'):
|
||||
modules[name] = list(component.named_modules())
|
||||
active_components.append(name)
|
||||
total = sum(len(x) for x in modules.values())
|
||||
if len(l.previously_loaded_networks) > 0 and l.debug:
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Network: type=LoRA action=deactivate'), rp.BarColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
task = pbar.add_task(description='', total=total)
|
||||
else:
|
||||
task = None
|
||||
pbar = nullcontext()
|
||||
with devices.inference_context(), pbar:
|
||||
applied_layers.clear()
|
||||
for component in modules.keys():
|
||||
device = getattr(sd_model, component, None).device
|
||||
for _, module in modules[component]:
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
if shared.state.interrupted or network_layer_name is None:
|
||||
if task is not None:
|
||||
pbar.update(task, advance=1)
|
||||
continue
|
||||
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name, use_previous=True)
|
||||
if shared.opts.lora_fuse_native:
|
||||
network_apply_direct(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
|
||||
else:
|
||||
network_apply_weights(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
|
||||
if batch_updown is not None or batch_ex_bias is not None:
|
||||
applied_layers.append(network_layer_name)
|
||||
del batch_updown, batch_ex_bias
|
||||
module.network_current_names = ()
|
||||
if task is not None:
|
||||
pbar.update(task, advance=1)
|
||||
continue
|
||||
batch_updown, batch_ex_bias = network_calc_weights(module, network_layer_name, use_previous=True)
|
||||
if shared.opts.lora_fuse_native:
|
||||
network_apply_direct(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
|
||||
else:
|
||||
network_apply_weights(module, batch_updown, batch_ex_bias, device=device, deactivate=True)
|
||||
if batch_updown is not None or batch_ex_bias is not None:
|
||||
applied_layers.append(network_layer_name)
|
||||
del batch_updown, batch_ex_bias
|
||||
module.network_current_names = ()
|
||||
if task is not None:
|
||||
pbar.update(task, advance=1, description=f'networks={len(l.previously_loaded_networks)} modules={active_components} layers={total} unapply={len(applied_layers)}')
|
||||
|
||||
pbar.update(task, advance=1, description=f'networks={len(l.previously_loaded_networks)} modules={active_components} layers={total} unapply={len(applied_layers)}')
|
||||
l.timer.deactivate = time.time() - t0
|
||||
if l.debug and len(l.previously_loaded_networks) > 0:
|
||||
shared.log.debug(f'Network deactivate: type=LoRA networks={[n.name for n in l.previously_loaded_networks]} modules={active_components} layers={total} apply={len(applied_layers)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers} time={l.timer.summary}')
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def create_ui(prompt, negative, styles, overrides, init_image, init_strength, la
|
|||
generate = gr.Button('Generate', elem_id="ltx_generate_btn", variant='primary', visible=False)
|
||||
with gr.Row():
|
||||
ltx_models = [m.name for m in models['LTX Video']] if 'LTX Video' in models else ['None']
|
||||
model = gr.Dropdown(label='LTX model', choices=ltx_models, value=ltx_models[0])
|
||||
model = gr.Dropdown(label='LTX model', choices=ltx_models, value=ltx_models[0], elem_id="ltx_model")
|
||||
with gr.Accordion(open=False, label="Condition", elem_id='ltx_condition_accordion'):
|
||||
with gr.Tabs():
|
||||
with gr.Tab('Video', id='ltx_condition_video_tab'):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
import os
|
||||
from modules.paths import data_path
|
||||
from installer import log
|
||||
|
||||
|
||||
files = [
|
||||
'cache.json',
|
||||
'metadata.json',
|
||||
'html/extensions.json',
|
||||
'html/previews.json',
|
||||
'html/upscalers.json',
|
||||
'html/reference.json',
|
||||
'html/themes.json',
|
||||
'html/reference-quant.json',
|
||||
'html/reference-distilled.json',
|
||||
'html/reference-community.json',
|
||||
'html/reference-cloud.json',
|
||||
]
|
||||
|
||||
|
||||
def migrate_data():
|
||||
for f in files:
|
||||
old_filename = os.path.join(data_path, f)
|
||||
new_filename = os.path.join(data_path, "data", os.path.basename(f))
|
||||
if os.path.exists(old_filename):
|
||||
if not os.path.exists(new_filename):
|
||||
log.info(f'Migrating: file="{old_filename}" target="{new_filename}"')
|
||||
try:
|
||||
os.rename(old_filename, new_filename)
|
||||
except Exception as e:
|
||||
log.error(f'Migrating: file="{old_filename}" target="{new_filename}" {e}')
|
||||
else:
|
||||
log.warning(f'Migrating: file="{old_filename}" target="{new_filename}" skip existing')
|
||||
|
||||
|
||||
migrate_data()
|
||||
|
|
@ -58,7 +58,7 @@ def get_model_type(pipe):
|
|||
model_type = 'sana'
|
||||
elif "HiDream" in name:
|
||||
model_type = 'h1'
|
||||
elif "Cosmos2TextToImage" in name:
|
||||
elif "Cosmos2TextToImage" in name or "AnimaTextToImage" in name:
|
||||
model_type = 'cosmos'
|
||||
elif "FLite" in name:
|
||||
model_type = 'flite'
|
||||
|
|
|
|||
|
|
@ -11,8 +11,10 @@ if TYPE_CHECKING:
|
|||
from modules.ui_components import DropdownEditable
|
||||
|
||||
|
||||
def options_section(section_identifier: tuple[str, str], options_dict: dict[str, OptionInfo | LegacyOption]):
|
||||
def options_section(section_identifier: tuple[str, str], options_dict: dict[str, OptionInfo | LegacyOption]) -> dict[str, OptionInfo | LegacyOption]:
|
||||
"""Set the `section` value for all OptionInfo/LegacyOption items"""
|
||||
if len(section_identifier) > 2:
|
||||
section_identifier = section_identifier[:2]
|
||||
for v in options_dict.values():
|
||||
v.section = section_identifier
|
||||
return options_dict
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def apply(p: processing.StableDiffusionProcessing): # pylint: disable=arguments-
|
|||
cls = unapply()
|
||||
if p.pag_scale == 0:
|
||||
return
|
||||
if 'PAG' in cls.__name__:
|
||||
if cls is not None and 'PAG' in cls.__name__:
|
||||
pass
|
||||
elif detect.is_sd15(cls):
|
||||
if sd_models.get_diffusers_task(shared.sd_model) != sd_models.DiffusersTaskType.TEXT_2_IMAGE:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import sys
|
|||
import json
|
||||
import shlex
|
||||
import argparse
|
||||
import tempfile
|
||||
from installer import log
|
||||
|
||||
|
||||
|
|
@ -18,12 +19,16 @@ cli = parser.parse_known_args(argv)[0]
|
|||
parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s") # twice because we want data_dir
|
||||
cli = parser.parse_known_args(argv)[0]
|
||||
config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf8') as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
config = {}
|
||||
|
||||
temp_dir = config.get('temp_dir', '')
|
||||
if len(temp_dir) == 0:
|
||||
temp_dir = tempfile.gettempdir()
|
||||
reference_path = os.path.join('models', 'Reference')
|
||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||
script_path = os.path.dirname(modules_path)
|
||||
|
|
|
|||
|
|
@ -422,7 +422,7 @@ class YoloRestorer(Detailer):
|
|||
pc.image_mask = [item.mask]
|
||||
pc.overlay_images = []
|
||||
# explictly disable for detailer pass
|
||||
pc.enable_hr = False
|
||||
pc.enable_hr = False
|
||||
pc.do_not_save_samples = True
|
||||
pc.do_not_save_grid = True
|
||||
# set recursion flag to avoid nested detailer calls
|
||||
|
|
|
|||
|
|
@ -243,13 +243,13 @@ def process_init(p: StableDiffusionProcessing):
|
|||
seed = get_fixed_seed(p.seed)
|
||||
subseed = get_fixed_seed(p.subseed)
|
||||
reset_prompts = False
|
||||
if p.all_prompts is None:
|
||||
if not p.all_prompts:
|
||||
p.all_prompts = p.prompt if isinstance(p.prompt, list) else p.batch_size * p.n_iter * [p.prompt]
|
||||
reset_prompts = True
|
||||
if p.all_negative_prompts is None:
|
||||
if not p.all_negative_prompts:
|
||||
p.all_negative_prompts = p.negative_prompt if isinstance(p.negative_prompt, list) else p.batch_size * p.n_iter * [p.negative_prompt]
|
||||
reset_prompts = True
|
||||
if p.all_seeds is None:
|
||||
if not p.all_seeds:
|
||||
reset_prompts = True
|
||||
if type(seed) == list:
|
||||
p.all_seeds = [int(s) for s in seed]
|
||||
|
|
@ -262,7 +262,7 @@ def process_init(p: StableDiffusionProcessing):
|
|||
for i in range(len(p.all_prompts)):
|
||||
seed = get_fixed_seed(p.seed)
|
||||
p.all_seeds.append(int(seed) + (i if p.subseed_strength == 0 else 0))
|
||||
if p.all_subseeds is None:
|
||||
if not p.all_subseeds:
|
||||
if type(subseed) == list:
|
||||
p.all_subseeds = [int(s) for s in subseed]
|
||||
else:
|
||||
|
|
@ -270,8 +270,8 @@ def process_init(p: StableDiffusionProcessing):
|
|||
if reset_prompts:
|
||||
if not hasattr(p, 'keep_prompts'):
|
||||
p.all_prompts, p.all_negative_prompts = shared.prompt_styles.apply_styles_to_prompts(p.all_prompts, p.all_negative_prompts, p.styles, p.all_seeds)
|
||||
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
p.prompts = p.all_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
|
||||
p.negative_prompts = p.all_negative_prompts[(p.iteration * p.batch_size):((p.iteration+1) * p.batch_size)]
|
||||
p.prompts, _ = extra_networks.parse_prompts(p.prompts)
|
||||
|
||||
|
||||
|
|
@ -427,13 +427,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
continue
|
||||
|
||||
if not hasattr(p, 'keep_prompts'):
|
||||
p.prompts = p.all_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.prompts = p.all_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
|
||||
p.negative_prompts = p.all_negative_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
|
||||
p.seeds = p.all_seeds[(n * p.batch_size):((n+1) * p.batch_size)]
|
||||
p.subseeds = p.all_subseeds[(n * p.batch_size):((n+1) * p.batch_size)]
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
if len(p.prompts) == 0:
|
||||
if not p.prompts:
|
||||
break
|
||||
p.prompts, p.network_data = extra_networks.parse_prompts(p.prompts)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
|
|
@ -469,8 +469,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.scripts.postprocess_batch(p, samples, batch_number=n)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
p.prompts = p.all_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n+1) * p.batch_size]
|
||||
p.prompts = p.all_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
|
||||
p.negative_prompts = p.all_negative_prompts[(n * p.batch_size):((n+1) * p.batch_size)]
|
||||
batch_params = scripts_manager.PostprocessBatchListArgs(list(samples))
|
||||
p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
|
||||
samples = batch_params.images
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def task_specific_kwargs(p, model):
|
|||
if 'hires' not in p.ops:
|
||||
p.ops.append('img2img')
|
||||
if p.vae_type == 'Remote':
|
||||
from modules.sd_vae_remote import remote_encode
|
||||
from modules.vae.sd_vae_remote import remote_encode
|
||||
p.init_images = remote_encode(p.init_images)
|
||||
task_args = {
|
||||
'image': p.init_images,
|
||||
|
|
@ -117,7 +117,7 @@ def task_specific_kwargs(p, model):
|
|||
p.ops.append('inpaint')
|
||||
mask_image = p.task_args.get('image_mask', None) or getattr(p, 'image_mask', None) or getattr(p, 'mask', None)
|
||||
if p.vae_type == 'Remote':
|
||||
from modules.sd_vae_remote import remote_encode
|
||||
from modules.vae.sd_vae_remote import remote_encode
|
||||
p.init_images = remote_encode(p.init_images)
|
||||
# mask_image = remote_encode(mask_image)
|
||||
task_args = {
|
||||
|
|
@ -269,7 +269,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
kwargs['output_type'] = 'np' # only set latent if model has vae
|
||||
|
||||
# model specific
|
||||
if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__:
|
||||
if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'Anima' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__:
|
||||
kwargs['output_type'] = 'np' # only set latent if model has vae
|
||||
if 'StableCascade' in model.__class__.__name__:
|
||||
kwargs.pop("guidance_scale") # remove
|
||||
|
|
|
|||
|
|
@ -308,15 +308,14 @@ class StableDiffusionProcessing:
|
|||
shared.log.error(f'Override: {override_settings} {e}')
|
||||
self.override_settings = {}
|
||||
|
||||
# null items initialized later
|
||||
self.prompts = None
|
||||
self.negative_prompts = None
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.prompts = []
|
||||
self.negative_prompts = []
|
||||
self.all_prompts = []
|
||||
self.all_negative_prompts = []
|
||||
self.seeds = []
|
||||
self.subseeds = []
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.all_seeds = []
|
||||
self.all_subseeds = []
|
||||
|
||||
# a1111 compatibility items
|
||||
self.seed_enable_extras: bool = True
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space
|
|||
|
||||
import os
|
||||
import torch
|
||||
from modules import shared, sd_vae_taesd, devices
|
||||
from modules import shared, devices
|
||||
from modules.vae import sd_vae_taesd
|
||||
|
||||
|
||||
debug_enabled = os.environ.get('SD_HDR_DEBUG', None) is not None
|
||||
|
|
|
|||
|
|
@ -563,9 +563,9 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
|
|||
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # force pipeline
|
||||
if len(getattr(p, 'init_images', [])) == 0:
|
||||
p.init_images = [TF.to_pil_image(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))]
|
||||
if p.prompts is None or len(p.prompts) == 0:
|
||||
if not p.prompts:
|
||||
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
if p.negative_prompts is None or len(p.negative_prompts) == 0:
|
||||
if not p.negative_prompts:
|
||||
p.negative_prompts = p.all_negative_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
|
||||
sd_models_compile.openvino_recompile_model(p, hires=False, refiner=False) # recompile if a parameter changes
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ def calculate_base_steps(p, use_denoise_start, use_refiner_start):
|
|||
if len(getattr(p, 'timesteps', [])) > 0:
|
||||
return None
|
||||
cls = shared.sd_model.__class__.__name__
|
||||
if 'Flex' in cls or 'Kontext' in cls or 'Edit' in cls or 'Wan' in cls or 'Flux2' in cls or 'Layered' in cls:
|
||||
if shared.sd_model_type not in ['sd', 'sdxl']:
|
||||
steps = p.steps
|
||||
elif is_modular():
|
||||
steps = p.steps
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
|
|||
"Steps": p.steps,
|
||||
"Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
|
||||
"Sampler": p.sampler_name if p.sampler_name != 'Default' else None,
|
||||
"Scheduler": shared.sd_model.scheduler.__class__.__name__ if getattr(shared.sd_model, 'scheduler', None) is not None else None,
|
||||
"Seed": all_seeds[index],
|
||||
"Seed resize from": None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}",
|
||||
"CFG scale": p.cfg_scale if p.cfg_scale > 1.0 else 1.0,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import os
|
|||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
from modules import shared, devices, sd_models, sd_vae, sd_vae_taesd, errors
|
||||
from modules import shared, devices, sd_models, sd_vae, errors
|
||||
from modules.vae import sd_vae_taesd
|
||||
|
||||
|
||||
debug = os.environ.get('SD_VAE_DEBUG', None) is not None
|
||||
|
|
@ -286,13 +287,13 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
|
|||
|
||||
if vae_type == 'Remote':
|
||||
jobid = shared.state.begin('Remote VAE')
|
||||
from modules.sd_vae_remote import remote_decode
|
||||
from modules.vae.sd_vae_remote import remote_decode
|
||||
tensors = remote_decode(latents=latents, width=width, height=height)
|
||||
shared.state.end(jobid)
|
||||
if tensors is not None and len(tensors) > 0:
|
||||
return vae_postprocess(tensors, model, output_type)
|
||||
if vae_type == 'Repa':
|
||||
from modules.sd_vae_repa import repa_load
|
||||
from modules.vae.sd_vae_repa import repa_load
|
||||
vae = repa_load(latents)
|
||||
vae_type = 'Full'
|
||||
if vae is not None:
|
||||
|
|
@ -310,14 +311,17 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
|
|||
latents = latents.unsqueeze(0)
|
||||
if latents.shape[-1] <= 4: # not a latent, likely an image
|
||||
decoded = latents.float().cpu().numpy()
|
||||
elif vae_type == 'Full' and hasattr(model, "vae"):
|
||||
decoded = full_vae_decode(latents=latents, model=model)
|
||||
elif hasattr(model, "vqgan"):
|
||||
decoded = full_vqgan_decode(latents=latents, model=model)
|
||||
else:
|
||||
elif vae_type == 'Tiny':
|
||||
decoded = taesd_vae_decode(latents=latents)
|
||||
if torch.is_tensor(decoded):
|
||||
decoded = 2.0 * decoded - 1.0 # typical normalized range
|
||||
elif hasattr(model, "vqgan"):
|
||||
decoded = full_vqgan_decode(latents=latents, model=model)
|
||||
elif hasattr(model, "vae"):
|
||||
decoded = full_vae_decode(latents=latents, model=model)
|
||||
else:
|
||||
shared.log.error('VAE not found in model')
|
||||
decoded = []
|
||||
|
||||
images = vae_postprocess(decoded, model, output_type)
|
||||
if shared.cmd_opts.profile or debug:
|
||||
|
|
@ -339,11 +343,14 @@ def vae_encode(image, model, vae_type='Full'): # pylint: disable=unused-variable
|
|||
shared.log.error('VAE not found in model')
|
||||
return []
|
||||
tensor = f.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae)
|
||||
if vae_type == 'Full':
|
||||
if vae_type == 'Tiny':
|
||||
latents = taesd_vae_encode(image=tensor)
|
||||
elif vae_type == 'Full' and hasattr(model, 'vae'):
|
||||
tensor = tensor * 2 - 1
|
||||
latents = full_vae_encode(image=tensor, model=shared.sd_model)
|
||||
else:
|
||||
latents = taesd_vae_encode(image=tensor)
|
||||
shared.log.error('VAE not found in model')
|
||||
latents = []
|
||||
devices.torch_gc()
|
||||
shared.state.end(jobid)
|
||||
return latents
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,267 @@
|
|||
# res4lyf
|
||||
|
||||
from .abnorsett_scheduler import ABNorsettScheduler
|
||||
from .bong_tangent_scheduler import BongTangentScheduler
|
||||
from .common_sigma_scheduler import CommonSigmaScheduler
|
||||
from .deis_scheduler_alt import RESDEISMultistepScheduler
|
||||
from .etdrk_scheduler import ETDRKScheduler
|
||||
from .gauss_legendre_scheduler import GaussLegendreScheduler
|
||||
from .langevin_dynamics_scheduler import LangevinDynamicsScheduler
|
||||
from .lawson_scheduler import LawsonScheduler
|
||||
from .linear_rk_scheduler import LinearRKScheduler
|
||||
from .lobatto_scheduler import LobattoScheduler
|
||||
from .pec_scheduler import PECScheduler
|
||||
from .radau_iia_scheduler import RadauIIAScheduler
|
||||
from .res_multistep_scheduler import RESMultistepScheduler
|
||||
from .res_multistep_sde_scheduler import RESMultistepSDEScheduler
|
||||
from .res_singlestep_scheduler import RESSinglestepScheduler
|
||||
from .res_singlestep_sde_scheduler import RESSinglestepSDEScheduler
|
||||
from .res_unified_scheduler import RESUnifiedScheduler
|
||||
from .riemannian_flow_scheduler import RiemannianFlowScheduler
|
||||
from .rungekutta_44s_scheduler import RungeKutta44Scheduler
|
||||
from .rungekutta_57s_scheduler import RungeKutta57Scheduler
|
||||
from .rungekutta_67s_scheduler import RungeKutta67Scheduler
|
||||
from .simple_exponential_scheduler import SimpleExponentialScheduler
|
||||
from .specialized_rk_scheduler import SpecializedRKScheduler
|
||||
|
||||
from .variants import (
|
||||
ABNorsett2MScheduler,
|
||||
ABNorsett3MScheduler,
|
||||
ABNorsett4MScheduler,
|
||||
SigmaArcsineScheduler,
|
||||
DEIS1MultistepScheduler,
|
||||
DEIS2MScheduler,
|
||||
DEIS2MultistepScheduler,
|
||||
DEIS3MScheduler,
|
||||
DEIS3MultistepScheduler,
|
||||
DEISUnified1SScheduler,
|
||||
DEISUnified2MScheduler,
|
||||
DEISUnified3MScheduler,
|
||||
SigmaEasingScheduler,
|
||||
ETDRK2Scheduler,
|
||||
ETDRK3AScheduler,
|
||||
ETDRK3BScheduler,
|
||||
ETDRK4AltScheduler,
|
||||
ETDRK4Scheduler,
|
||||
FlowEuclideanScheduler,
|
||||
FlowHyperbolicScheduler,
|
||||
Lawson2AScheduler,
|
||||
Lawson2BScheduler,
|
||||
Lawson4Scheduler,
|
||||
LinearRK2Scheduler,
|
||||
LinearRK3Scheduler,
|
||||
LinearRK4Scheduler,
|
||||
LinearRKMidpointScheduler,
|
||||
LinearRKRalsstonScheduler,
|
||||
Lobatto2Scheduler,
|
||||
Lobatto3Scheduler,
|
||||
Lobatto4Scheduler,
|
||||
FlowLorentzianScheduler,
|
||||
PEC2H2SScheduler,
|
||||
PEC2H3SScheduler,
|
||||
RadauIIA2Scheduler,
|
||||
RadauIIA3Scheduler,
|
||||
RES2MScheduler,
|
||||
RES2MSDEScheduler,
|
||||
RES2SScheduler,
|
||||
RES2SSDEScheduler,
|
||||
RES3MScheduler,
|
||||
RES3MSDEScheduler,
|
||||
RES3SScheduler,
|
||||
RES3SSDEScheduler,
|
||||
RES5SScheduler,
|
||||
RES5SSDEScheduler,
|
||||
RES6SScheduler,
|
||||
RES6SSDEScheduler,
|
||||
RESUnified2MScheduler,
|
||||
RESUnified2SScheduler,
|
||||
RESUnified3MScheduler,
|
||||
RESUnified3SScheduler,
|
||||
RESUnified5SScheduler,
|
||||
RESUnified6SScheduler,
|
||||
SigmaSigmoidScheduler,
|
||||
SigmaSineScheduler,
|
||||
SigmaSmoothScheduler,
|
||||
FlowSphericalScheduler,
|
||||
GaussLegendre2SScheduler,
|
||||
GaussLegendre3SScheduler,
|
||||
GaussLegendre4SScheduler,
|
||||
)
|
||||
|
||||
__all__ = [ # noqa: RUF022
|
||||
# Base
|
||||
"RESUnifiedScheduler",
|
||||
"RESMultistepScheduler",
|
||||
"RESMultistepSDEScheduler",
|
||||
"RESSinglestepScheduler",
|
||||
"RESSinglestepSDEScheduler",
|
||||
"RESDEISMultistepScheduler",
|
||||
"ETDRKScheduler",
|
||||
"LawsonScheduler",
|
||||
"ABNorsettScheduler",
|
||||
"PECScheduler",
|
||||
"BongTangentScheduler",
|
||||
"RiemannianFlowScheduler",
|
||||
"LangevinDynamicsScheduler",
|
||||
"CommonSigmaScheduler",
|
||||
"SimpleExponentialScheduler",
|
||||
"LinearRKScheduler",
|
||||
"LobattoScheduler",
|
||||
"RadauIIAScheduler",
|
||||
"GaussLegendreScheduler",
|
||||
"SpecializedRKScheduler",
|
||||
# Variants
|
||||
"RES2MScheduler",
|
||||
"RES3MScheduler",
|
||||
"DEIS2MScheduler",
|
||||
"DEIS3MScheduler",
|
||||
"RES2MSDEScheduler",
|
||||
"RES3MSDEScheduler",
|
||||
"RES2SScheduler",
|
||||
"RES3SScheduler",
|
||||
"RES5SScheduler",
|
||||
"RES6SScheduler",
|
||||
"RES2SSDEScheduler",
|
||||
"RES3SSDEScheduler",
|
||||
"RES5SSDEScheduler",
|
||||
"RES6SSDEScheduler",
|
||||
"ETDRK2Scheduler",
|
||||
"ETDRK3AScheduler",
|
||||
"ETDRK3BScheduler",
|
||||
"ETDRK4Scheduler",
|
||||
"ETDRK4AltScheduler",
|
||||
"Lawson2AScheduler",
|
||||
"Lawson2BScheduler",
|
||||
"Lawson4Scheduler",
|
||||
"ABNorsett2MScheduler",
|
||||
"ABNorsett3MScheduler",
|
||||
"ABNorsett4MScheduler",
|
||||
"PEC2H2SScheduler",
|
||||
"PEC2H3SScheduler",
|
||||
"FlowEuclideanScheduler",
|
||||
"FlowHyperbolicScheduler",
|
||||
"FlowSphericalScheduler",
|
||||
"FlowLorentzianScheduler",
|
||||
"SigmaSigmoidScheduler",
|
||||
"SigmaSineScheduler",
|
||||
"SigmaEasingScheduler",
|
||||
"SigmaArcsineScheduler",
|
||||
"SigmaSmoothScheduler",
|
||||
"DEISUnified1SScheduler",
|
||||
"DEISUnified2MScheduler",
|
||||
"DEISUnified3MScheduler",
|
||||
"RESUnified2MScheduler",
|
||||
"RESUnified3MScheduler",
|
||||
"RESUnified2SScheduler",
|
||||
"RESUnified3SScheduler",
|
||||
"RESUnified5SScheduler",
|
||||
"RESUnified6SScheduler",
|
||||
"DEIS1MultistepScheduler",
|
||||
"DEIS2MultistepScheduler",
|
||||
"DEIS3MultistepScheduler",
|
||||
"LinearRK2Scheduler",
|
||||
"LinearRK3Scheduler",
|
||||
"LinearRK4Scheduler",
|
||||
"LinearRKRalsstonScheduler",
|
||||
"LinearRKMidpointScheduler",
|
||||
"Lobatto2Scheduler",
|
||||
"Lobatto3Scheduler",
|
||||
"Lobatto4Scheduler",
|
||||
"RadauIIA2Scheduler",
|
||||
"RadauIIA3Scheduler",
|
||||
"GaussLegendre2SScheduler",
|
||||
"GaussLegendre3SScheduler",
|
||||
"GaussLegendre4SScheduler",
|
||||
"RungeKutta44Scheduler",
|
||||
"RungeKutta57Scheduler",
|
||||
"RungeKutta67Scheduler",
|
||||
]
|
||||
|
||||
BASE = [
|
||||
("RES Unified", RESUnifiedScheduler),
|
||||
("RES Multistep", RESMultistepScheduler),
|
||||
("RES Multistep SDE", RESMultistepSDEScheduler),
|
||||
("RES Singlestep", RESSinglestepScheduler),
|
||||
("RES Singlestep SDE", RESSinglestepSDEScheduler),
|
||||
("DEIS Multistep", RESDEISMultistepScheduler),
|
||||
("ETDRK", ETDRKScheduler),
|
||||
("Lawson", LawsonScheduler),
|
||||
("ABNorsett", ABNorsettScheduler),
|
||||
("PEC", PECScheduler),
|
||||
("Common Sigma", CommonSigmaScheduler),
|
||||
("Riemannian Flow", RiemannianFlowScheduler),
|
||||
("Specialized RK", SpecializedRKScheduler),
|
||||
]
|
||||
|
||||
SIMPLE = [
|
||||
("Bong Tangent", BongTangentScheduler),
|
||||
("Langevin Dynamics", LangevinDynamicsScheduler),
|
||||
("Simple Exponential", SimpleExponentialScheduler),
|
||||
]
|
||||
|
||||
VARIANTS = [
|
||||
("RES 2M", RES2MScheduler),
|
||||
("RES 3M", RES3MScheduler),
|
||||
("DEIS 2M", DEIS2MScheduler),
|
||||
("DEIS 3M", DEIS3MScheduler),
|
||||
("RES 2M SDE", RES2MSDEScheduler),
|
||||
("RES 3M SDE", RES3MSDEScheduler),
|
||||
("RES 2S", RES2SScheduler),
|
||||
("RES 3S", RES3SScheduler),
|
||||
("RES 5S", RES5SScheduler),
|
||||
("RES 6S", RES6SScheduler),
|
||||
("RES 2S SDE", RES2SSDEScheduler),
|
||||
("RES 3S SDE", RES3SSDEScheduler),
|
||||
("RES 5S SDE", RES5SSDEScheduler),
|
||||
("RES 6S SDE", RES6SSDEScheduler),
|
||||
("ETDRK 2", ETDRK2Scheduler),
|
||||
("ETDRK 3A", ETDRK3AScheduler),
|
||||
("ETDRK 3B", ETDRK3BScheduler),
|
||||
("ETDRK 4", ETDRK4Scheduler),
|
||||
("ETDRK 4 Alt", ETDRK4AltScheduler),
|
||||
("Lawson 2A", Lawson2AScheduler),
|
||||
("Lawson 2B", Lawson2BScheduler),
|
||||
("Lawson 4", Lawson4Scheduler),
|
||||
("ABNorsett 2M", ABNorsett2MScheduler),
|
||||
("ABNorsett 3M", ABNorsett3MScheduler),
|
||||
("ABNorsett 4M", ABNorsett4MScheduler),
|
||||
("PEC 2H2S", PEC2H2SScheduler),
|
||||
("PEC 2H3S", PEC2H3SScheduler),
|
||||
("Euclidean Flow", FlowEuclideanScheduler),
|
||||
("Hyperbolic Flow", FlowHyperbolicScheduler),
|
||||
("Spherical Flow", FlowSphericalScheduler),
|
||||
("Lorentzian Flow", FlowLorentzianScheduler),
|
||||
("Sigmoid Sigma", SigmaSigmoidScheduler),
|
||||
("Sine Sigma", SigmaSineScheduler),
|
||||
("Easing Sigma", SigmaEasingScheduler),
|
||||
("Arcsine Sigma", SigmaArcsineScheduler),
|
||||
("Smoothstep Sigma", SigmaSmoothScheduler),
|
||||
("DEIS Unified 1", DEISUnified1SScheduler),
|
||||
("DEIS Unified 2", DEISUnified2MScheduler),
|
||||
("DEIS Unified 3", DEISUnified3MScheduler),
|
||||
("RES Unified 2M", RESUnified2MScheduler),
|
||||
("RES Unified 3M", RESUnified3MScheduler),
|
||||
("RES Unified 2S", RESUnified2SScheduler),
|
||||
("RES Unified 3S", RESUnified3SScheduler),
|
||||
("RES Unified 5S", RESUnified5SScheduler),
|
||||
("RES Unified 6S", RESUnified6SScheduler),
|
||||
("DEIS Multistep 1", DEIS1MultistepScheduler),
|
||||
("DEIS Multistep 2", DEIS2MultistepScheduler),
|
||||
("DEIS Multistep 3", DEIS3MultistepScheduler),
|
||||
("Linear-RK 2", LinearRK2Scheduler),
|
||||
("Linear-RK 3", LinearRK3Scheduler),
|
||||
("Linear-RK 4", LinearRK4Scheduler),
|
||||
("Linear-RK Ralston", LinearRKRalsstonScheduler),
|
||||
("Linear-RK Midpoint", LinearRKMidpointScheduler),
|
||||
("Lobatto 2", Lobatto2Scheduler),
|
||||
("Lobatto 3", Lobatto3Scheduler),
|
||||
("Lobatto 4", Lobatto4Scheduler),
|
||||
("Radau-IIA 2", RadauIIA2Scheduler),
|
||||
("Radau-IIA 3", RadauIIA3Scheduler),
|
||||
("Gauss-Legendre 2S", GaussLegendre2SScheduler),
|
||||
("Gauss-Legendre 3S", GaussLegendre3SScheduler),
|
||||
("Gauss-Legendre 4S", GaussLegendre4SScheduler),
|
||||
("Runge-Kutta 4/4", RungeKutta44Scheduler),
|
||||
("Runge-Kutta 5/7", RungeKutta57Scheduler),
|
||||
("Runge-Kutta 6/7", RungeKutta67Scheduler),
|
||||
]
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .phi_functions import Phi
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ABNorsettScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Adams-Bashforth Norsett (ABNorsett) scheduler.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["abnorsett_2m", "abnorsett_3m", "abnorsett_4m"] = "abnorsett_2m",
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Buffer for multistep
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
s_min = getattr(self.config, "sigma_min", None)
|
||||
s_max = getattr(self.config, "sigma_max", None)
|
||||
if s_min is None:
|
||||
s_min = 0.001
|
||||
if s_max is None:
|
||||
s_max = 1.0
|
||||
sigmas = np.linspace(s_max, s_min, num_inference_steps)
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
# Map shifted sigmas back to timesteps (Linear mapping for Flow)
|
||||
# t = sigma * 1000. Use standard linear scaling.
|
||||
# This ensures the model receives the correct time embedding for the shifted noise level.
|
||||
# We assume Flow sigmas are in [1.0, 0.0] range (before shift) and model expects [1000, 0].
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step = self._step_index
|
||||
sigma = self.sigmas[step]
|
||||
sigma_next = self.sigmas[step + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
self.model_outputs.append(model_output)
|
||||
self.x0_outputs.append(x0)
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
variant = self.config.variant
|
||||
order = int(variant[-2])
|
||||
curr_order = min(len(self.prev_sigmas), order)
|
||||
|
||||
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
|
||||
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# Multi-step coefficients b for ABNorsett family
|
||||
if curr_order == 1:
|
||||
b = [[phi(1)]]
|
||||
elif curr_order == 2:
|
||||
b2 = -phi(2)
|
||||
b1 = phi(1) - b2
|
||||
b = [[b1, b2]]
|
||||
elif curr_order == 3:
|
||||
b2 = -2 * phi(2) - 2 * phi(3)
|
||||
b3 = 0.5 * phi(2) + phi(3)
|
||||
b1 = phi(1) - (b2 + b3)
|
||||
b = [[b1, b2, b3]]
|
||||
elif curr_order == 4:
|
||||
b2 = -3 * phi(2) - 5 * phi(3) - 3 * phi(4)
|
||||
b3 = 1.5 * phi(2) + 4 * phi(3) + 3 * phi(4)
|
||||
b4 = -1 / 3 * phi(2) - phi(3) - phi(4)
|
||||
b1 = phi(1) - (b2 + b3 + b4)
|
||||
b = [[b1, b2, b3, b4]]
|
||||
else:
|
||||
b = [[phi(1)]]
|
||||
|
||||
# Apply coefficients to x0 buffer
|
||||
res = torch.zeros_like(sample)
|
||||
for i, b_val in enumerate(b[0]):
|
||||
idx = len(self.x0_outputs) - 1 - i
|
||||
if idx >= 0:
|
||||
res += b_val * self.x0_outputs[idx]
|
||||
|
||||
# Exponential Integrator Update
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
# Variable Step Adams-Bashforth for Flow Matching
|
||||
# x_{n+1} = x_n + \int_{t_n}^{t_{n+1}} v(t) dt
|
||||
sigma_curr = sigma
|
||||
dt = sigma_next - sigma_curr
|
||||
|
||||
# Current derivative v_n is self.model_outputs[-1]
|
||||
v_n = self.model_outputs[-1]
|
||||
|
||||
if curr_order == 1:
|
||||
# Euler: x_{n+1} = x_n + dt * v_n
|
||||
x_next = sample + dt * v_n
|
||||
elif curr_order == 2:
|
||||
# AB2 Variable Step
|
||||
# x_{n+1} = x_n + dt * [ (1 + r/2) * v_n - (r/2) * v_{n-1} ]
|
||||
# where r = dt_cur / dt_prev
|
||||
|
||||
v_nm1 = self.model_outputs[-2]
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma_curr - sigma_prev
|
||||
|
||||
if abs(dt_prev) < 1e-8:
|
||||
# Fallback to Euler if division by zero risk
|
||||
x_next = sample + dt * v_n
|
||||
else:
|
||||
r = dt / dt_prev
|
||||
# Standard variable step AB2 coefficients
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * v_nm1)
|
||||
|
||||
elif curr_order >= 3:
|
||||
# For now, fallback to AB2 (variable) for higher orders to ensure stability
|
||||
# given the complexity of variable-step AB3/4 formulas inline.
|
||||
# The user specifically requested abnorsett_2m.
|
||||
v_nm1 = self.model_outputs[-2]
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma_curr - sigma_prev
|
||||
|
||||
if abs(dt_prev) < 1e-8:
|
||||
x_next = sample + dt * v_n
|
||||
else:
|
||||
r = dt / dt_prev
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * v_nm1)
|
||||
else:
|
||||
x_next = sample + dt * v_n
|
||||
|
||||
else:
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if len(self.x0_outputs) > order:
|
||||
self.x0_outputs.pop(0)
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
# Copyright 2025 The RES4LYF Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BongTangentScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
BongTangent scheduler using Exponential Integrator step.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
start: float = 1.0,
|
||||
middle: float = 0.5,
|
||||
end: float = 0.0,
|
||||
pivot_1: float = 0.6,
|
||||
pivot_2: float = 0.6,
|
||||
slope_1: float = 0.2,
|
||||
slope_2: float = 0.2,
|
||||
pad: bool = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.sigmas = torch.Tensor([])
|
||||
self.timesteps = torch.Tensor([])
|
||||
self.num_inference_steps = None
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timestep_spacing = getattr(self.config, "timestep_spacing", "linspace")
|
||||
steps_offset = getattr(self.config, "steps_offset", 0)
|
||||
|
||||
if timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += steps_offset
|
||||
elif timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {timestep_spacing} is not supported.")
|
||||
|
||||
# Derived sigma range from alphas_cumprod
|
||||
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
# Note: alphas_cumprod[0] is ~0.999 (small sigma), alphas_cumprod[-1] is ~0.0001 (large sigma)
|
||||
sigma_max = base_sigmas[-1]
|
||||
sigma_min = base_sigmas[0]
|
||||
sigma_mid = (sigma_max + sigma_min) / 2 # Default midpoint for tangent nodes
|
||||
|
||||
steps = num_inference_steps
|
||||
midpoint = int(steps * getattr(self.config, "midpoint", 0.5))
|
||||
p1 = int(steps * getattr(self.config, "pivot_1", 0.6))
|
||||
p2 = int(steps * getattr(self.config, "pivot_2", 0.6))
|
||||
|
||||
s1 = getattr(self.config, "slope_1", 0.2) / (steps / 40)
|
||||
s2 = getattr(self.config, "slope_2", 0.2) / (steps / 40)
|
||||
|
||||
stage_1_len = midpoint
|
||||
stage_2_len = steps - midpoint + 1
|
||||
|
||||
# Use model's sigma range for start/middle/end
|
||||
start_cfg = getattr(self.config, "start", 1.0)
|
||||
start_val = sigma_max * start_cfg if start_cfg > 1.0 else sigma_max
|
||||
end_val = sigma_min
|
||||
mid_val = sigma_mid
|
||||
|
||||
tan_sigmas_1 = self._get_bong_tangent_sigmas(stage_1_len, s1, p1, start_val, mid_val, dtype=dtype)
|
||||
tan_sigmas_2 = self._get_bong_tangent_sigmas(stage_2_len, s2, p2 - stage_1_len, mid_val, end_val, dtype=dtype)
|
||||
|
||||
tan_sigmas_1 = tan_sigmas_1[:-1]
|
||||
sigmas_list = tan_sigmas_1 + tan_sigmas_2
|
||||
if getattr(self.config, "pad", False):
|
||||
sigmas_list.append(0.0)
|
||||
|
||||
sigmas = np.array(sigmas_list)
|
||||
|
||||
if getattr(self.config, "use_karras_sigmas", False):
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_exponential_sigmas", False):
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_beta_sigmas", False):
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_flow_sigmas", False):
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
shift = getattr(self.config, "shift", 1.0)
|
||||
use_dynamic_shifting = getattr(self.config, "use_dynamic_shifting", False)
|
||||
if shift != 1.0 or use_dynamic_shifting:
|
||||
if use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
getattr(self.config, "base_shift", 0.5),
|
||||
getattr(self.config, "max_shift", 1.5),
|
||||
getattr(self.config, "base_image_seq_len", 256),
|
||||
getattr(self.config, "max_image_seq_len", 4096),
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def _get_bong_tangent_sigmas(self, steps: int, slope: float, pivot: int, start: float, end: float, dtype: torch.dtype = torch.float32) -> List[float]:
|
||||
x = torch.arange(steps, dtype=dtype)
|
||||
|
||||
def bong_fn(val):
|
||||
return ((2 / torch.pi) * torch.atan(-slope * (val - pivot)) + 1) / 2
|
||||
|
||||
smax = bong_fn(torch.tensor(0.0))
|
||||
smin = bong_fn(torch.tensor(steps - 1.0))
|
||||
|
||||
srange = smax - smin
|
||||
sscale = start - end
|
||||
|
||||
sigmas = ((bong_fn(x) - smin) * (1 / srange) * sscale + end)
|
||||
return sigmas.tolist()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sigma_next = self.sigmas[self._step_index + 1]
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1)**0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
# Exponential Integrator Update
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CommonSigmaScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Common Sigma scheduler using Exponential Integrator step.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order: ClassVar[int] = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
profile: Literal["sigmoid", "sine", "easing", "arcsine", "smoothstep"] = "sigmoid",
|
||||
variant: str = "logistic",
|
||||
strength: float = 1.0,
|
||||
gain: float = 1.0,
|
||||
offset: float = 0.0,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = torch.Tensor([])
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
# Derived sigma range from alphas_cumprod
|
||||
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigma_max = base_sigmas[-1]
|
||||
sigma_min = base_sigmas[0]
|
||||
|
||||
t = torch.linspace(0, 1, num_inference_steps)
|
||||
profile = self.config.profile
|
||||
variant = self.config.variant
|
||||
gain = self.config.gain
|
||||
offset = self.config.offset
|
||||
|
||||
if profile == "sigmoid":
|
||||
x = gain * (t * 10 - 5 + offset)
|
||||
if variant == "logistic":
|
||||
result = 1.0 / (1.0 + torch.exp(-x))
|
||||
elif variant == "tanh":
|
||||
result = (torch.tanh(x) + 1) / 2
|
||||
else:
|
||||
result = torch.sigmoid(x)
|
||||
elif profile == "sine":
|
||||
result = torch.sin(t * math.pi / 2)
|
||||
elif profile == "easing":
|
||||
result = t * t * (3 - 2 * t)
|
||||
elif profile == "arcsine":
|
||||
result = torch.arcsin(t) / (math.pi / 2)
|
||||
else:
|
||||
result = t
|
||||
|
||||
# Map profile to sigma range
|
||||
sigmas = (sigma_max * (1 - result) + sigma_min * result).cpu().numpy()
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1)**0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
# Exponential Integrator Update
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,403 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
from .phi_functions import Phi
|
||||
|
||||
|
||||
def get_def_integral_2(a, b, start, end, c):
|
||||
coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
|
||||
return coeff / ((c - a) * (c - b))
|
||||
|
||||
|
||||
def get_def_integral_3(a, b, c, start, end, d):
|
||||
coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 + (end**2 - start**2) * (a * b + a * c + b * c) / 2 - (end - start) * a * b * c
|
||||
return coeff / ((d - a) * (d - b) * (d - c))
|
||||
|
||||
|
||||
class RESDEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RESDEISMultistepScheduler: Diffusion Explicit Iterative Sampler with high-order multistep.
|
||||
Adapted from the RES4LYF repository.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
solver_order: int = 2,
|
||||
use_analytic_solution: bool = True,
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state
|
||||
self.model_outputs = []
|
||||
self.hist_samples = []
|
||||
self._step_index = None
|
||||
self._sigmas_cpu = None
|
||||
self.all_coeffs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float).copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = (np.arange(num_inference_steps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= step_ratio
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
if self.config.timestep_spacing == "trailing":
|
||||
timesteps = np.maximum(timesteps, 0)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas_all = np.log(np.maximum(sigmas, 1e-10))
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
# 2. Sigma Schedule
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# Map back to timesteps
|
||||
if self.config.use_flow_sigmas:
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
else:
|
||||
timesteps = np.interp(np.log(np.maximum(sigmas, 1e-10)), log_sigmas_all, np.arange(len(log_sigmas_all)))
|
||||
|
||||
self.sigmas = torch.from_numpy(np.append(sigmas, 0.0)).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
|
||||
|
||||
# Precompute coefficients
|
||||
self.all_coeffs = []
|
||||
num_steps = len(timesteps)
|
||||
for i in range(num_steps):
|
||||
sigma_t = self._sigmas_cpu[i]
|
||||
sigma_next = self._sigmas_cpu[i + 1]
|
||||
|
||||
if sigma_next <= 0:
|
||||
coeffs = None
|
||||
else:
|
||||
current_order = min(i + 1, self.config.solver_order)
|
||||
if current_order == 1:
|
||||
coeffs = [sigma_next - sigma_t]
|
||||
else:
|
||||
ts = [self._sigmas_cpu[i - j] for j in range(current_order)]
|
||||
t_next = sigma_next
|
||||
if current_order == 2:
|
||||
t_cur, t_prev1 = ts[0], ts[1]
|
||||
coeff_cur = ((t_next - t_prev1) ** 2 - (t_cur - t_prev1) ** 2) / (2 * (t_cur - t_prev1))
|
||||
coeff_prev1 = (t_next - t_cur) ** 2 / (2 * (t_prev1 - t_cur))
|
||||
coeffs = [coeff_cur, coeff_prev1]
|
||||
elif current_order == 3:
|
||||
t_cur, t_prev1, t_prev2 = ts[0], ts[1], ts[2]
|
||||
coeffs = [
|
||||
get_def_integral_2(t_prev1, t_prev2, t_cur, t_next, t_cur),
|
||||
get_def_integral_2(t_cur, t_prev2, t_cur, t_next, t_prev1),
|
||||
get_def_integral_2(t_cur, t_prev1, t_cur, t_next, t_prev2),
|
||||
]
|
||||
elif current_order == 4:
|
||||
t_cur, t_prev1, t_prev2, t_prev3 = ts[0], ts[1], ts[2], ts[3]
|
||||
coeffs = [
|
||||
get_def_integral_3(t_prev1, t_prev2, t_prev3, t_cur, t_next, t_cur),
|
||||
get_def_integral_3(t_cur, t_prev2, t_prev3, t_cur, t_next, t_prev1),
|
||||
get_def_integral_3(t_cur, t_prev1, t_prev3, t_cur, t_next, t_prev2),
|
||||
get_def_integral_3(t_cur, t_prev1, t_prev2, t_cur, t_next, t_prev3),
|
||||
]
|
||||
else:
|
||||
coeffs = [(sigma_next - sigma_t) / sigma_t] # Fallback to Euler
|
||||
self.all_coeffs.append(coeffs)
|
||||
|
||||
# Reset history
|
||||
self.model_outputs = []
|
||||
self.hist_samples = []
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma_t = self.sigmas[step_index]
|
||||
|
||||
# RECONSTRUCT X0 (Matching PEC pattern)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {self.config.prediction_type}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
# Variable Step Adams-Bashforth for Flow Matching
|
||||
self.model_outputs.append(model_output)
|
||||
self.prev_sigmas.append(sigma_t)
|
||||
# Note: deis uses hist_samples for x0? I'll use model_outputs for v.
|
||||
if len(self.model_outputs) > 4:
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
dt = self.sigmas[step_index + 1] - sigma_t
|
||||
v_n = model_output
|
||||
|
||||
curr_order = min(len(self.prev_sigmas), 3)
|
||||
|
||||
if curr_order == 1:
|
||||
x_next = sample + dt * v_n
|
||||
elif curr_order == 2:
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma_t - sigma_prev
|
||||
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
|
||||
if dt_prev == 0 or r < -0.9 or r > 2.0:
|
||||
x_next = sample + dt * v_n
|
||||
else:
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
|
||||
else:
|
||||
# AB2 fallback
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma_t - sigma_prev
|
||||
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
|
||||
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
if self.config.solver_order == 1:
|
||||
# 1st order step (Euler) in x-space
|
||||
x_next = (sigma_next / sigma_t) * sample + (1 - sigma_next / sigma_t) * denoised
|
||||
prev_sample = x_next
|
||||
else:
|
||||
# Multistep weights based on phi functions (consistent with RESMultistep)
|
||||
h = -torch.log(sigma_next / sigma_t) if sigma_t > 0 and sigma_next > 0 else torch.zeros_like(sigma_t)
|
||||
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
|
||||
phi_1 = phi(1)
|
||||
|
||||
# History of denoised samples
|
||||
x0s = [denoised] + self.model_outputs[::-1]
|
||||
orders = min(len(x0s), self.config.solver_order)
|
||||
|
||||
# Force Order 1 at the end of schedule
|
||||
if self.num_inference_steps is not None and step_index >= self.num_inference_steps - 3:
|
||||
res = phi_1 * denoised
|
||||
elif orders == 1:
|
||||
res = phi_1 * denoised
|
||||
elif orders == 2:
|
||||
# Use phi(2) for 2nd order interpolation
|
||||
h_prev = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
|
||||
h_prev_t = torch.tensor(h_prev, device=sample.device, dtype=sample.dtype)
|
||||
r = h_prev_t / (h + 1e-9)
|
||||
h_prev = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
|
||||
h_prev_t = torch.tensor(h_prev, device=sample.device, dtype=sample.dtype)
|
||||
r = h_prev_t / (h + 1e-9)
|
||||
|
||||
# Hard Restart
|
||||
if r < 0.5 or r > 2.0:
|
||||
res = phi_1 * denoised
|
||||
else:
|
||||
phi_2 = phi(2)
|
||||
# Correct Adams-Bashforth-like coefficients: b2 = -phi_2 / r
|
||||
b2 = -phi_2 / (r + 1e-9)
|
||||
b1 = phi_1 - b2
|
||||
res = b1 * x0s[0] + b2 * x0s[1]
|
||||
elif orders == 3:
|
||||
# 3rd order with varying step sizes
|
||||
# 3rd order with varying step sizes
|
||||
h_p1 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
|
||||
h_p2 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 2] + 1e-9))
|
||||
r1 = torch.tensor(h_p1, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
|
||||
r2 = torch.tensor(h_p2, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
|
||||
h_p1 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 1] + 1e-9))
|
||||
h_p2 = -np.log(self._sigmas_cpu[step_index] / (self._sigmas_cpu[step_index - 2] + 1e-9))
|
||||
r1 = torch.tensor(h_p1, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
|
||||
r2 = torch.tensor(h_p2, device=sample.device, dtype=sample.dtype) / (h + 1e-9)
|
||||
|
||||
# Hard Restart
|
||||
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
|
||||
res = phi_1 * denoised
|
||||
else:
|
||||
phi_2, phi_3 = phi(2), phi(3)
|
||||
denom = r2 - r1 + 1e-9
|
||||
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
|
||||
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
|
||||
b1 = phi_1 - b2 - b3
|
||||
res = b1 * x0s[0] + b2 * x0s[1] + b3 * x0s[2]
|
||||
else:
|
||||
# Fallback to Euler or lower order
|
||||
res = phi_1 * denoised
|
||||
|
||||
# Stable update in x-space
|
||||
if sigma_next == 0:
|
||||
x_next = denoised
|
||||
else:
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
prev_sample = x_next
|
||||
|
||||
# Store state (always store x0)
|
||||
self.model_outputs.append(denoised)
|
||||
self.hist_samples.append(sample)
|
||||
|
||||
if len(self.model_outputs) > 4:
|
||||
self.model_outputs.pop(0)
|
||||
self.hist_samples.pop(0)
|
||||
|
||||
if self._step_index is not None:
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .phi_functions import Phi
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ETDRKScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Exponential Time Differencing Runge-Kutta (ETDRK) scheduler.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["etdrk2_2s", "etdrk3_a_3s", "etdrk3_b_3s", "etdrk4_4s", "etdrk4_4s_alt"] = "etdrk4_4s",
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Buffer for multistage/multistep
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
self.model_outputs.append(model_output)
|
||||
self.x0_outputs.append(x0)
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
variant = self.config.variant
|
||||
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# ETDRK coefficients
|
||||
if variant == "etdrk2_2s":
|
||||
ci = [0.0, 1.0]
|
||||
phi = Phi(h, ci, self.config.use_analytic_solution)
|
||||
if len(self.x0_outputs) < 2:
|
||||
res = phi(1) * x0
|
||||
else:
|
||||
eps_1, eps_2 = self.x0_outputs[-2:]
|
||||
b2 = phi(2)
|
||||
b1 = phi(1) - b2
|
||||
res = b1 * eps_1 + b2 * eps_2
|
||||
elif variant == "etdrk3_b_3s":
|
||||
ci = [0, 4/9, 2/3]
|
||||
phi = Phi(h, ci, self.config.use_analytic_solution)
|
||||
if len(self.x0_outputs) < 3:
|
||||
res = phi(1) * x0
|
||||
else:
|
||||
eps_1, eps_2, eps_3 = self.x0_outputs[-3:]
|
||||
b3 = (3/2) * phi(2)
|
||||
b2 = 0
|
||||
b1 = phi(1) - b3
|
||||
res = b1 * eps_1 + b2 * eps_2 + b3 * eps_3
|
||||
elif variant == "etdrk4_4s":
|
||||
ci = [0, 1/2, 1/2, 1]
|
||||
phi = Phi(h, ci, self.config.use_analytic_solution)
|
||||
if len(self.x0_outputs) < 4:
|
||||
res = phi(1) * x0
|
||||
else:
|
||||
e1, e2, e3, e4 = self.x0_outputs[-4:]
|
||||
b2 = 2*phi(2) - 4*phi(3)
|
||||
b3 = 2*phi(2) - 4*phi(3)
|
||||
b4 = -phi(2) + 4*phi(3)
|
||||
b1 = phi(1) - (b2 + b3 + b4)
|
||||
res = b1 * e1 + b2 * e2 + b3 * e3 + b4 * e4
|
||||
else:
|
||||
res = Phi(h, [0], self.config.use_analytic_solution)(1) * x0
|
||||
|
||||
# Exponential Integrator Update
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
# Buffer control
|
||||
limit = 4 if variant.startswith("etdrk4") else 3
|
||||
if len(self.x0_outputs) > limit:
|
||||
self.x0_outputs.pop(0)
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,384 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class GaussLegendreScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
GaussLegendreScheduler: High-accuracy implicit symplectic integrators.
|
||||
Supports various orders (2s, 3s, 4s, 5s, 8s-diagonal).
|
||||
Adapted from the RES4LYF repository.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: str = "gauss-legendre_2s", # 2s to 8s variants
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
def _get_tableau(self):
|
||||
v = self.config.variant
|
||||
if v == "gauss-legendre_2s":
|
||||
r3 = 3**0.5
|
||||
a = [[1 / 4, 1 / 4 - r3 / 6], [1 / 4 + r3 / 6, 1 / 4]]
|
||||
b = [1 / 2, 1 / 2]
|
||||
c = [1 / 2 - r3 / 6, 1 / 2 + r3 / 6]
|
||||
elif v == "gauss-legendre_3s":
|
||||
r15 = 15**0.5
|
||||
a = [[5 / 36, 2 / 9 - r15 / 15, 5 / 36 - r15 / 30], [5 / 36 + r15 / 24, 2 / 9, 5 / 36 - r15 / 24], [5 / 36 + r15 / 30, 2 / 9 + r15 / 15, 5 / 36]]
|
||||
b = [5 / 18, 4 / 9, 5 / 18]
|
||||
c = [1 / 2 - r15 / 10, 1 / 2, 1 / 2 + r15 / 10]
|
||||
elif v == "gauss-legendre_4s":
|
||||
r15 = 15**0.5
|
||||
a = [[1 / 4, 1 / 4 - r15 / 6, 1 / 4 + r15 / 6, 1 / 4], [1 / 4 + r15 / 6, 1 / 4, 1 / 4 - r15 / 6, 1 / 4], [1 / 4, 1 / 4 + r15 / 6, 1 / 4, 1 / 4 - r15 / 6], [1 / 4 - r15 / 6, 1 / 4, 1 / 4 + r15 / 6, 1 / 4]]
|
||||
b = [1 / 8, 3 / 8, 3 / 8, 1 / 8]
|
||||
c = [1 / 2 - r15 / 10, 1 / 2 + r15 / 10, 1 / 2 + r15 / 10, 1 / 2 - r15 / 10]
|
||||
elif v == "gauss-legendre_5s":
|
||||
r739 = 739**0.5
|
||||
a = [
|
||||
[
|
||||
4563950663 / 32115191526,
|
||||
(310937500000000 / 2597974476091533 + 45156250000 * r739 / 8747388808389),
|
||||
(310937500000000 / 2597974476091533 - 45156250000 * r739 / 8747388808389),
|
||||
(5236016175 / 88357462711 + 709703235 * r739 / 353429850844),
|
||||
(5236016175 / 88357462711 - 709703235 * r739 / 353429850844),
|
||||
],
|
||||
[
|
||||
(4563950663 / 32115191526 - 38339103 * r739 / 6250000000),
|
||||
(310937500000000 / 2597974476091533 + 9557056475401 * r739 / 3498955523355600000),
|
||||
(310937500000000 / 2597974476091533 - 14074198220719489 * r739 / 3498955523355600000),
|
||||
(5236016175 / 88357462711 + 5601362553163918341 * r739 / 2208936567775000000000),
|
||||
(5236016175 / 88357462711 - 5040458465159165409 * r739 / 2208936567775000000000),
|
||||
],
|
||||
[
|
||||
(4563950663 / 32115191526 + 38339103 * r739 / 6250000000),
|
||||
(310937500000000 / 2597974476091533 + 14074198220719489 * r739 / 3498955523355600000),
|
||||
(310937500000000 / 2597974476091533 - 9557056475401 * r739 / 3498955523355600000),
|
||||
(5236016175 / 88357462711 + 5040458465159165409 * r739 / 2208936567775000000000),
|
||||
(5236016175 / 88357462711 - 5601362553163918341 * r739 / 2208936567775000000000),
|
||||
],
|
||||
[
|
||||
(4563950663 / 32115191526 - 38209 * r739 / 7938810),
|
||||
(310937500000000 / 2597974476091533 - 359369071093750 * r739 / 70145310854471391),
|
||||
(310937500000000 / 2597974476091533 - 323282178906250 * r739 / 70145310854471391),
|
||||
(5236016175 / 88357462711 - 470139 * r739 / 1413719403376),
|
||||
(5236016175 / 88357462711 - 44986764863 * r739 / 21205791050640),
|
||||
],
|
||||
[
|
||||
(4563950663 / 32115191526 + 38209 * r739 / 7938810),
|
||||
(310937500000000 / 2597974476091533 + 359369071093750 * r739 / 70145310854471391),
|
||||
(310937500000000 / 2597974476091533 + 323282178906250 * r739 / 70145310854471391),
|
||||
(5236016175 / 88357462711 + 44986764863 * r739 / 21205791050640),
|
||||
(5236016175 / 88357462711 + 470139 * r739 / 1413719403376),
|
||||
],
|
||||
]
|
||||
b = [4563950663 / 16057595763, 621875000000000 / 2597974476091533, 621875000000000 / 2597974476091533, 10472032350 / 88357462711, 10472032350 / 88357462711]
|
||||
c = [1 / 2, 1 / 2 - 99 * r739 / 10000, 1 / 2 + 99 * r739 / 10000, 1 / 2 - r739 / 60, 1 / 2 + r739 / 60]
|
||||
elif v == "gauss-legendre_diag_8s":
|
||||
a = [
|
||||
[0.5, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1.0818949631055815, 0.5, 0, 0, 0, 0, 0, 0],
|
||||
[0.9599572962220549, 1.0869589243008327, 0.5, 0, 0, 0, 0, 0],
|
||||
[1.0247213458032004, 0.9550588736973743, 1.0880938387323083, 0.5, 0, 0, 0, 0],
|
||||
[0.9830238267636289, 1.0287597754747493, 0.9538345351852, 1.0883471611098278, 0.5, 0, 0, 0],
|
||||
[1.0122259141132982, 0.9799828723635913, 1.0296038730649779, 0.9538345351852, 1.0880938387323083, 0.5, 0, 0],
|
||||
[0.9912514332308026, 1.0140743558891669, 0.9799828723635913, 1.0287597754747493, 0.9550588736973743, 1.0869589243008327, 0.5, 0],
|
||||
[1.0054828082532159, 0.9912514332308026, 1.0122259141132982, 0.9830238267636289, 1.0247213458032004, 0.9599572962220549, 1.0818949631055815, 0.5],
|
||||
]
|
||||
b = [0.05061426814518813, 0.11119051722668724, 0.15685332293894364, 0.181341891689181, 0.181341891689181, 0.15685332293894364, 0.11119051722668724, 0.05061426814518813]
|
||||
c = [0.019855071751231884, 0.10166676129318663, 0.2372337950418355, 0.4082826787521751, 0.5917173212478249, 0.7627662049581645, 0.8983332387068134, 0.9801449282487681]
|
||||
else:
|
||||
raise ValueError(f"Unknown variant: {v}")
|
||||
return np.array(a), np.array(b), np.array(c)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
# 2. Sigma Schedule
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# We handle multi-history expansion
|
||||
_a_mat, _b_vec, c_vec = self._get_tableau()
|
||||
len(c_vec)
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
for c_val in c_vec:
|
||||
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0)
|
||||
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
a_mat, b_vec, c_vec = self._get_tableau()
|
||||
num_stages = len(c_vec)
|
||||
|
||||
stage_index = step_index % num_stages
|
||||
base_step_index = (step_index // num_stages) * num_stages
|
||||
|
||||
sigma_curr = self.sigmas[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
|
||||
sigma_next = self.sigmas[sigma_next_idx]
|
||||
|
||||
if sigma_next <= 0:
|
||||
sigma_t = self.sigmas[step_index]
|
||||
prediction_type = getattr(self.config, "prediction_type", "epsilon")
|
||||
if prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
else:
|
||||
denoised = model_output
|
||||
|
||||
if getattr(self.config, "clip_sample", False):
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
prev_sample = denoised
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
h = sigma_next - sigma_curr
|
||||
sigma_t = self.sigmas[step_index]
|
||||
|
||||
prediction_type = getattr(self.config, "prediction_type", "epsilon")
|
||||
if prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {prediction_type}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self.sigmas[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
else:
|
||||
self.model_outputs.append(derivative)
|
||||
|
||||
# Predict sample for next stage
|
||||
next_stage_idx = stage_index + 1
|
||||
if next_stage_idx < num_stages:
|
||||
sum_ak = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
|
||||
|
||||
sigma_next_stage = self.sigmas[min(step_index + 1, len(self.sigmas) - 1)]
|
||||
|
||||
# Update x (unnormalized sample)
|
||||
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
|
||||
else:
|
||||
# Final step update using b coefficients
|
||||
sum_bk = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_bk
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LangevinDynamicsScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Langevin Dynamics sigma scheduler using Exponential Integrator step.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order: ClassVar[int] = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
temperature: float = 0.5,
|
||||
friction: float = 1.0,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = torch.zeros((num_train_timesteps,), dtype=torch.float32)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# Discretization parameters for Langevin schedule generation
|
||||
dt = 1.0 / num_inference_steps
|
||||
sqrt_2dt = math.sqrt(2 * dt)
|
||||
|
||||
start_sigma = 10.0
|
||||
if hasattr(self, "alphas_cumprod"):
|
||||
start_sigma = float(((1 - self.alphas_cumprod[-1]) / self.alphas_cumprod[-1]) ** 0.5)
|
||||
|
||||
end_sigma = 0.01
|
||||
|
||||
def grad_U(x):
|
||||
return x - end_sigma
|
||||
|
||||
x = torch.tensor([start_sigma], dtype=dtype)
|
||||
v = torch.zeros(1)
|
||||
|
||||
trajectory = [start_sigma]
|
||||
temperature = self.config.temperature
|
||||
friction = self.config.friction
|
||||
|
||||
for _ in range(num_inference_steps - 1):
|
||||
v = v - dt * friction * v - dt * grad_U(x) / 2
|
||||
x = x + dt * v
|
||||
noise = torch.randn(1, generator=generator) * sqrt_2dt * temperature
|
||||
v = v - dt * friction * v - dt * grad_U(x) / 2 + noise
|
||||
trajectory.append(x.item())
|
||||
|
||||
sigmas = np.array(trajectory)
|
||||
# Force monotonicity to prevent negative h in step()
|
||||
sigmas = np.sort(sigmas)[::-1]
|
||||
sigmas[-1] = end_sigma
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(np.linspace(1000, 0, num_inference_steps)).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
# Determine denoised (x_0 prediction)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1)**0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
# Exponential Integrator Update
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LawsonScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Lawson's integration method scheduler.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["lawson2a_2s", "lawson2b_2s", "lawson4_4s"] = "lawson4_4s",
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Buffer for multistage/multistep
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
exp_h = torch.exp(-h)
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
self.model_outputs.append(model_output)
|
||||
self.x0_outputs.append(x0)
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
variant = self.config.variant
|
||||
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# Lawson coefficients (anchored at x0)
|
||||
if variant == "lawson2a_2s":
|
||||
if len(self.x0_outputs) < 2:
|
||||
res = (1 - exp_h) / h * x0
|
||||
else:
|
||||
x0_1, x0_2 = self.x0_outputs[-2:]
|
||||
# b2 = exp(-h/2)
|
||||
# b1 = phi(1) - b2? No, Lawson is different.
|
||||
# But if we want it to be a valid exponential integrator,
|
||||
# we use the Lawson-specific weighting.
|
||||
res = torch.exp(-h/2) * x0_2
|
||||
elif variant == "lawson2b_2s":
|
||||
if len(self.x0_outputs) < 2:
|
||||
res = (1 - exp_h) / h * x0
|
||||
else:
|
||||
x0_1, x0_2 = self.x0_outputs[-2:]
|
||||
res = 0.5 * exp_h * x0_1 + 0.5 * x0_2
|
||||
elif variant == "lawson4_4s":
|
||||
if len(self.x0_outputs) < 4:
|
||||
res = (1 - exp_h) / h * x0
|
||||
else:
|
||||
e1, e2, e3, e4 = self.x0_outputs[-4:]
|
||||
b1 = (1/6) * exp_h
|
||||
b2 = (1/3) * torch.exp(-h/2)
|
||||
b3 = (1/3) * torch.exp(-h/2)
|
||||
b4 = 1/6
|
||||
res = b1 * e1 + b2 * e2 + b3 * e3 + b4 * e4
|
||||
else:
|
||||
res = (1 - exp_h) / h * x0
|
||||
|
||||
# Update
|
||||
x_next = exp_h * sample + h * res
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
# Buffer control
|
||||
limit = 4 if variant == "lawson4_4s" else 2
|
||||
if len(self.x0_outputs) > limit:
|
||||
self.x0_outputs.pop(0)
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,321 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class LinearRKScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
LinearRKScheduler: Standard explicit Runge-Kutta integrators.
|
||||
Supports Ralston, Midpoint, Heun, Kutta, and standard RK4.
|
||||
Adapted from the RES4LYF repository.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: str = "rk4", # euler, heun, rk2, rk3, rk4, ralston, midpoint
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
def _get_tableau(self):
|
||||
v = str(self.config.variant).lower().strip()
|
||||
if v in ["ralston", "ralston_2s"]:
|
||||
a, b, c = [[2 / 3]], [1 / 4, 3 / 4], [0, 2 / 3]
|
||||
elif v in ["midpoint", "midpoint_2s"]:
|
||||
a, b, c = [[1 / 2]], [0, 1], [0, 1 / 2]
|
||||
elif v in ["heun", "heun_2s"]:
|
||||
a, b, c = [[1]], [1 / 2, 1 / 2], [0, 1]
|
||||
elif v == "heun_3s":
|
||||
a, b, c = [[1 / 3], [0, 2 / 3]], [1 / 4, 0, 3 / 4], [0, 1 / 3, 2 / 3]
|
||||
elif v in ["kutta", "kutta_3s"]:
|
||||
a, b, c = [[1 / 2], [-1, 2]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
|
||||
elif v in ["rk4", "rk4_4s"]:
|
||||
a, b, c = [[1 / 2], [0, 1 / 2], [0, 0, 1]], [1 / 6, 1 / 3, 1 / 3, 1 / 6], [0, 1 / 2, 1 / 2, 1]
|
||||
elif v in ["rk2", "heun"]:
|
||||
a, b, c = [[1]], [1 / 2, 1 / 2], [0, 1]
|
||||
elif v == "rk3":
|
||||
a, b, c = [[1 / 2], [-1, 2]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
|
||||
elif v == "euler":
|
||||
a, b, c = [], [1], [0]
|
||||
else:
|
||||
raise ValueError(f"Unknown variant: {v}")
|
||||
|
||||
# Expand 'a' to full matrix
|
||||
stages = len(c)
|
||||
full_a = np.zeros((stages, stages))
|
||||
for i, row in enumerate(a):
|
||||
full_a[i + 1, : len(row)] = row
|
||||
|
||||
return full_a, np.array(b), np.array(c)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
# 2. Sigma Schedule
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# We handle multi-history expansion
|
||||
_a_mat, _b_vec, c_vec = self._get_tableau()
|
||||
len(c_vec)
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
for c_val in c_vec:
|
||||
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0)
|
||||
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
a_mat, b_vec, c_vec = self._get_tableau()
|
||||
num_stages = len(c_vec)
|
||||
|
||||
stage_index = self._step_index % num_stages
|
||||
base_step_index = (self._step_index // num_stages) * num_stages
|
||||
|
||||
sigma_curr = self.sigmas[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
|
||||
sigma_next = self.sigmas[sigma_next_idx]
|
||||
|
||||
if sigma_next <= 0:
|
||||
sigma_t = self.sigmas[self._step_index]
|
||||
denoised = sample - sigma_t * model_output if self.config.prediction_type == "epsilon" else model_output
|
||||
prev_sample = denoised
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
h = sigma_next - sigma_curr
|
||||
sigma_t = self.sigmas[self._step_index]
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type {self.config.prediction_type} is not supported.")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self.sigmas[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
else:
|
||||
self.model_outputs.append(derivative)
|
||||
|
||||
next_stage_idx = stage_index + 1
|
||||
if next_stage_idx < num_stages:
|
||||
sum_ak = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
|
||||
|
||||
sigma_next_stage = self.sigmas[self._step_index + 1]
|
||||
|
||||
# Update x (unnormalized sample)
|
||||
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
|
||||
else:
|
||||
sum_bk = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_bk
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,321 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
# pylint: disable=no-member
|
||||
class LobattoScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
LobattoScheduler: High-accuracy implicit integrators from the Lobatto family.
|
||||
Supports variants IIIA, IIIB, IIIC, IIIC*, IIID (orders 2, 3, 4).
|
||||
Adapted from the RES4LYF repository.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: str = "lobatto_iiia_3s", # Available: iiia, iiib, iiic
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
def _get_tableau(self):
|
||||
v = self.config.variant
|
||||
r5 = 5**0.5
|
||||
if v == "lobatto_iiia_2s":
|
||||
a, b, c = [[0, 0], [1 / 2, 1 / 2]], [1 / 2, 1 / 2], [0, 1]
|
||||
elif v == "lobatto_iiia_3s":
|
||||
a, b, c = [[0, 0, 0], [5 / 24, 1 / 3, -1 / 24], [1 / 6, 2 / 3, 1 / 6]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
|
||||
elif v == "lobatto_iiia_4s":
|
||||
a = [[0, 0, 0, 0], [(11 + r5) / 120, (25 - r5) / 120, (25 - 13 * r5) / 120, (-1 + r5) / 120], [(11 - r5) / 120, (25 + 13 * r5) / 120, (25 + r5) / 120, (-1 - r5) / 120], [1 / 12, 5 / 12, 5 / 12, 1 / 12]]
|
||||
b = [1 / 12, 5 / 12, 5 / 12, 1 / 12]
|
||||
c = [0, (5 - r5) / 10, (5 + r5) / 10, 1]
|
||||
elif v == "lobatto_iiib_2s":
|
||||
a, b, c = [[1 / 2, 0], [1 / 2, 0]], [1 / 2, 1 / 2], [0, 1]
|
||||
elif v == "lobatto_iiib_3s":
|
||||
a, b, c = [[1 / 6, -1 / 6, 0], [1 / 6, 1 / 3, 0], [1 / 6, 5 / 6, 0]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
|
||||
elif v == "lobatto_iiic_2s":
|
||||
a, b, c = [[1 / 2, -1 / 2], [1 / 2, 1 / 2]], [1 / 2, 1 / 2], [0, 1]
|
||||
elif v == "lobatto_iiic_3s":
|
||||
a, b, c = [[1 / 6, -1 / 3, 1 / 6], [1 / 6, 5 / 12, -1 / 12], [1 / 6, 2 / 3, 1 / 6]], [1 / 6, 2 / 3, 1 / 6], [0, 1 / 2, 1]
|
||||
elif v == "kraaijevanger_spijker_2s":
|
||||
a, b, c = [[1 / 2, 0], [-1 / 2, 2]], [-1 / 2, 3 / 2], [1 / 2, 3 / 2]
|
||||
elif v == "qin_zhang_2s":
|
||||
a, b, c = [[1 / 4, 0], [1 / 2, 1 / 4]], [1 / 2, 1 / 2], [1 / 4, 3 / 4]
|
||||
elif v == "pareschi_russo_2s":
|
||||
gamma = 1 - 2**0.5 / 2
|
||||
a, b, c = [[gamma, 0], [1 - 2 * gamma, gamma]], [1 / 2, 1 / 2], [gamma, 1 - gamma]
|
||||
else:
|
||||
raise ValueError(f"Unknown variant: {v}")
|
||||
return np.array(a), np.array(b), np.array(c)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
# 2. Sigma Schedule
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# We handle multi-history expansion
|
||||
_a_mat, _b_vec, c_vec = self._get_tableau()
|
||||
len(c_vec)
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
for c_val in c_vec:
|
||||
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0) # Add the final sigma=0 for the last step
|
||||
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
a_mat, b_vec, c_vec = self._get_tableau()
|
||||
num_stages = len(c_vec)
|
||||
|
||||
stage_index = self._step_index % num_stages
|
||||
base_step_index = (self._step_index // num_stages) * num_stages
|
||||
|
||||
sigma_curr = self.sigmas[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
|
||||
sigma_next = self.sigmas[sigma_next_idx]
|
||||
|
||||
if sigma_next <= 0:
|
||||
sigma_t = self.sigmas[self._step_index]
|
||||
denoised = sample - sigma_t * model_output if self.config.prediction_type == "epsilon" else model_output
|
||||
prev_sample = denoised
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
h = sigma_next - sigma_curr
|
||||
sigma_t = self.sigmas[self._step_index]
|
||||
|
||||
if self.config.prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {getattr(self.config, 'prediction_type', 'epsilon')}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self.sigmas[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
else:
|
||||
self.model_outputs.append(derivative)
|
||||
|
||||
next_stage_idx = stage_index + 1
|
||||
if next_stage_idx < num_stages:
|
||||
sum_ak = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
|
||||
|
||||
sigma_next_stage = self.sigmas[self._step_index + 1]
|
||||
|
||||
# Update x (unnormalized sample)
|
||||
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
|
||||
else:
|
||||
sum_bk = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_bk
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .phi_functions import Phi
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PECScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Predictor-Corrector (PEC) scheduler.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["pec423_2h2s", "pec433_2h3s"] = "pec423_2h2s",
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Buffer for multistep
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
# This x0 is actually a * x0 in discrete NSR space
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# This x0 is the true clean x0
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
self.model_outputs.append(model_output)
|
||||
self.x0_outputs.append(x0)
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
variant = self.config.variant
|
||||
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
|
||||
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# PEC coefficients (anchored at x0)
|
||||
if variant == "pec423_2h2s":
|
||||
if len(self.x0_outputs) < 2:
|
||||
res = phi(1) * x0
|
||||
else:
|
||||
x0_n, x0_p1 = self.x0_outputs[-2:]
|
||||
b2 = (1/3)*phi(2) + phi(3) + phi(4)
|
||||
b1 = phi(1) - b2
|
||||
res = b1 * x0_n + b2 * x0_p1
|
||||
elif variant == "pec433_2h3s":
|
||||
if len(self.x0_outputs) < 3:
|
||||
res = phi(1) * x0
|
||||
else:
|
||||
x0_n, x0_p1, x0_p2 = self.x0_outputs[-3:]
|
||||
b3 = (1/3)*phi(2) + phi(3) + phi(4)
|
||||
b2 = 0
|
||||
b1 = phi(1) - b3
|
||||
res = b1 * x0_n + b2 * x0_p1 + b3 * x0_p2
|
||||
else:
|
||||
res = phi(1) * x0
|
||||
|
||||
# Update in x-space
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
# Buffer control
|
||||
limit = 3 if variant == "pec433_2h3s" else 2
|
||||
if len(self.x0_outputs) > limit:
|
||||
self.x0_outputs.pop(0)
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mpmath import exp as mp_exp
|
||||
from mpmath import factorial as mp_factorial
|
||||
from mpmath import mp, mpf
|
||||
|
||||
# Set precision for mpmath
|
||||
mp.dps = 80
|
||||
|
||||
|
||||
def calculate_gamma(c2: float, c3: float) -> float:
|
||||
"""Calculates the gamma parameter for RES 3s samplers."""
|
||||
return (3 * (c3**3) - 2 * c3) / (c2 * (2 - 3 * c2))
|
||||
|
||||
|
||||
def _torch_factorial(n: int) -> float:
|
||||
return float(math.factorial(n))
|
||||
|
||||
|
||||
def phi_standard_torch(j: int, neg_h: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Standard implementation of phi functions using torch.
|
||||
ϕj(-h) = (e^(-h) - \sum_{k=0}^{j-1} (-h)^k / k!) / (-h)^j
|
||||
For h=0, ϕj(0) = 1/j!
|
||||
"""
|
||||
assert j > 0
|
||||
|
||||
# Handle h=0 case
|
||||
if torch.all(neg_h == 0):
|
||||
return torch.full_like(neg_h, 1.0 / _torch_factorial(j))
|
||||
|
||||
# We use double precision for the series to avoid early overflow/precision loss
|
||||
orig_dtype = neg_h.dtype
|
||||
neg_h = neg_h.to(torch.float64)
|
||||
|
||||
# For very small h, use series expansion to avoid 0/0
|
||||
if torch.any(torch.abs(neg_h) < 1e-4):
|
||||
# 1/j! + z/(j+1)! + z^2/(2!(j+2)!) ...
|
||||
result = torch.full_like(neg_h, 1.0 / _torch_factorial(j))
|
||||
term = torch.full_like(neg_h, 1.0 / _torch_factorial(j))
|
||||
for k in range(1, 5):
|
||||
term = term * neg_h / (j + k)
|
||||
result += term
|
||||
return result.to(orig_dtype)
|
||||
|
||||
remainder = torch.zeros_like(neg_h)
|
||||
for k in range(j):
|
||||
remainder += (neg_h**k) / _torch_factorial(k)
|
||||
|
||||
phi_val = (torch.exp(neg_h) - remainder) / (neg_h**j)
|
||||
return phi_val.to(orig_dtype)
|
||||
|
||||
|
||||
def phi_mpmath_series(j: int, neg_h: float) -> float:
|
||||
"""Arbitrary-precision phi_j(-h) via series definition."""
|
||||
j = int(j)
|
||||
z = mpf(float(neg_h))
|
||||
|
||||
# Handle h=0 case: phi_j(0) = 1/j!
|
||||
if z == 0:
|
||||
return float(1.0 / mp_factorial(j))
|
||||
|
||||
s_val = mp.mpf("0")
|
||||
for k in range(j):
|
||||
s_val += (z**k) / mp_factorial(k)
|
||||
phi_val = (mp_exp(z) - s_val) / (z**j)
|
||||
return float(phi_val)
|
||||
|
||||
|
||||
class Phi:
|
||||
"""
|
||||
Class to manage phi function calculations and caching.
|
||||
Supports both standard torch-based and high-precision mpmath-based solutions.
|
||||
"""
|
||||
|
||||
def __init__(self, h: torch.Tensor, c: List[Union[float, mpf]], analytic_solution: bool = True):
|
||||
self.h = h
|
||||
self.c = c
|
||||
self.cache: Dict[Tuple[int, int], Union[float, torch.Tensor]] = {}
|
||||
self.analytic_solution = analytic_solution
|
||||
|
||||
if analytic_solution:
|
||||
self.phi_f = phi_mpmath_series
|
||||
self.h_mpf = mpf(float(h))
|
||||
self.c_mpf = [mpf(float(c_val)) for c_val in c]
|
||||
else:
|
||||
self.phi_f = phi_standard_torch
|
||||
|
||||
def __call__(self, j: int, i: int = -1) -> Union[float, torch.Tensor]:
|
||||
if (j, i) in self.cache:
|
||||
return self.cache[(j, i)]
|
||||
|
||||
if i < 0:
|
||||
c_val = 1.0
|
||||
else:
|
||||
c_val = self.c[i - 1]
|
||||
if c_val == 0:
|
||||
self.cache[(j, i)] = 0.0
|
||||
return 0.0
|
||||
|
||||
if self.analytic_solution:
|
||||
h_val = self.h_mpf
|
||||
c_mapped = self.c_mpf[i - 1] if i >= 0 else 1.0
|
||||
|
||||
if j == 0:
|
||||
result = float(mp_exp(-h_val * c_mapped))
|
||||
else:
|
||||
# Use the mpmath internal function for higher precision
|
||||
z = -h_val * c_mapped
|
||||
if z == 0:
|
||||
result = float(1.0 / mp_factorial(j))
|
||||
else:
|
||||
s_val = mp.mpf("0")
|
||||
for k in range(j):
|
||||
s_val += (z**k) / mp_factorial(k)
|
||||
result = float((mp_exp(z) - s_val) / (z**j))
|
||||
else:
|
||||
h_val = self.h
|
||||
c_mapped = float(c_val)
|
||||
|
||||
if j == 0:
|
||||
result = torch.exp(-h_val * c_mapped)
|
||||
else:
|
||||
result = self.phi_f(j, -h_val * c_mapped)
|
||||
|
||||
self.cache[(j, i)] = result
|
||||
return result
|
||||
|
|
@ -0,0 +1,364 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
# pylint: disable=no-member
|
||||
class RadauIIAScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RadauIIAScheduler: Fully implicit Runge-Kutta integrators.
|
||||
Supports variants with 2, 3, 5, 7, 9, 11 stages.
|
||||
Adapted from the RES4LYF repository.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
variant: str = "radau_iia_3s", # 2s to 11s variants
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
def _get_tableau(self):
|
||||
v = self.config.variant
|
||||
if v == "radau_iia_2s":
|
||||
a, b, c = [[5 / 12, -1 / 12], [3 / 4, 1 / 4]], [3 / 4, 1 / 4], [1 / 3, 1]
|
||||
elif v == "radau_iia_3s":
|
||||
r6 = 6**0.5
|
||||
a = [[11 / 45 - 7 * r6 / 360, 37 / 225 - 169 * r6 / 1800, -2 / 225 + r6 / 75], [37 / 225 + 169 * r6 / 1800, 11 / 45 + 7 * r6 / 360, -2 / 225 - r6 / 75], [4 / 9 - r6 / 36, 4 / 9 + r6 / 36, 1 / 9]]
|
||||
b, c = [4 / 9 - r6 / 36, 4 / 9 + r6 / 36, 1 / 9], [2 / 5 - r6 / 10, 2 / 5 + r6 / 10, 1]
|
||||
elif v == "radau_iia_5s":
|
||||
a = [
|
||||
[0.07299886, -0.02673533, 0.01867693, -0.01287911, 0.00504284],
|
||||
[0.15377523, 0.14621487, -0.03644457, 0.02123306, -0.00793558],
|
||||
[0.14006305, 0.29896713, 0.16758507, -0.03396910, 0.01094429],
|
||||
[0.14489431, 0.27650007, 0.32579792, 0.12875675, -0.01570892],
|
||||
[0.14371356, 0.28135602, 0.31182652, 0.22310390, 0.04000000],
|
||||
]
|
||||
b = [0.14371356, 0.28135602, 0.31182652, 0.22310390, 0.04]
|
||||
c = [0.05710420, 0.27684301, 0.58359043, 0.86024014, 1.0]
|
||||
elif v == "radau_iia_7s":
|
||||
a = [
|
||||
[0.03754626, -0.01403933, 0.01035279, -0.00815832, 0.00638841, -0.00460233, 0.00182894],
|
||||
[0.08014760, 0.08106206, -0.02123799, 0.01400029, -0.01023419, 0.00715347, -0.00281264],
|
||||
[0.07206385, 0.17106835, 0.10961456, -0.02461987, 0.01476038, -0.00957526, 0.00367268],
|
||||
[0.07570513, 0.15409016, 0.22710774, 0.11747819, -0.02381083, 0.01270999, -0.00460884],
|
||||
[0.07391234, 0.16135561, 0.20686724, 0.23700712, 0.10308679, -0.01885414, 0.00585890],
|
||||
[0.07470556, 0.15830722, 0.21415342, 0.21987785, 0.19875212, 0.06926550, -0.00811601],
|
||||
[0.07449424, 0.15910212, 0.21235189, 0.22355491, 0.19047494, 0.11961374, 0.02040816],
|
||||
]
|
||||
b = [0.07449424, 0.15910212, 0.21235189, 0.22355491, 0.19047494, 0.11961374, 0.02040816]
|
||||
c = [0.02931643, 0.14807860, 0.33698469, 0.55867152, 0.76923386, 0.92694567, 1.0]
|
||||
elif v == "radau_iia_9s":
|
||||
a = [
|
||||
[0.02278838, -0.00858964, 0.00645103, -0.00525753, 0.00438883, -0.00365122, 0.00294049, -0.00214927, 0.00085884],
|
||||
[0.04890795, 0.05070205, -0.01352381, 0.00920937, -0.00715571, 0.00574725, -0.00454258, 0.00328816, -0.00130907],
|
||||
[0.04374276, 0.10830189, 0.07291957, -0.01687988, 0.01070455, -0.00790195, 0.00599141, -0.00424802, 0.00167815],
|
||||
[0.04624924, 0.09656073, 0.15429877, 0.08671937, -0.01845164, 0.01103666, -0.00767328, 0.00522822, -0.00203591],
|
||||
[0.04483444, 0.10230685, 0.13821763, 0.18126393, 0.09043360, -0.01808506, 0.01019339, -0.00640527, 0.00242717],
|
||||
[0.04565876, 0.09914547, 0.14574704, 0.16364828, 0.18594459, 0.08361326, -0.01580994, 0.00813825, -0.00291047],
|
||||
[0.04520060, 0.10085371, 0.14194224, 0.17118947, 0.16978339, 0.16776829, 0.06707903, -0.01179223, 0.00360925],
|
||||
[0.04541652, 0.10006040, 0.14365284, 0.16801908, 0.17556077, 0.15588627, 0.12889391, 0.04281083, -0.00493457],
|
||||
[0.04535725, 0.10027665, 0.14319335, 0.16884698, 0.17413650, 0.15842189, 0.12359469, 0.07382701, 0.01234568],
|
||||
]
|
||||
b = [0.04535725, 0.10027665, 0.14319335, 0.16884698, 0.17413650, 0.15842189, 0.12359469, 0.07382701, 0.01234568]
|
||||
c = [0.01777992, 0.09132361, 0.21430848, 0.37193216, 0.54518668, 0.71317524, 0.85563374, 0.95536604, 1.0]
|
||||
elif v == "radau_iia_11s":
|
||||
a = [
|
||||
[0.01528052, -0.00578250, 0.00438010, -0.00362104, 0.00309298, -0.00267283, 0.00230509, -0.00195565, 0.00159387, -0.00117286, 0.00046993],
|
||||
[0.03288398, 0.03451351, -0.00928542, 0.00641325, -0.00509546, 0.00424609, -0.00358767, 0.00300683, -0.00243267, 0.00178278, -0.00071315],
|
||||
[0.02933250, 0.07416243, 0.05114868, -0.01200502, 0.00777795, -0.00594470, 0.00480266, -0.00392360, 0.00312733, -0.00227314, 0.00090638],
|
||||
[0.03111455, 0.06578995, 0.10929963, 0.06381052, -0.01385359, 0.00855744, -0.00630764, 0.00491336, -0.00381400, 0.00273343, -0.00108397],
|
||||
[0.03005269, 0.07011285, 0.09714692, 0.13539160, 0.07147108, -0.01471024, 0.00873319, -0.00619941, 0.00459164, -0.00321333, 0.00126286],
|
||||
[0.03072807, 0.06751926, 0.10334060, 0.12083526, 0.15032679, 0.07350932, -0.01451288, 0.00829665, -0.00561283, 0.00376623, -0.00145771],
|
||||
[0.03029202, 0.06914472, 0.09972096, 0.12801064, 0.13493180, 0.15289670, 0.06975993, -0.01327455, 0.00725877, -0.00448439, 0.00168785],
|
||||
[0.03056654, 0.06813851, 0.10188107, 0.12403361, 0.14211432, 0.13829395, 0.14289135, 0.06052636, -0.01107774, 0.00559867, -0.00198773],
|
||||
[0.03040663, 0.06871881, 0.10066096, 0.12619527, 0.13848876, 0.14450774, 0.13065189, 0.12111401, 0.04655548, -0.00802620, 0.00243764],
|
||||
[0.03048412, 0.06843925, 0.10124185, 0.12518732, 0.14011843, 0.14190387, 0.13500343, 0.11262870, 0.08930604, 0.02896966, -0.00331170],
|
||||
[0.03046255, 0.06851684, 0.10108155, 0.12546269, 0.13968067, 0.14258278, 0.13393354, 0.11443306, 0.08565881, 0.04992304, 0.00826446],
|
||||
]
|
||||
b = [0.03046255, 0.06851684, 0.10108155, 0.12546269, 0.13968067, 0.14258278, 0.13393354, 0.11443306, 0.08565881, 0.04992304, 0.00826446]
|
||||
c = [0.01191761, 0.06173207, 0.14711145, 0.26115968, 0.39463985, 0.53673877, 0.67594446, 0.80097892, 0.90171099, 0.96997097, 1.0]
|
||||
else:
|
||||
raise ValueError(f"Unknown variant: {v}")
|
||||
return np.array(a), np.array(b), np.array(c)
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
# 2. Sigma Schedule
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# We handle multi-history expansion
|
||||
_a_mat, _b_vec, c_vec = self._get_tableau()
|
||||
len(c_vec)
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
for c_val in c_vec:
|
||||
sigmas_expanded.append(s_curr + c_val * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0)
|
||||
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
if isinstance(schedule_timesteps, torch.Tensor):
|
||||
schedule_timesteps = schedule_timesteps.detach().cpu().numpy()
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.detach().cpu().numpy()
|
||||
|
||||
return np.abs(schedule_timesteps - timestep).argmin().item()
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
a_mat, b_vec, c_vec = self._get_tableau()
|
||||
num_stages = len(c_vec)
|
||||
|
||||
stage_index = self._step_index % num_stages
|
||||
base_step_index = (self._step_index // num_stages) * num_stages
|
||||
|
||||
sigma_curr = self.sigmas[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + num_stages, len(self.sigmas) - 1)
|
||||
sigma_next = self.sigmas[sigma_next_idx]
|
||||
|
||||
if sigma_next <= 0:
|
||||
sigma_t = self.sigmas[self._step_index]
|
||||
denoised = sample - sigma_t * model_output if getattr(self.config, "prediction_type", "epsilon") == "epsilon" else model_output
|
||||
prev_sample = denoised
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
h = sigma_next - sigma_curr
|
||||
sigma_t = self.sigmas[self._step_index]
|
||||
|
||||
prediction_type = getattr(self.config, "prediction_type", "epsilon")
|
||||
if prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {getattr(self.config, 'prediction_type', 'epsilon')}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self.sigmas[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
else:
|
||||
self.model_outputs.append(derivative)
|
||||
|
||||
next_stage_idx = stage_index + 1
|
||||
if next_stage_idx < num_stages:
|
||||
sum_ak = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_ak = sum_ak + a_mat[next_stage_idx][j] * self.model_outputs[j]
|
||||
|
||||
sigma_next_stage = self.sigmas[self._step_index + 1]
|
||||
|
||||
# Update x (unnormalized sample)
|
||||
prev_sample = self.sample_at_start_of_step + (sigma_next_stage - sigma_curr) * sum_ak
|
||||
else:
|
||||
sum_bk = 0
|
||||
for j in range(len(self.model_outputs)):
|
||||
sum_bk = sum_bk + b_vec[j] * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_bk
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,451 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .phi_functions import Phi, calculate_gamma
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RESMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RESMultistepScheduler (Restartable Exponential Integrator) ported from RES4LYF.
|
||||
|
||||
Supports RES 2M, 3M and DEIS 2M, 3M variants.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to "linear"):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
prediction_type (`str`, defaults to "epsilon"):
|
||||
The prediction type of the scheduler function.
|
||||
variant (`str`, defaults to "res_2m"):
|
||||
The specific RES/DEIS variant to use. Supported: "res_2m", "res_3m", "deis_2m", "deis_3m".
|
||||
use_analytic_solution (`bool`, defaults to True):
|
||||
Whether to use high-precision analytic solutions for phi functions.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["res_2m", "res_3m", "deis_2m", "deis_3m"] = "res_2m",
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Buffer for multistep
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
# Linear remapping for Flow Matching
|
||||
if self.config.use_flow_sigmas:
|
||||
# Standardize linear spacing
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
# Already handled above, ensuring variable consistency
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
if self.config.use_flow_sigmas:
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
self.lower_order_nums = 0
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step = self._step_index
|
||||
sigma = self.sigmas[step]
|
||||
sigma_next = self.sigmas[step + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0 (Matching PEC pattern)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type {self.config.prediction_type} is not supported.")
|
||||
|
||||
self.model_outputs.append(model_output)
|
||||
self.x0_outputs.append(x0)
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
# Order logic
|
||||
variant = self.config.variant
|
||||
order = int(variant[-2]) if variant.endswith("m") else 1
|
||||
|
||||
# Effective order for current step
|
||||
curr_order = min(len(self.prev_sigmas), order) if sigma > 0 else 1
|
||||
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
# Variable Step Adams-Bashforth for Flow Matching
|
||||
dt = sigma_next - sigma
|
||||
v_n = model_output
|
||||
|
||||
if curr_order == 1:
|
||||
x_next = sample + dt * v_n
|
||||
elif curr_order == 2:
|
||||
# AB2
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma - sigma_prev
|
||||
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
|
||||
|
||||
# Stability check
|
||||
if dt_prev == 0 or r < -0.9 or r > 2.0: # Fallback
|
||||
x_next = sample + dt * v_n
|
||||
else:
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
|
||||
elif curr_order >= 3:
|
||||
# Re-implement AB2 logic
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma - sigma_prev
|
||||
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
|
||||
|
||||
self._step_index += 1
|
||||
if len(self.model_outputs) > order:
|
||||
self.model_outputs.pop(0)
|
||||
self.x0_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
# Exponential Integrator Setup
|
||||
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
|
||||
phi_1 = phi(1)
|
||||
|
||||
if variant.startswith("res"):
|
||||
# Force Order 1 at the end of schedule
|
||||
if self.num_inference_steps is not None and self._step_index >= self.num_inference_steps - 3:
|
||||
curr_order = 1
|
||||
|
||||
if curr_order == 2:
|
||||
h_prev = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-2])
|
||||
elif curr_order == 3:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Exponential Integrator Update in x-space
|
||||
if curr_order == 1:
|
||||
res = phi_1 * x0
|
||||
elif curr_order == 2:
|
||||
# b2 = -phi_2 / r
|
||||
# b2 = -phi_2 / r = -phi(2) / (h_prev/h)
|
||||
# Here we use: b2 = phi(2) / ((-h_prev / h) + 1e-9)
|
||||
# Since (-h_prev/h) is negative (-r), this gives correct negative sign for b2.
|
||||
|
||||
# Stability check
|
||||
r_check = h_prev / (h + 1e-9) # This is effectively -r if using h_prev definition above?
|
||||
# Wait, h_prev above is -log(). Positive.
|
||||
# h is positive.
|
||||
# So h_prev/h is positive. defined as r in other files.
|
||||
# But here code uses -h_prev / h in denominator.
|
||||
|
||||
# Stability check
|
||||
r_check = h_prev / (h + 1e-9)
|
||||
|
||||
# Hard Restart
|
||||
if r_check < 0.5 or r_check > 2.0:
|
||||
res = phi_1 * x0
|
||||
else:
|
||||
b2 = phi(2) / ((-h_prev / h) + 1e-9)
|
||||
b1 = phi_1 - b2
|
||||
res = b1 * self.x0_outputs[-1] + b2 * self.x0_outputs[-2]
|
||||
elif curr_order == 3:
|
||||
# Generalized AB3 for Exponential Integrators
|
||||
h_p1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
h_p2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
|
||||
r1 = h_p1 / (h + 1e-9)
|
||||
r2 = h_p2 / (h + 1e-9)
|
||||
|
||||
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
|
||||
res = phi_1 * x0
|
||||
else:
|
||||
phi_2, phi_3 = phi(2), phi(3)
|
||||
denom = r2 - r1 + 1e-9
|
||||
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
|
||||
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
|
||||
b1 = phi_1 - b2 - b3
|
||||
res = b1 * self.x0_outputs[-1] + b2 * self.x0_outputs[-2] + b3 * self.x0_outputs[-3]
|
||||
else:
|
||||
res = phi_1 * x0
|
||||
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
|
||||
else:
|
||||
# DEIS logic (Linear multistep in log-sigma space)
|
||||
b = self._get_deis_coefficients(curr_order, sigma, sigma_next)
|
||||
|
||||
# For DEIS, we apply b to the denoised estimates
|
||||
res = torch.zeros_like(sample)
|
||||
for i, b_val in enumerate(b[0]):
|
||||
idx = len(self.x0_outputs) - 1 - i
|
||||
if idx >= 0:
|
||||
res += b_val * self.x0_outputs[idx]
|
||||
|
||||
# DEIS update in x-space
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if len(self.model_outputs) > order:
|
||||
self.model_outputs.pop(0)
|
||||
self.x0_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _get_res_coefficients(self, rk_type, h, c2, c3):
|
||||
ci = [0, c2, c3]
|
||||
phi = Phi(h, ci, getattr(self.config, "use_analytic_solution", True))
|
||||
|
||||
if rk_type == "res_2s":
|
||||
b2 = phi(2) / (c2 + 1e-9)
|
||||
b = [[phi(1) - b2, b2]]
|
||||
a = [[0, 0], [c2 * phi(1, 2), 0]]
|
||||
elif rk_type == "res_3s":
|
||||
gamma_val = calculate_gamma(c2, c3)
|
||||
b3 = phi(2) / (gamma_val * c2 + c3 + 1e-9)
|
||||
b2 = gamma_val * b3
|
||||
b = [[phi(1) - (b2 + b3), b2, b3]]
|
||||
a = [] # Simplified
|
||||
else:
|
||||
b = [[phi(1)]]
|
||||
a = [[0]]
|
||||
return a, b, ci
|
||||
|
||||
def _get_deis_coefficients(self, order, sigma, sigma_next):
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
phi = Phi(h, [0], getattr(self.config, "use_analytic_solution", True))
|
||||
phi_1 = phi(1)
|
||||
|
||||
if order == 1:
|
||||
return [[phi_1]]
|
||||
elif order == 2:
|
||||
h_prev = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
r = h_prev / (h + 1e-9)
|
||||
|
||||
# Correct Adams-Bashforth-like coefficients for Exponential Integrators
|
||||
|
||||
# Hard Restart for stability
|
||||
if r < 0.5 or r > 2.0:
|
||||
return [[phi_1]]
|
||||
|
||||
phi_2 = phi(2)
|
||||
b2 = -phi_2 / (r + 1e-9)
|
||||
b1 = phi_1 - b2
|
||||
return [[b1, b2]]
|
||||
elif order == 3:
|
||||
h_prev1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
h_prev2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
|
||||
r1 = h_prev1 / (h + 1e-9)
|
||||
r2 = h_prev2 / (h + 1e-9)
|
||||
|
||||
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
|
||||
return [[phi_1]]
|
||||
|
||||
phi_2 = phi(2)
|
||||
phi_3 = phi(3)
|
||||
|
||||
# Generalized AB3 for Exponential Integrators (Varying steps)
|
||||
denom = r2 - r1 + 1e-9
|
||||
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
|
||||
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
|
||||
b1 = phi_1 - (b2 + b3)
|
||||
return [[b1, b2, b3]]
|
||||
else:
|
||||
return [[phi_1]]
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,330 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RESMultistepSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RESMultistepSDEScheduler (Stochastic Exponential Integrator) ported from RES4LYF.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
variant (`str`, defaults to "res_2m"):
|
||||
The specific RES/DEIS variant to use. Supported: "res_2m", "res_3m".
|
||||
eta (`float`, defaults to 1.0):
|
||||
The amount of noise to add during sampling (0.0 for ODE, 1.0 for full SDE).
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["res_2m", "res_3m"] = "res_2m",
|
||||
eta: float = 1.0,
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Buffer for multistep
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step = self._step_index
|
||||
sigma = self.sigmas[step]
|
||||
sigma_next = self.sigmas[step + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
self.model_outputs.append(model_output)
|
||||
self.x0_outputs.append(x0)
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
# Order logic
|
||||
variant = self.config.variant
|
||||
order = int(variant[-2]) if variant.endswith("m") else 1
|
||||
|
||||
# Effective order for current step
|
||||
curr_order = min(len(self.prev_sigmas), order)
|
||||
|
||||
# REiS Multistep logic
|
||||
c2, c3 = 0.5, 1.0
|
||||
if curr_order == 2:
|
||||
h_prev = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-2])
|
||||
c2 = (-h_prev / h).item() if h > 0 else 0.5
|
||||
rk_type = "res_2s"
|
||||
elif curr_order == 3:
|
||||
h_prev1 = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-2])
|
||||
h_prev2 = -torch.log(self.prev_sigmas[-1] / self.prev_sigmas[-3])
|
||||
c2 = (-h_prev1 / h).item() if h > 0 else 0.5
|
||||
c3 = (-h_prev2 / h).item() if h > 0 else 1.0
|
||||
rk_type = "res_3s"
|
||||
else:
|
||||
rk_type = "res_1s"
|
||||
|
||||
if curr_order == 1:
|
||||
rk_type = "res_1s"
|
||||
_a, b, _ci = self._get_res_coefficients(rk_type, h, c2, c3)
|
||||
|
||||
# Apply coefficients to get multistep x_0
|
||||
res = torch.zeros_like(sample)
|
||||
for i, b_val in enumerate(b[0]):
|
||||
idx = len(self.x0_outputs) - 1 - i
|
||||
if idx >= 0:
|
||||
res += b_val * self.x0_outputs[idx]
|
||||
|
||||
# SDE stochastic step
|
||||
eta = self.config.eta
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# Ancestral SDE logic:
|
||||
# 1. Calculate sigma_up and sigma_down to preserve variance
|
||||
# sigma_up = eta * sigma_next * sqrt(1 - (sigma_next/sigma)^2)
|
||||
# sigma_down = sqrt(sigma_next^2 - sigma_up^2)
|
||||
|
||||
sigma_up = eta * (sigma_next**2 * (sigma**2 - sigma_next**2) / (sigma**2 + 1e-9))**0.5
|
||||
sigma_down = (sigma_next**2 - sigma_up**2)**0.5
|
||||
|
||||
# 2. Take deterministic step to sigma_down
|
||||
h_det = -torch.log(sigma_down / sigma) if sigma > 0 and sigma_down > 0 else h
|
||||
|
||||
# Re-calculate coefficients for h_det
|
||||
_a, b_det, _ci = self._get_res_coefficients(rk_type, h_det, c2, c3)
|
||||
res_det = torch.zeros_like(sample)
|
||||
for i, b_val in enumerate(b_det[0]):
|
||||
idx = len(self.x0_outputs) - 1 - i
|
||||
if idx >= 0:
|
||||
res_det += b_val * self.x0_outputs[idx]
|
||||
|
||||
x_det = torch.exp(-h_det) * sample + h_det * res_det
|
||||
|
||||
# 3. Add noise scaled by sigma_up
|
||||
if eta > 0:
|
||||
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
|
||||
x_next = x_det + sigma_up * noise
|
||||
else:
|
||||
x_next = x_det
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if len(self.x0_outputs) > order:
|
||||
self.x0_outputs.pop(0)
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _get_res_coefficients(self, rk_type, h, c2, c3):
|
||||
from .phi_functions import Phi, calculate_gamma
|
||||
ci = [0, c2, c3]
|
||||
phi = Phi(h, ci, self.config.use_analytic_solution)
|
||||
|
||||
if rk_type == "res_2s":
|
||||
b2 = phi(2) / (c2 + 1e-9)
|
||||
b = [[phi(1) - b2, b2]]
|
||||
a = [[0, 0], [c2 * phi(1, 2), 0]]
|
||||
elif rk_type == "res_3s":
|
||||
gamma_val = calculate_gamma(c2, c3)
|
||||
b3 = phi(2) / (gamma_val * c2 + c3 + 1e-9)
|
||||
b2 = gamma_val * b3
|
||||
b = [[phi(1) - (b2 + b3), b2, b3]]
|
||||
a = []
|
||||
else:
|
||||
b = [[phi(1)]]
|
||||
a = [[0]]
|
||||
return a, b, ci
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,243 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RESSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RESSinglestepScheduler (Multistage Exponential Integrator) ported from RES4LYF.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["res_2s", "res_3s", "res_5s", "res_6s"] = "res_2s",
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
|
||||
# Linear remapping logic
|
||||
if self.config.use_flow_sigmas:
|
||||
# Logic handled here
|
||||
pass
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
if not self.config.use_flow_sigmas:
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
if self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
if self.config.use_flow_sigmas:
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step = self._step_index
|
||||
sigma = self.sigmas[step]
|
||||
sigma_next = self.sigmas[step + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0 (Matching PEC pattern)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
dt = sigma_next - sigma
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
# Exponential Integrator Update
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# For singlestep RES (multistage), a proper RK requires model evals at intermediate ci * h.
|
||||
# Here we provide the standard 1st order update as a base.
|
||||
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,237 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RESSinglestepSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RESSinglestepSDEScheduler (Stochastic Multistage Exponential Integrator) ported from RES4LYF.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
prediction_type: str = "epsilon",
|
||||
variant: Literal["res_2s", "res_3s", "res_5s", "res_6s"] = "res_2s",
|
||||
eta: float = 1.0,
|
||||
use_analytic_solution: bool = True,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {self.config.timestep_spacing} is not supported.")
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step = self._step_index
|
||||
sigma = self.sigmas[step]
|
||||
sigma_next = self.sigmas[step + 1]
|
||||
eta = self.config.eta
|
||||
|
||||
# RECONSTRUCT X0
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1)**0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
# Exponential Integrator Update (Deterministic Part)
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# Ancestral SDE logic
|
||||
sigma_up = eta * (sigma_next**2 * (sigma**2 - sigma_next**2) / (sigma**2 + 1e-9))**0.5
|
||||
sigma_down = (sigma_next**2 - sigma_up**2)**0.5
|
||||
|
||||
h_det = -torch.log(sigma_down / sigma) if sigma > 0 and sigma_down > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# Deterministic update to sigma_down
|
||||
x_det = torch.exp(-h_det) * sample + (1 - torch.exp( -h_det)) * x0
|
||||
|
||||
# Stochastic part
|
||||
if eta > 0:
|
||||
noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
|
||||
x_next = x_det + sigma_up * noise
|
||||
else:
|
||||
x_next = x_det
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
from typing import ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from .phi_functions import Phi
|
||||
|
||||
|
||||
class RESUnifiedScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RESUnifiedScheduler (Exponential Integrator) ported from RES4LYF.
|
||||
Supports RES 2M, 3M, 2S, 3S, 5S, 6S
|
||||
Supports DEIS 1S, 2M, 3M
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order: ClassVar[int] = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
prediction_type: str = "epsilon",
|
||||
rk_type: str = "res_2m",
|
||||
use_analytic_solution: bool = True,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.sigmas = torch.Tensor([])
|
||||
self.timesteps = torch.Tensor([])
|
||||
self.model_outputs = []
|
||||
self.x0_outputs = []
|
||||
self.prev_sigmas = []
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
def set_sigmas(self, sigmas: torch.Tensor):
|
||||
self.sigmas = sigmas
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timestep_spacing = getattr(self.config, "timestep_spacing", "linspace")
|
||||
steps_offset = getattr(self.config, "steps_offset", 0)
|
||||
|
||||
if timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += steps_offset
|
||||
elif timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {timestep_spacing} is not supported.")
|
||||
|
||||
# Derived sigma range from alphas_cumprod
|
||||
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = base_sigmas[::-1].copy() # Ensure high to low
|
||||
|
||||
if getattr(self.config, "use_karras_sigmas", False):
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_exponential_sigmas", False):
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_beta_sigmas", False):
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_flow_sigmas", False):
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
else:
|
||||
if self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
else:
|
||||
# Re-sample the base sigmas at the requested steps
|
||||
idx = np.linspace(0, len(base_sigmas) - 1, num_inference_steps)
|
||||
sigmas = np.interp(idx, np.arange(len(base_sigmas)), base_sigmas)[::-1].copy()
|
||||
|
||||
shift = getattr(self.config, "shift", 1.0)
|
||||
use_dynamic_shifting = getattr(self.config, "use_dynamic_shifting", False)
|
||||
if shift != 1.0 or use_dynamic_shifting:
|
||||
if use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
getattr(self.config, "base_shift", 0.5),
|
||||
getattr(self.config, "max_shift", 1.5),
|
||||
getattr(self.config, "base_image_seq_len", 256),
|
||||
getattr(self.config, "max_image_seq_len", 4096),
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
if getattr(self.config, "use_flow_sigmas", False):
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def _get_coefficients(self, sigma, sigma_next):
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 else torch.zeros_like(sigma)
|
||||
phi = Phi(h, [], getattr(self.config, "use_analytic_solution", True))
|
||||
phi_1 = phi(1)
|
||||
phi_2 = phi(2)
|
||||
# phi_2 = phi(2) # Moved inside conditional blocks as needed
|
||||
|
||||
history_len = len(self.x0_outputs)
|
||||
|
||||
# Stability: Force Order 1 for final few steps to prevent degradation at low noise levels
|
||||
if self.num_inference_steps is not None and self._step_index >= self.num_inference_steps - 3:
|
||||
return [phi_1], h
|
||||
|
||||
if self.config.rk_type in ["res_2m", "deis_2m"] and history_len >= 2:
|
||||
h_prev = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
r = h_prev / (h + 1e-9)
|
||||
|
||||
h_prev = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
r = h_prev / (h + 1e-9)
|
||||
|
||||
# Hard Restart: if step sizes vary too wildly, fallback to order 1
|
||||
if r < 0.5 or r > 2.0:
|
||||
return [phi_1], h
|
||||
|
||||
phi_2 = phi(2)
|
||||
# Correct Adams-Bashforth-like coefficients for Exponential Integrators
|
||||
b2 = -phi_2 / (r + 1e-9)
|
||||
b1 = phi_1 - b2
|
||||
return [b1, b2], h
|
||||
elif self.config.rk_type in ["res_3m", "deis_3m"] and history_len >= 3:
|
||||
h_prev1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
h_prev2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
|
||||
r1 = h_prev1 / (h + 1e-9)
|
||||
r2 = h_prev2 / (h + 1e-9)
|
||||
|
||||
h_prev1 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-2] + 1e-9))
|
||||
h_prev2 = -torch.log(self.prev_sigmas[-1] / (self.prev_sigmas[-3] + 1e-9))
|
||||
r1 = h_prev1 / (h + 1e-9)
|
||||
r2 = h_prev2 / (h + 1e-9)
|
||||
|
||||
# Hard Restart check
|
||||
if r1 < 0.5 or r1 > 2.0 or r2 < 0.5 or r2 > 2.0:
|
||||
return [phi_1], h
|
||||
|
||||
phi_2 = phi(2)
|
||||
phi_3 = phi(3)
|
||||
|
||||
# Generalized AB3 for Exponential Integrators (Varying steps)
|
||||
denom = r2 - r1 + 1e-9
|
||||
b3 = (phi_3 + r1 * phi_2) / (r2 * denom)
|
||||
b2 = -(phi_3 + r2 * phi_2) / (r1 * denom)
|
||||
b1 = phi_1 - (b2 + b3)
|
||||
return [b1, b2, b3], h
|
||||
|
||||
return [phi_1], h
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sigma_next = self.sigmas[self._step_index + 1]
|
||||
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
|
||||
# RECONSTRUCT X0 (Matching PEC pattern)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1) ** 0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
self.x0_outputs.append(x0)
|
||||
self.model_outputs.append(model_output) # Added for AB support
|
||||
self.prev_sigmas.append(sigma)
|
||||
|
||||
if len(self.x0_outputs) > 3:
|
||||
self.x0_outputs.pop(0)
|
||||
self.model_outputs.pop(0)
|
||||
self.prev_sigmas.pop(0)
|
||||
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
# Variable Step Adams-Bashforth for Flow Matching
|
||||
dt = sigma_next - sigma
|
||||
v_n = model_output
|
||||
|
||||
curr_order = min(len(self.prev_sigmas), 3) # Max order 3 here
|
||||
|
||||
if curr_order == 1:
|
||||
x_next = sample + dt * v_n
|
||||
elif curr_order == 2:
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma - sigma_prev
|
||||
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
|
||||
if dt_prev == 0 or r < -0.9 or r > 2.0:
|
||||
x_next = sample + dt * v_n
|
||||
else:
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
|
||||
else:
|
||||
# AB2 fallback for robustness
|
||||
sigma_prev = self.prev_sigmas[-2]
|
||||
dt_prev = sigma - sigma_prev
|
||||
r = dt / dt_prev if abs(dt_prev) > 1e-8 else 0.0
|
||||
c0 = 1 + 0.5 * r
|
||||
c1 = -0.5 * r
|
||||
x_next = sample + dt * (c0 * v_n + c1 * self.model_outputs[-2])
|
||||
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
# GET COEFFICIENTS
|
||||
b, h_val = self._get_coefficients(sigma, sigma_next)
|
||||
|
||||
if len(b) == 1:
|
||||
res = b[0] * x0
|
||||
elif len(b) == 2:
|
||||
res = b[0] * self.x0_outputs[-1] + b[1] * self.x0_outputs[-2]
|
||||
elif len(b) == 3:
|
||||
res = b[0] * self.x0_outputs[-1] + b[1] * self.x0_outputs[-2] + b[2] * self.x0_outputs[-3]
|
||||
else:
|
||||
res = b[0] * x0
|
||||
|
||||
# UPDATE
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
# Propagate in x-space (unnormalized)
|
||||
x_next = torch.exp(-h) * sample + h * res
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class RiemannianFlowScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Riemannian Flow scheduler using Exponential Integrator step.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order: ClassVar[int] = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
metric_type: Literal["euclidean", "hyperbolic", "spherical", "lorentzian"] = "hyperbolic",
|
||||
curvature: float = 1.0,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = torch.zeros((num_train_timesteps,), dtype=torch.float32)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timestep_spacing = getattr(self.config, "timestep_spacing", "linspace")
|
||||
steps_offset = getattr(self.config, "steps_offset", 0)
|
||||
|
||||
if timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
elif timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
timesteps += steps_offset
|
||||
elif timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy()
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing {timestep_spacing} is not supported.")
|
||||
|
||||
# Derived sigma range from alphas_cumprod
|
||||
# In FM, we usually go from sigma_max to sigma_min
|
||||
base_sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
# Note: alphas_cumprod[0] is ~0.999 (small sigma), alphas_cumprod[-1] is ~0.0001 (large sigma)
|
||||
start_sigma = base_sigmas[-1]
|
||||
end_sigma = base_sigmas[0]
|
||||
|
||||
t = torch.linspace(0, 1, num_inference_steps, device=device)
|
||||
metric_type = self.config.metric_type
|
||||
curvature = self.config.curvature
|
||||
|
||||
if metric_type == "euclidean":
|
||||
result = start_sigma * (1 - t) + end_sigma * t
|
||||
elif metric_type == "hyperbolic":
|
||||
x_start = torch.tanh(torch.tensor(start_sigma / 2, device=device))
|
||||
x_end = torch.tanh(torch.tensor(end_sigma / 2, device=device))
|
||||
d = torch.acosh(torch.clamp(1 + 2 * ((x_start - x_end)**2) / ((1 - x_start**2) * (1 - x_end**2) + 1e-9), min=1.0))
|
||||
lambda_t = torch.sinh(t * d) / (torch.sinh(d) + 1e-9)
|
||||
result = 2 * torch.atanh(torch.clamp((1 - lambda_t) * x_start + lambda_t * x_end, -0.999, 0.999))
|
||||
elif metric_type == "spherical":
|
||||
k = torch.tensor(curvature, device=device)
|
||||
theta_start = start_sigma * torch.sqrt(k)
|
||||
theta_end = end_sigma * torch.sqrt(k)
|
||||
result = torch.sin((1 - t) * theta_start + t * theta_end) / torch.sqrt(k)
|
||||
elif metric_type == "lorentzian":
|
||||
gamma = 1 / torch.sqrt(torch.clamp(1 - curvature * t**2, min=1e-9))
|
||||
result = (start_sigma * (1 - t) + end_sigma * t) * gamma
|
||||
else:
|
||||
result = start_sigma * (1 - t) + end_sigma * t
|
||||
|
||||
result = torch.clamp(result, min=min(start_sigma, end_sigma), max=max(start_sigma, end_sigma))
|
||||
|
||||
if start_sigma > end_sigma:
|
||||
result, _ = torch.sort(result, descending=True)
|
||||
|
||||
sigmas = result.cpu().numpy()
|
||||
|
||||
if getattr(self.config, "use_karras_sigmas", False):
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_exponential_sigmas", False):
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_beta_sigmas", False):
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif getattr(self.config, "use_flow_sigmas", False):
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
shift = getattr(self.config, "shift", 1.0)
|
||||
use_dynamic_shifting = getattr(self.config, "use_dynamic_shifting", False)
|
||||
if shift != 1.0 or use_dynamic_shifting:
|
||||
if use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
getattr(self.config, "base_shift", 0.5),
|
||||
getattr(self.config, "max_shift", 1.5),
|
||||
getattr(self.config, "base_image_seq_len", 256),
|
||||
getattr(self.config, "max_image_seq_len", 4096),
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
# Determine denoised (x_0 prediction)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1)**0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
# Exponential Integrator Update (1st order)
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,251 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class RungeKutta44Scheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RK4: Classical 4th-order Runge-Kutta scheduler.
|
||||
Adapted from the RES4LYF repository.
|
||||
|
||||
This scheduler uses 4 stages per step.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented for RungeKutta44Scheduler")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.sigmas = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state for multi-stage
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._sigmas_cpu = None
|
||||
self._step_index = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Base sigmas
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
|
||||
# 2. Add sub-step sigmas for multi-stage RK
|
||||
# RK4 has c = [0, 1/2, 1/2, 1]
|
||||
c_values = [0.0, 0.5, 0.5, 1.0]
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
# Intermediate sigmas: s_curr + c * (s_next - s_curr)
|
||||
for c in c_values:
|
||||
# Add a tiny epsilon to duplicate sigmas to allow distinct indexing if needed,
|
||||
# but better to rely on internal counter.
|
||||
sigmas_expanded.append(s_curr + c * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0) # terminal sigma
|
||||
|
||||
# 3. Map back to timesteps
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._step_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
# Use argmin for robust float matching
|
||||
index = torch.abs(schedule_timesteps - timestep).argmin().item()
|
||||
return index
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self._sigmas_cpu[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
stage_index = step_index % 4
|
||||
|
||||
# Current and next step interval sigmas
|
||||
base_step_index = (step_index // 4) * 4
|
||||
sigma_curr = self._sigmas_cpu[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + 4, len(self._sigmas_cpu) - 1)
|
||||
sigma_next = self._sigmas_cpu[sigma_next_idx] # The sigma at the end of this 4-stage step
|
||||
|
||||
h = sigma_next - sigma_curr
|
||||
|
||||
sigma_t = self._sigmas_cpu[step_index]
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
|
||||
prediction_type = getattr(self.config, "prediction_type", "epsilon")
|
||||
if prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {prediction_type}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self._sigmas_cpu[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
# Stage 2 input: y + 0.5 * h * k1
|
||||
prev_sample = self.sample_at_start_of_step + 0.5 * h * derivative
|
||||
elif stage_index == 1:
|
||||
self.model_outputs.append(derivative)
|
||||
# Stage 3 input: y + 0.5 * h * k2
|
||||
prev_sample = self.sample_at_start_of_step + 0.5 * h * derivative
|
||||
elif stage_index == 2:
|
||||
self.model_outputs.append(derivative)
|
||||
# Stage 4 input: y + h * k3
|
||||
prev_sample = self.sample_at_start_of_step + h * derivative
|
||||
elif stage_index == 3:
|
||||
self.model_outputs.append(derivative)
|
||||
# Final result: y + (h/6) * (k1 + 2*k2 + 2*k3 + k4)
|
||||
k1, k2, k3, k4 = self.model_outputs
|
||||
prev_sample = self.sample_at_start_of_step + (h / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
|
||||
# Clear state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
# Increment step index
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,299 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class RungeKutta57Scheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RK5_7S: 5th-order Runge-Kutta scheduler with 7 stages.
|
||||
Adapted from the RES4LYF repository.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Internal state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._sigmas_cpu = None
|
||||
self._step_index = None
|
||||
self._timesteps_cpu = None
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float).copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(float)
|
||||
timesteps -= step_ratio
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
# Ensure trailing ends at 0
|
||||
if self.config.timestep_spacing == "trailing":
|
||||
timesteps = np.maximum(timesteps, 0)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# RK5_7s c values: [0, 1/5, 3/10, 4/5, 8/9, 1, 1]
|
||||
c_values = [0, 1 / 5, 3 / 10, 4 / 5, 8 / 9, 1, 1]
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
for c in c_values:
|
||||
sigmas_expanded.append(s_curr + c * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0)
|
||||
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
|
||||
self._timesteps_cpu = self.timesteps.detach().cpu().numpy()
|
||||
self._step_index = None
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self._sigmas_cpu[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Dormand-Prince 5(4) Coefficients
|
||||
a = [
|
||||
[],
|
||||
[1/5],
|
||||
[3/40, 9/40],
|
||||
[44/45, -56/15, 32/9],
|
||||
[19372/6561, -25360/2187, 64448/6561, -212/729],
|
||||
[9017/3168, -355/33, 46732/5247, 49/176, -5103/18656],
|
||||
[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]
|
||||
]
|
||||
b = [35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0]
|
||||
|
||||
step_index = self._step_index
|
||||
stage_index = step_index % 7
|
||||
|
||||
base_step_index = (step_index // 7) * 7
|
||||
sigma_curr = self._sigmas_cpu[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + 7, len(self._sigmas_cpu) - 1)
|
||||
sigma_next = self._sigmas_cpu[sigma_next_idx]
|
||||
h = sigma_next - sigma_curr
|
||||
|
||||
sigma_t = self._sigmas_cpu[step_index]
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
prediction_type = getattr(self.config, "prediction_type", "epsilon")
|
||||
if prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {prediction_type}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self._sigmas_cpu[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
else:
|
||||
self.model_outputs.append(derivative)
|
||||
|
||||
if stage_index < 6:
|
||||
# Predict next stage sample: y_next_stage = y_start + h * sum(a[stage_index+1][j] * k[j])
|
||||
next_a_row = a[stage_index + 1]
|
||||
sum_ak = torch.zeros_like(derivative)
|
||||
for j, weight in enumerate(next_a_row):
|
||||
sum_ak += weight * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_ak
|
||||
else:
|
||||
# Final 7th stage complete, calculate final step
|
||||
sum_bk = torch.zeros_like(derivative)
|
||||
for j, weight in enumerate(b):
|
||||
sum_bk += weight * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_bk
|
||||
|
||||
# Clear state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class RungeKutta67Scheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
RK6_7S: 6th-order Runge-Kutta scheduler with 7 stages.
|
||||
Adapted from the RES4LYF repository.
|
||||
(Note: Defined as 5th order in some contexts, but follows the 7-stage tableau).
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
rho: float = 7.0,
|
||||
shift: Optional[float] = None,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
use_dynamic_shifting: bool = False,
|
||||
timestep_spacing: str = "linspace",
|
||||
clip_sample: bool = False,
|
||||
sample_max_value: float = 1.0,
|
||||
set_alpha_to_one: bool = False,
|
||||
skip_prk_steps: bool = False,
|
||||
interpolation_type: str = "linear",
|
||||
steps_offset: int = 0,
|
||||
timestep_type: str = "discrete",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# internal state
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = None
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
self._sigmas_cpu = None
|
||||
self._timesteps_cpu = None
|
||||
self._step_index = None
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
# 1. Spacing
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float).copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(float)
|
||||
timesteps -= step_ratio
|
||||
else:
|
||||
raise ValueError(f"timestep_spacing must be one of 'linspace', 'leading', or 'trailing', got {self.config.timestep_spacing}")
|
||||
|
||||
# Ensure trailing ends at 0
|
||||
if self.config.timestep_spacing == "trailing":
|
||||
timesteps = np.maximum(timesteps, 0)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.interpolation_type == "linear":
|
||||
sigmas = np.interp(timesteps, np.arange(len(sigmas)), sigmas)
|
||||
elif self.config.interpolation_type == "log_linear":
|
||||
sigmas = np.exp(np.interp(timesteps, np.arange(len(sigmas)), np.log(sigmas)))
|
||||
else:
|
||||
raise ValueError(f"interpolation_type must be one of 'linear' or 'log_linear', got {self.config.interpolation_type}")
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
rho = self.config.rho
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
sigmas = np.exp(np.linspace(np.log(sigma_max), np.log(sigma_min), num_inference_steps))
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigma_min = self.config.sigma_min if self.config.sigma_min is not None else sigmas[-1]
|
||||
sigma_max = self.config.sigma_max if self.config.sigma_max is not None else sigmas[0]
|
||||
alpha, beta = 0.6, 0.6
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
try:
|
||||
import torch.distributions as dist
|
||||
|
||||
b = dist.Beta(alpha, beta)
|
||||
ramp = b.sample((num_inference_steps,)).sort().values.numpy()
|
||||
except Exception:
|
||||
pass
|
||||
sigmas = sigma_max * (1 - ramp) + sigma_min * ramp
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = np.linspace(1.0, 1 / 1000, num_inference_steps)
|
||||
|
||||
# 3. Shifting
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
sigmas = mu * sigmas / (1 + (mu - 1) * sigmas)
|
||||
elif self.config.shift is not None:
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
|
||||
# RK6_7s c values: [0, 1/3, 2/3, 1/3, 1/2, 1/2, 1]
|
||||
c_values = [0, 1 / 3, 2 / 3, 1 / 3, 1 / 2, 1 / 2, 1]
|
||||
|
||||
sigmas_expanded = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
s_curr = sigmas[i]
|
||||
s_next = sigmas[i + 1]
|
||||
for c in c_values:
|
||||
sigmas_expanded.append(s_curr + c * (s_next - s_curr))
|
||||
sigmas_expanded.append(0.0)
|
||||
|
||||
sigmas_interpolated = np.array(sigmas_expanded)
|
||||
# Linear remapping for Flow Matching
|
||||
timesteps_expanded = sigmas_interpolated * self.config.num_train_timesteps
|
||||
self.sigmas = torch.from_numpy(sigmas_interpolated).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(timesteps_expanded + self.config.steps_offset).to(device=device, dtype=dtype)
|
||||
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
self._sigmas_cpu = self.sigmas.detach().cpu().numpy()
|
||||
self._timesteps_cpu = self.timesteps.detach().cpu().numpy()
|
||||
self._step_index = None
|
||||
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for the current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self._step_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self._sigmas_cpu[self._step_index]
|
||||
return sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
stage_index = step_index % 7
|
||||
|
||||
base_step_index = (step_index // 7) * 7
|
||||
sigma_curr = self._sigmas_cpu[base_step_index]
|
||||
sigma_next_idx = min(base_step_index + 7, len(self._sigmas_cpu) - 1)
|
||||
sigma_next = self._sigmas_cpu[sigma_next_idx]
|
||||
h = sigma_next - sigma_curr
|
||||
|
||||
sigma_t = self._sigmas_cpu[step_index]
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
|
||||
prediction_type = getattr(self.config, "prediction_type", "epsilon")
|
||||
if prediction_type == "epsilon":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "v_prediction":
|
||||
alpha_t = 1 / (sigma_t**2 + 1) ** 0.5
|
||||
sigma_actual = sigma_t * alpha_t
|
||||
denoised = alpha_t * sample - sigma_actual * model_output
|
||||
elif prediction_type == "flow_prediction":
|
||||
denoised = sample - sigma_t * model_output
|
||||
elif prediction_type == "sample":
|
||||
denoised = model_output
|
||||
else:
|
||||
raise ValueError(f"prediction_type error: {prediction_type}")
|
||||
|
||||
if self.config.clip_sample:
|
||||
denoised = denoised.clamp(-self.config.sample_max_value, self.config.sample_max_value)
|
||||
|
||||
# derivative = (x - x0) / sigma
|
||||
derivative = (sample - denoised) / sigma_t if sigma_t > 1e-6 else torch.zeros_like(sample)
|
||||
|
||||
if self.sample_at_start_of_step is None:
|
||||
if stage_index > 0:
|
||||
# Mid-step fallback for Img2Img/Inpainting
|
||||
sigma_next_t = self._sigmas_cpu[self._step_index + 1]
|
||||
dt = sigma_next_t - sigma_t
|
||||
prev_sample = sample + dt * derivative
|
||||
self._step_index += 1
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
self.sample_at_start_of_step = sample
|
||||
self.model_outputs = [derivative] * stage_index
|
||||
|
||||
# Butcher Tableau A matrix for rk6_7s
|
||||
a = [
|
||||
[],
|
||||
[1 / 3],
|
||||
[0, 2 / 3],
|
||||
[1 / 12, 1 / 3, -1 / 12],
|
||||
[-1 / 16, 9 / 8, -3 / 16, -3 / 8],
|
||||
[0, 9 / 8, -3 / 8, -3 / 4, 1 / 2],
|
||||
[9 / 44, -9 / 11, 63 / 44, 18 / 11, 0, -16 / 11],
|
||||
]
|
||||
|
||||
# Butcher Tableau B weights for rk6_7s
|
||||
b = [11 / 120, 0, 27 / 40, 27 / 40, -4 / 15, -4 / 15, 11 / 120]
|
||||
|
||||
if stage_index == 0:
|
||||
self.model_outputs = [derivative]
|
||||
self.sample_at_start_of_step = sample
|
||||
else:
|
||||
self.model_outputs.append(derivative)
|
||||
|
||||
if stage_index < 6:
|
||||
# Predict next stage sample: y_next_stage = y_start + h * sum(a[stage_index+1][j] * k[j])
|
||||
next_a_row = a[stage_index + 1]
|
||||
sum_ak = torch.zeros_like(derivative)
|
||||
for j, weight in enumerate(next_a_row):
|
||||
sum_ak += weight * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_ak
|
||||
else:
|
||||
# Final 7th stage complete, calculate final step
|
||||
sum_bk = torch.zeros_like(derivative)
|
||||
for j, weight in enumerate(b):
|
||||
sum_bk += weight * self.model_outputs[j]
|
||||
|
||||
prev_sample = self.sample_at_start_of_step + h * sum_bk
|
||||
|
||||
# Clear state
|
||||
self.model_outputs = []
|
||||
self.sample_at_start_of_step = None
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
import math
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import scipy.stats
|
||||
_scipy_available = True
|
||||
except ImportError:
|
||||
_scipy_available = False
|
||||
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps: int,
|
||||
max_beta: float = 0.999,
|
||||
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if alpha_transform_type == "cosine":
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
elif alpha_transform_type == "laplace":
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
elif alpha_transform_type == "exp":
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=dtype)
|
||||
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
alphas_bar = alphas_bar_sqrt**2
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
return betas
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu", dtype: torch.dtype = torch.float32):
|
||||
ramp = np.linspace(0, 1, n)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu", dtype: torch.dtype = torch.float32):
|
||||
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), n))
|
||||
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
||||
|
||||
def get_sigmas_beta(n, sigma_min, sigma_max, alpha=0.6, beta=0.6, device="cpu", dtype: torch.dtype = torch.float32):
|
||||
if not _scipy_available:
|
||||
raise ImportError("scipy is required for beta sigmas")
|
||||
sigmas = np.array(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, n)
|
||||
]
|
||||
]
|
||||
)
|
||||
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
||||
|
||||
def get_sigmas_flow(n, sigma_min, sigma_max, device="cpu", dtype: torch.dtype = torch.float32):
|
||||
# Linear flow sigmas
|
||||
sigmas = np.linspace(sigma_max, sigma_min, n)
|
||||
return torch.from_numpy(sigmas).to(dtype=dtype, device=device)
|
||||
|
||||
def apply_shift(sigmas, shift):
|
||||
return shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
def get_dynamic_shift(mu, base_shift, max_shift, base_seq_len, max_seq_len):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
return m * mu + b
|
||||
|
||||
def index_for_timestep(timestep, timesteps):
|
||||
# Normalize inputs to numpy arrays for a robust, device-agnostic argmin
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep_np = timestep.detach().cpu().numpy()
|
||||
else:
|
||||
timestep_np = np.array(timestep)
|
||||
|
||||
if isinstance(timesteps, torch.Tensor):
|
||||
timesteps_np = timesteps.detach().cpu().numpy()
|
||||
else:
|
||||
timesteps_np = np.array(timesteps)
|
||||
|
||||
# Use numpy argmin on absolute difference for stability
|
||||
idx = np.abs(timesteps_np - timestep_np).argmin()
|
||||
return int(idx)
|
||||
|
||||
def add_noise_to_sample(
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
sigmas: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
step_index = index_for_timestep(timestep, timesteps)
|
||||
sigma = sigmas[step_index].to(original_samples.dtype)
|
||||
|
||||
noisy_samples = original_samples + sigma * noise
|
||||
return noisy_samples
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
from diffusers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SimpleExponentialScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Simple Exponential sigma scheduler using Exponential Integrator step.
|
||||
"""
|
||||
|
||||
_compatibles: ClassVar[List[str]] = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order: ClassVar[int] = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
sigma_max: float = 1.0,
|
||||
sigma_min: float = 0.01,
|
||||
gain: float = 1.0,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_karras_sigmas: bool = False,
|
||||
use_exponential_sigmas: bool = False,
|
||||
use_beta_sigmas: bool = False,
|
||||
use_flow_sigmas: bool = False,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting: bool = False,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
):
|
||||
from .scheduler_utils import betas_for_alpha_bar, rescale_zero_terminal_snr
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does not exist.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# Standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
self.sigmas = torch.zeros((num_train_timesteps,), dtype=torch.float32)
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self) -> Optional[int]:
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self) -> Optional[int]:
|
||||
return self._begin_index
|
||||
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None, dtype: torch.dtype = torch.float32):
|
||||
from .scheduler_utils import (
|
||||
apply_shift,
|
||||
get_dynamic_shift,
|
||||
get_sigmas_beta,
|
||||
get_sigmas_exponential,
|
||||
get_sigmas_flow,
|
||||
get_sigmas_karras,
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
sigmas = np.exp(np.linspace(np.log(self.config.sigma_max), np.log(self.config.sigma_min), num_inference_steps))
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = get_sigmas_karras(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = get_sigmas_exponential(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = get_sigmas_beta(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
elif self.config.use_flow_sigmas:
|
||||
sigmas = get_sigmas_flow(num_inference_steps, sigmas[-1], sigmas[0], device=device, dtype=dtype).cpu().numpy()
|
||||
|
||||
if self.config.shift != 1.0 or self.config.use_dynamic_shifting:
|
||||
shift = self.config.shift
|
||||
if self.config.use_dynamic_shifting and mu is not None:
|
||||
shift = get_dynamic_shift(
|
||||
mu,
|
||||
self.config.base_shift,
|
||||
self.config.max_shift,
|
||||
self.config.base_image_seq_len,
|
||||
self.config.max_image_seq_len,
|
||||
)
|
||||
sigmas = apply_shift(torch.from_numpy(sigmas), shift).numpy()
|
||||
|
||||
self.sigmas = torch.from_numpy(np.concatenate([sigmas, [0.0]])).to(device=device, dtype=dtype)
|
||||
self.timesteps = torch.from_numpy(np.linspace(1000, 0, num_inference_steps)).to(device=device, dtype=dtype)
|
||||
self.init_noise_sigma = self.sigmas.max().item() if self.sigmas.numel() > 0 else 1.0
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
from .scheduler_utils import index_for_timestep
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
return index_for_timestep(timestep, schedule_timesteps)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from .scheduler_utils import add_noise_to_sample
|
||||
return add_noise_to_sample(original_samples, noise, self.sigmas, timesteps, self.timesteps)
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
if self.config.prediction_type == "flow_prediction":
|
||||
return sample
|
||||
sigma = self.sigmas[self._step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[float, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
if self._step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
step_index = self._step_index
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
|
||||
# Determine denoised (x_0 prediction)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
x0 = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t = 1.0 / (sigma**2 + 1)**0.5
|
||||
sigma_t = sigma * alpha_t
|
||||
x0 = alpha_t * sample - sigma_t * model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0 = model_output
|
||||
elif self.config.prediction_type == "flow_prediction":
|
||||
x0 = sample - sigma * model_output
|
||||
else:
|
||||
x0 = model_output
|
||||
|
||||
# Exponential Integrator Update (1st order)
|
||||
if sigma_next == 0:
|
||||
x_next = x0
|
||||
else:
|
||||
h = -torch.log(sigma_next / sigma) if sigma > 0 and sigma_next > 0 else torch.zeros_like(sigma)
|
||||
x_next = torch.exp(-h) * sample + (1 - torch.exp(-h)) * x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (x_next,)
|
||||
return SchedulerOutput(prev_sample=x_next)
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue