diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 2ee05f966..98bdff031 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -58,3 +58,49 @@ General app structure is: - Shared mutable global state can create subtle regressions; prefer narrow, explicit changes. - Device/backend-specific code paths (**CUDA/ROCm/IPEX/DirectML/OpenVINO**) should not assume one platform. - Scripts and extension loading is dynamic; failures may appear only when specific extensions or models are present. + +## Repo-Local Skills + +Use these repo-local skills for recurring SD.Next model integration work: + +- `port-model` + File: `.github/skills/port-model/SKILL.md` + Use when adding a new model family, porting a standalone script into a Diffusers pipeline, or wiring an upstream Diffusers model into SD.Next. + +- `debug-model` + File: `.github/skills/debug-model/SKILL.md` + Use when a new or existing SD.Next/Diffusers model integration fails during detection, loading, prompt encoding, sampling, or output handling. + +- `check-api` + File: `.github/skills/check-api/SKILL.md` + Use when auditing API routes in `modules/api/api.py` and validating endpoint parameters plus request/response signatures. + +- `check-schedulers` + File: `.github/skills/check-schedulers/SKILL.md` + Use when auditing scheduler registrations in `modules/sd_samplers_diffusers.py` for class loadability, config validity, and `SamplerData` mapping correctness. + +- `check-models` + File: `.github/skills/check-models/SKILL.md` + Use when running end-to-end model integration audits across loaders, detect/routing parity, reference catalogs, and pipeline API contracts. + +- `check-processing` + File: `.github/skills/check-processing/SKILL.md` + Use when validating txt2img/img2img/control processing workflows from UI submit definitions through backend processing and Diffusers execution, including parameter, type, and initialization checks. + +- `check-scripts` + File: `.github/skills/check-scripts/SKILL.md` + Use when auditing scripts in `scripts/*.py` for standard `Script` overrides (`__init__`, `title`, `show`) and validating `ui()` output against `run()` or `process()` parameters. + +- `github-issues` + File: `.github/skills/github-issues/SKILL.md` + Use when reading SD.Next GitHub issues with `[Issues]` in title and producing a markdown report with short summary, status, and suggested next steps for each issue. + +- `github-features` + File: `.github/skills/github-features/SKILL.md` + Use when reading SD.Next GitHub issues with `[Feature]` in title and producing a markdown report with short summary, status, and suggested next steps for each issue. + +- `analyze-model` + File: `.github/skills/analyze-model/SKILL.md` + Use when analyzing an external model URL to identify implementation style and estimate how difficult it is to port into SD.Next. + +When creating and updating skills, update this file and the index in `.github/skills/README.md` accordingly. diff --git a/.github/skills/README.md b/.github/skills/README.md new file mode 100644 index 000000000..04ee2d386 --- /dev/null +++ b/.github/skills/README.md @@ -0,0 +1,51 @@ +# Repo Skills + +This folder contains repo-local Copilot skills for recurring SD.Next tasks. + +## Available Skills + +- `port-model` + File: `port-model/SKILL.md` + Use when adding or porting a model family into SD.Next and Diffusers. + +- `debug-model` + File: `debug-model/SKILL.md` + Use when a new or existing SD.Next/Diffusers model integration fails during detect, load, prompt encode, sample, or output handling. + +- `check-api` + File: `check-api/SKILL.md` + Use when auditing API endpoints in `modules/api/api.py` and delegated API modules for route parameter correctness plus request/response signature consistency. + +- `check-schedulers` + File: `check-schedulers/SKILL.md` + Use when auditing scheduler registrations from `modules/sd_samplers_diffusers.py` to verify class loadability, config validity against scheduler capabilities, and `SamplerData` correctness. + +- `check-models` + File: `check-models/SKILL.md` + Use when running an end-to-end model integration audit covering loaders, detect/routing parity, reference catalogs, and custom pipeline API contracts. + +- `check-processing` + File: `check-processing/SKILL.md` + Use when validating txt2img/img2img/control processing workflows from UI submit definitions to backend execution with parameter, type, and initialization checks. + +- `check-scripts` + File: `check-scripts/SKILL.md` + Use when auditing `scripts/*.py` for correct Script overrides (`__init__`, `title`, `show`) and verifying `ui()` output compatibility with `run()` or `process()` parameters. + +- `github-issues` + File: `github-issues/SKILL.md` + Use when reading SD.Next GitHub issues with `[Issues]` in title and producing a markdown summary with status and suggested next steps for each issue. + +- `github-features` + File: `github-features/SKILL.md` + Use when reading SD.Next GitHub issues with `[Feature]` in title and producing a markdown summary with status and suggested next steps for each issue. + +- `analyze-model` + File: `analyze-model/SKILL.md` + Use when analyzing an external model URL to classify implementation style and estimate SD.Next porting difficulty before coding. + +## Notes + +- Keep skills narrowly task-oriented and reusable. +- Prefer referencing existing repo patterns over generic framework advice. +- Update this index when adding new repo-local skills. diff --git a/.github/skills/analyze-model/SKILL.md b/.github/skills/analyze-model/SKILL.md new file mode 100644 index 000000000..2f9df43ff --- /dev/null +++ b/.github/skills/analyze-model/SKILL.md @@ -0,0 +1,133 @@ +--- +name: analyze-model +description: "Analyze an external model URL (typically Hugging Face) to determine implementation style and estimate SD.Next porting difficulty using the port-model workflow." +argument-hint: "Provide model URL and optional target scope: text2img, img2img, edit, video, or full integration" +--- + +# Analyze External Model For SD.Next Porting + +Given an external model URL, inspect how the model is implemented and estimate how hard it is to port into SD.Next according to the port-model skill. + +## When To Use + +- User provides a Hugging Face model URL and asks if or how it can be ported +- User wants effort estimation before implementation work +- User wants to classify whether integration should reuse Diffusers, use custom Diffusers code, or require full custom implementation + +## Accepted Inputs + +- Hugging Face model URL (preferred) +- Hugging Face repo id +- Optional source links to custom code repos or PRs + +Example: + +- https://huggingface.co/jdopensource/JoyAI-Image-Edit + +## Required Outputs + +For each analyzed model, provide: + +1. Implementation classification +2. Evidence for classification +3. SD.Next porting path recommendation +4. Difficulty rating and effort breakdown +5. Main risks and blockers +6. First implementation steps aligned with port-model skill + +## Implementation Classification Buckets + +Classify into one of these (or closest fit): + +1. Integrated into upstream Diffusers library +2. Custom Diffusers implementation in model repo (custom pipeline/classes, not upstream) +3. Fully custom implementation (non-Diffusers inference stack) +4. Existing integration in ComfyUI (node or PR) but not in Diffusers +5. Other (describe clearly) + +## Procedure + +### 1. Inspect Model Repository Artifacts + +From the provided URL/repo, collect: + +- model card details +- files such as model_index.json, config.json, scheduler config, tokenizer files +- presence of Diffusers-style folder layout +- references to custom Python modules or remote code requirements + +### 2. Determine Runtime Stack + +Identify whether model usage is: + +- standard Diffusers pipeline call +- custom Diffusers pipeline class with trust_remote_code +- pure custom inference script or framework +- node-based integration in ComfyUI or another host + +### 3. Cross-Check Integration Surface + +Determine required SD.Next touchpoints if ported: + +- loader file in pipelines/model_name.py +- detect and dispatch updates in modules/sd_detect.py and modules/sd_models.py +- model type mapping in modules/modeldata.py +- optional custom pipeline package under pipelines/model/ +- reference catalog updates and preview asset requirements + +### 4. Estimate Porting Difficulty + +Use this scale: + +- Low: mostly loader wiring to existing upstream Diffusers pipeline +- Medium: custom Diffusers classes or limited checkpoint/config adaptation +- High: full custom architecture, major prompt/sampler/output differences, or sparse docs +- Very High: no usable Diffusers path plus major runtime assumptions mismatch + +Break down difficulty by: + +- loader complexity +- pipeline/API contract complexity +- scheduler/sampler compatibility +- prompt encoding complexity +- checkpoint conversion/remapping complexity +- validation and testing burden + +### 5. Identify Risks + +Call out concrete risks: + +- missing or incompatible scheduler config +- unclear output domain (latent vs pixel) +- custom text encoder or processor constraints +- nonstandard checkpoint format +- dependency on external runtime features unavailable in SD.Next + +### 6. Recommend Porting Path + +Map recommendation to port-model workflow: + +- Upstream Diffusers reuse path +- Custom Diffusers pipeline package path +- Raw checkpoint plus remap path + +Provide a concise first-step plan with smallest viable integration milestone. + +## Reporting Format + +Return sections in this order: + +1. Classification +2. Evidence +3. Porting difficulty +4. Recommended SD.Next integration path +5. Risks and unknowns +6. Next actions + +If critical information is missing, explicitly list unknowns and what to inspect next. + +## Notes + +- Prefer concrete evidence from repository files and model card over assumptions. +- If source suggests multiple possible paths, compare at least two and state why one is preferred. +- Keep recommendations aligned with the conventions defined in the port-model skill. \ No newline at end of file diff --git a/.github/skills/check-api/SKILL.md b/.github/skills/check-api/SKILL.md new file mode 100644 index 000000000..1c0a75daf --- /dev/null +++ b/.github/skills/check-api/SKILL.md @@ -0,0 +1,121 @@ +--- +name: check-api +description: "Audit SD.Next API route definitions and verify endpoint parameters plus request/response signatures against declared FastAPI contracts." +argument-hint: "Optionally focus on specific route prefixes or endpoint groups" +--- + +# Check API Endpoints And Signatures + +Read modules/api/api.py, enumerate all registered endpoints, and validate that each endpoint has coherent request parameters and response signatures. + +## When To Use + +- The user asks to audit API correctness +- A change touched API routes, endpoint handlers, or API models +- OpenAPI docs look wrong or clients report schema mismatches +- You need a pre-PR API contract sanity pass + +## Primary File + +- `modules/api/api.py` + +This file is the route registration hub and must be treated as the source of truth for direct `add_api_route(...)` registrations. + +## Secondary Files To Inspect + +- `modules/api/models.py` +- `modules/api/endpoints.py` +- `modules/api/generate.py` +- `modules/api/process.py` +- `modules/api/control.py` +- Any module loaded via `register_api(...)` from `modules/api/*` and feature modules (caption, lora, gallery, civitai, rembg, etc.) + +## Audit Goals + +For every endpoint, verify: + +1. Route method and path are valid and unique after subpath handling. +2. Handler call signature is compatible with the route declaration. +3. Declared `response_model` is coherent with returned payload shape. +4. Request body or query params implied by handler type hints are consistent with expected client usage. +5. Authentication behavior is intentional (`auth=True` default in `add_api_route`). +6. OpenAPI schema exposure is correct (including trailing-slash duplicate suppression). + +## Procedure + +### 1. Enumerate Routes In modules/api/api.py + +- Collect every `self.add_api_route(...)` call. +- Capture: path, methods, handler symbol, response_model, tags, auth flag. +- Note route groups: server, generation, processing, scripts, enumerators, functional. + +### 2. Resolve Handler Definitions + +For each handler symbol: + +- Locate the callable definition. +- Read its function signature and type hints. +- Identify required positional args, optional args, and body model annotations. + +Flag issues such as: + +- Required parameters that cannot be supplied by FastAPI. +- Mismatch between route method and handler intent (for example body expected on GET). +- Ambiguous or missing type hints for public API handlers. + +### 3. Validate Request Signatures + +- Confirm request model classes exist and are importable. +- Confirm endpoint signature reflects expected request source (path/query/body). +- Check optional vs required semantics for compatibility-sensitive endpoints. + +### 4. Validate Response Signatures + +- Compare declared `response_model` to handler return shape. +- Ensure list/dict wrappers match actual payload structure. +- Flag obvious drift where `response_model` implies fields never returned. + +### 5. Include register_api(...) Modules + +`Api.register()` delegates additional endpoints through module-level `register_api(...)` calls. + +- Inspect each delegated module for routes. +- Apply the same request/response signature checks. +- Include these findings in the final report, not only direct routes in api.py. + +### 6. Verify Runtime Schema (Optional But Preferred) + +If feasible in the current environment: + +- Build app and inspect `app.routes` metadata. +- Generate OpenAPI schema and spot-check key endpoints. +- Confirm trailing-slash duplicate suppression behavior remains correct. + +If runtime schema checks are not feasible, explicitly state that and rely on static validation. + +## Reporting Format + +Report findings ordered by severity: + +1. Breaking contract mismatches +2. Likely runtime errors +3. Schema quality or consistency issues +4. Minor style/typing improvements + +For each finding include: + +- Route path and method +- File and function +- Why it is a mismatch +- Minimal fix recommendation + +If no issues are found, state that explicitly and mention residual risk (for example runtime-only behavior not executed). + +## Output Expectations + +When this skill is used, return: + +- Total endpoints checked +- Direct routes vs delegated `register_api(...)` routes checked +- Findings list with severity and locations +- Clear pass/fail summary for request and response signature consistency diff --git a/.github/skills/check-models/SKILL.md b/.github/skills/check-models/SKILL.md new file mode 100644 index 000000000..88414134e --- /dev/null +++ b/.github/skills/check-models/SKILL.md @@ -0,0 +1,151 @@ +--- +name: check-models +description: "Audit SD.Next model integrations end-to-end: loaders, detect/routing, reference catalogs, and pipeline API contracts." +argument-hint: "Optionally focus on a model family, repo id, or a subset: loader, detect-routing, references, pipeline-contracts" +--- + +# Check Model Integrations End-To-End + +Run a consolidated model-integration audit that combines loader checks, detect/routing checks, reference-catalog checks, and pipeline contract checks. + +## When To Use + +- A new model family was added and needs a completeness audit +- Existing model support appears inconsistent across detection, loading, and UI references +- A custom pipeline was ported and needs contract validation +- You want a pre-PR integration quality gate for model-related changes + +## Combined Scope + +This skill combines four audit surfaces: + +1. Loader consistency (`check-loaders` equivalent) +2. Detect/routing parity (`check-detect-routing` equivalent) +3. Reference-catalog integrity (`check-reference-catalog` equivalent) +4. Pipeline API contract conformance (`check-pipeline-contracts` equivalent) + +## Primary Files + +- `pipelines/model_*.py` +- `modules/sd_detect.py` +- `modules/sd_models.py` +- `modules/modeldata.py` +- `data/reference.json` +- `data/reference-cloud.json` +- `data/reference-quant.json` +- `data/reference-distilled.json` +- `data/reference-nunchaku.json` +- `data/reference-community.json` +- `models/Reference/` + +Pipeline files as needed: + +- `pipelines//pipeline.py` +- `pipelines//model.py` + +## Audit A: Loader Consistency + +For each target model loader in `pipelines/model_*.py`, verify: + +- Correct `sd_models.path_to_repo(checkpoint_info)` and `sd_models.hf_auth_check(...)` usage +- Load args built with `model_quant.get_dit_args(...)` where applicable +- No duplicated kwargs (for example duplicate `torch_dtype`) +- Correct component loading path (`generic.load_transformer`, `generic.load_text_encoder`, tokenizer/processor) +- Proper post-load hooks (`sd_hijack_te`, `sd_hijack_vae`) where required +- Correct `pipe.task_args` defaults where needed +- Cleanup and `devices.torch_gc(...)` present + +Flag stale patterns, missing hooks, or conflicting load behavior. + +## Audit B: Detect/Routing Parity + +Verify model family alignment across: + +- `modules/sd_detect.py` detection heuristics +- `modules/sd_models.py` load dispatch branch +- `modules/modeldata.py` reverse classification from loaded pipeline class + +Checks: + +- Family is detectable by name/repo conventions +- Dispatch routes to the intended loader +- Loaded pipeline class is classified back to the same model family +- Branch ordering does not cause broad matches to shadow specific families + +## Audit C: Reference Catalog Integrity + +Verify references for model families intended to appear in model references. + +Checks: + +- Correct category file placement by type: + - base -> `data/reference.json` + - cloud -> `data/reference-cloud.json` + - quant -> `data/reference-quant.json` + - distilled -> `data/reference-distilled.json` + - nunchaku -> `data/reference-nunchaku.json` + - community -> `data/reference-community.json` +- Required fields present per entry (`path`, `preview`, `desc` when expected) +- Duplicate repo/path collisions across reference files are intentional or flagged +- Preview filename convention is consistent +- Referenced preview file exists in `models/Reference/` (or explicitly placeholder if intentional) +- JSON validity for touched reference files + +## Audit D: Pipeline API Contracts + +For custom pipelines (`pipelines//pipeline.py`), verify: + +- Inherits from `diffusers.DiffusionPipeline` +- Registers modules correctly +- `from_pretrained` wiring is coherent with actual artifact layout +- `encode_prompt` semantics are consistent with tokenizer/text encoder setup +- `__call__` supports expected public args for its task and does not expose unsupported generic args +- Batch and negative prompt behavior are coherent +- Output conversion aligns with model output domain (latent vs pixel space) +- `output_type` and `return_dict` behavior are consistent + +## Runtime Validation (Preferred) + +When feasible: + +- Import-level smoke tests for loaders and pipeline modules +- Lightweight loader construction checks without full heavy generation where possible +- One minimal generation/sampling pass for changed model families + +If runtime checks are not feasible, report limitations clearly. + +## Reporting Format + +Return findings by severity: + +1. Blocking integration failures +2. Contract mismatches (load/detect/reference/pipeline) +3. Consistency and quality issues +4. Optional improvements + +For each finding include: + +- model family +- layer (`loader`, `detect-routing`, `reference`, `pipeline-contract`) +- file location +- mismatch summary +- minimal fix + +Also include summary counts: + +- loaders checked +- model families checked for detect/routing parity +- reference files checked +- pipeline contracts checked +- runtime checks executed vs skipped + +## Pass Criteria + +A full pass requires all of the following in audited scope: + +- loader path is coherent and non-conflicting +- detect/routing/modeldata parity holds +- reference entries are valid, categorized correctly, and have preview files +- custom pipeline contracts are consistent with actual model behavior + +If any area is intentionally out of scope, mark as partial pass with explicit exclusions. \ No newline at end of file diff --git a/.github/skills/check-processing/SKILL.md b/.github/skills/check-processing/SKILL.md new file mode 100644 index 000000000..2b81a4558 --- /dev/null +++ b/.github/skills/check-processing/SKILL.md @@ -0,0 +1,178 @@ +--- +name: check-processing +description: "Validate txt2img/img2img/control/caption processing workflows from UI submit bindings to backend processing execution and confirm parameter/type/init correctness." +argument-hint: "Optionally focus on txt2img, img2img, control, caption, or process-only and include changed files" +--- + +# Check Processing Workflow Contracts + +Trace generation workflows from UI definitions and submit bindings to backend execution, then validate that parameters are passed, typed, and initialized correctly. + +## When To Use + +- A change touched UI submit wiring for txt2img, img2img, control, or caption workflows +- Processing code changed and regressions are suspected in argument ordering or defaults +- A new parameter was added to UI or processing classes/functions and needs end-to-end validation +- You want a pre-PR contract audit for generation flow integrity + +## Required Workflow Coverage + +Start from UI definitions and follow each workflow to final implementation: + +1. `txt2img`: `modules/ui_txt2img.py` -> `modules/txt2img.py` -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers` +2. `img2img`: `modules/ui_img2img.py` -> `modules/img2img.py` -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers` +3. `control/process`: `modules/ui_control.py` -> `modules/control/run.py` (and related control processing entrypoints) -> `modules/processing.py:process_images` -> `modules/processing_diffusers.py:process_diffusers` +4. `caption/process`: `modules/ui_caption.py` -> caption handler module(s) -> `modules/processing.py:process_images` and/or postprocess/caption execution module(s), depending on selected caption backend + +Also validate script hooks when present: + +- `modules/scripts_manager.py` (`run`, `before_process`, `process`, `process_images`, `after`) + +## Primary Files + +- `modules/ui_txt2img.py` +- `modules/txt2img.py` +- `modules/ui_img2img.py` +- `modules/img2img.py` +- `modules/ui_control.py` +- `modules/ui_caption.py` (and `modules/ui_captions.py` if present) +- `modules/control/run.py` +- `modules/processing.py` +- `modules/processing_diffusers.py` +- `modules/scripts_manager.py` + +## Audit Goals + +For each covered workflow, verify all three dimensions: + +1. Parameter pass-through correctness (name, order, semantic meaning) +2. Type correctness (UI component output shape vs function signature expectations) +3. Initialization correctness (defaults, `None` handling, fallback logic, and object state) + +## Procedure + +### 1. Build End-To-End Call Graph + +For each workflow (`txt2img`, `img2img`, `control`, `caption`): + +- Locate submit/click bindings in UI modules. +- Capture the exact `inputs=[...]` list order and target function (`fn=...`). +- Resolve wrappers (`call_queue.wrap_gradio_gpu_call`, queued wrappers) to actual function signatures. +- Follow function flow through processing class construction and execution (`processing.process_images`, then `process_diffusers` when applicable). + +Produce a normalized mapping table per workflow: + +- UI input component name +- UI expected output type +- receiving argument in submit target +- receiving processing-object field (if applicable) +- downstream consumption point + +### 2. Validate Argument Order And Arity + +For each submit path: + +- Compare UI `inputs` order against function positional parameter order. +- Validate handling of `*args` and script arguments. +- Confirm `state`, task id, mode flags, and tab selections align with function signatures. +- Flag positional drift where adding/removing an argument in one layer is not propagated. + +### 3. Validate Name And Semantic Parity + +Check that semantically related parameters remain coherent across layers: + +- sampler fields (`sampler_index`, `hr_sampler_index`, sampler name conversion) +- guidance/cfg fields +- size/resize fields +- denoise/refiner/hires fields +- detailer, hdr, grading fields +- override settings fields + +Flag mismatches such as: + +- same concept with divergent naming and wrong destination +- field sent but never consumed +- required field consumed but never provided + +### 4. Validate Type Contracts + +Audit type compatibility from UI component to processing target: + +- `gr.Image`/`gr.File`/`gr.Video` outputs vs expected Python types (`PIL.Image`, bytes, list, path-like, etc.) +- radios returning index/value and expected downstream representation +- sliders/number inputs (`int` vs `float`) and conversion points +- optional objects (`None`) and `.name`/attribute access safety + +Flag ambiguous or unsafe assumptions, especially for optional file inputs and mixed scalar/list values. + +### 5. Validate Initialization And Defaults + +In target modules (`txt2img.py`, `img2img.py`, `control/run.py`, processing classes): + +- verify defaults/fallbacks for invalid or missing inputs +- verify guards for unset model/state and expected error paths +- verify object fields are initialized before first use +- verify flow-specific defaults are not leaking across workflows + +Include checks for common regressions: + +- `None` passed into required processing fields +- missing fallback for sampler/seed/size values +- stale fields retained from prior job state + +### 6. Validate Script Hook Contracts + +Where script systems are involved: + +- verify `scripts_*.run(...)` fallback behavior to `processing.process_images(...)` +- verify `scripts_*.after(...)` receives compatible processed object +- ensure script args wiring matches `setup_ui(...)` order + +### 7. Runtime Spot Check (Preferred) + +If feasible, run lightweight smoke validation for each workflow: + +- one minimal txt2img run +- one minimal img2img run +- one minimal control run +- one minimal caption run + +Use very small dimensions/steps to limit runtime. +If runtime checks are not feasible, explicitly report static-only limitations. + +## Reporting Format + +Return findings by severity: + +1. Blocking contract failures (will break execution) +2. High-risk mismatches (likely wrong behavior) +3. Type/init safety issues +4. Non-blocking consistency issues + +For each finding include: + +- workflow (`txt2img` | `img2img` | `control` | `caption`) +- layer transition (`ui -> handler`, `handler -> processing`, `processing -> diffusers`) +- file location +- mismatch summary +- minimal fix + +Also include summary counts: + +- workflows checked +- UI bindings checked +- function signatures checked +- parameter mappings validated +- runtime checks executed vs skipped + +## Pass Criteria + +A full pass requires all of the following: + +- UI submit input order matches target function signatures +- all required parameters are passed end-to-end with correct semantics +- type expectations are explicit and compatible at each boundary +- initialization/default logic prevents unset/invalid state usage +- scripts fallback path to `processing.process_images` is coherent + +If only part of the workflow scope was checked, report partial pass with explicit exclusions. diff --git a/.github/skills/check-schedulers/SKILL.md b/.github/skills/check-schedulers/SKILL.md new file mode 100644 index 000000000..472c9a35f --- /dev/null +++ b/.github/skills/check-schedulers/SKILL.md @@ -0,0 +1,149 @@ +--- +name: check-schedulers +description: "Audit scheduler registrations starting from modules/sd_samplers_diffusers.py and verify class loadability, config validity against scheduler capabilities, and SamplerData correctness." +argument-hint: "Optionally focus on a scheduler subset, such as flow-matching, res4lyf, or parallel schedulers" +--- + +# Check Scheduler Registry And Contracts + +Use `modules/sd_samplers_diffusers.py` as the starting point and verify that scheduler classes, scheduler config, and `SamplerData` mappings are coherent and executable. + +## Required Guarantees + +The audit must explicitly verify all three: + +1. All scheduler classes can be loaded and compiled. +2. All scheduler config entries are valid and match scheduler capabilities in `__init__`. +3. All scheduler classes have valid associated `SamplerData` entries and mapping correctness. + +## Scope + +Primary file: + +- `modules/sd_samplers_diffusers.py` + +Related files: + +- `modules/sd_samplers_common.py` for `SamplerData` definition and sampler expectations +- `modules/sd_samplers.py` for sampler selection flow and runtime wiring +- `modules/schedulers/**/*.py` for custom scheduler implementations +- `modules/res4lyf/**/*.py` for Res4Lyf scheduler classes (if installed/enabled) + +## What "Loaded And Compiled" Means + +Treat this as a two-level check: + +1. Import and class resolution: + every scheduler class referenced by `SamplerData(... DiffusionSampler(..., SchedulerClass, ...))` resolves without missing symbol errors. +2. Construction and compile sanity: + scheduler instances can be created from their intended config path and survive a lightweight execution-level check (including compile path when available). + +Notes: + +- For non-`torch.nn.Module` schedulers, "compiled" means the scheduler integration path is executable in runtime checks (not necessarily `torch.compile`). +- If the environment cannot run compile checks, report this explicitly and still complete static validation. + +## Procedure + +### 1. Build Scheduler Inventory + +- Enumerate scheduler classes imported in `modules/sd_samplers_diffusers.py`. +- Enumerate all entries in `samplers_data_diffusers`. +- Enumerate all config keys in `config`. + +Create a joined table by sampler name with: + +- sampler name +- scheduler class +- config key used +- custom scheduler category (diffusers, SD.Next custom, Res4Lyf) + +### 2. Validate Scheduler Class Resolution + +For each mapped scheduler class: + +- confirm symbol exists and is importable +- confirm class object is callable for construction paths used in SD.Next + +Flag missing imports, dead entries, or stale class names. + +### 3. Validate Config Against Scheduler __init__ Capabilities + +For each sampler config entry: + +- inspect scheduler class `__init__` signature and accepted config fields +- flag config keys that are unsupported or misspelled +- flag required scheduler init/config fields that are absent +- verify defaults and overrides are compatible with scheduler family behavior + +Special attention: + +- flow-matching schedulers: shift/base_shift/max_shift/use_dynamic_shifting +- DPM families: algorithm_type/solver_order/solver_type/final_sigmas_type +- compatibility-only keys that are intentionally ignored should be documented, not silently assumed + +### 4. Validate SamplerData Mapping Correctness + +For each `SamplerData` entry: + +- scheduler label matches config key intent +- callable builds `DiffusionSampler` with the expected scheduler class +- mapping is not accidentally pointing to a different named preset +- no duplicate names with conflicting class/config behavior + +Flag mismatches such as wrong display name, wrong class wired to name, or stale aliasing. + +### 5. Runtime Smoke Checks (Preferred) + +If feasible, run lightweight checks: + +- instantiate each scheduler from config +- execute minimal scheduler setup path (`set_timesteps` and a dummy step where possible) +- verify no immediate runtime contract errors + +If runtime checks are not feasible for some schedulers, mark those explicitly as unverified-at-runtime. + +### 6. Compile Path Validation + +Where scheduler runtime path supports compile-related checks in SD.Next: + +- verify scheduler integration path remains compatible with compile options +- detect obvious compile blockers introduced by signature/config mismatches + +Do not mark compile as passed if only static checks were done. + +## Reporting Format + +Return findings ordered by severity: + +1. Blocking scheduler load failures +2. Config/signature contract mismatches +3. `SamplerData` mapping inconsistencies +4. Non-blocking improvements + +For each finding include: + +- sampler name +- scheduler class +- file location +- mismatch reason +- minimal corrective action + +Also include summary counts: + +- total scheduler classes discovered +- total `SamplerData` entries checked +- total config entries checked +- runtime-validated count +- compile-path validated count + +## Pass Criteria + +The check passes only if all are true: + +- all referenced scheduler classes resolve +- each scheduler config entry is compatible with scheduler capabilities +- each `SamplerData` entry is correctly mapped and usable +- no blocking runtime or compile-path failures in validated scope + +If scope is partial due to environment limitations, report pass with explicit limitations, not a full pass. \ No newline at end of file diff --git a/.github/skills/check-scripts/SKILL.md b/.github/skills/check-scripts/SKILL.md new file mode 100644 index 000000000..732bb8b4f --- /dev/null +++ b/.github/skills/check-scripts/SKILL.md @@ -0,0 +1,129 @@ +--- +name: check-scripts +description: "Audit scripts/*.py and verify Script override contracts (init/title/show) plus ui() output compatibility with run() or process() parameters." +argument-hint: "Optionally focus on a subset of scripts or only run-vs-ui or process-vs-ui checks" +--- + +# Check Script Class Contracts + +Audit all Python scripts in `scripts/*.py` and validate that script class overrides and UI-to-execution parameter contracts are correct. + +## When To Use + +- New or changed files were added under `scripts/*.py` +- A script crashes when selected or executed from UI +- A script UI was changed and runtime args no longer match +- You want a pre-PR quality gate for script API compatibility + +## Scope + +Primary audit scope: + +- `scripts/*.py` + +Contract references: + +- `modules/scripts_manager.py` (`Script` base class contracts for `title`, `show`, `ui`, `run`, `process`) +- `modules/scripts_postprocessing.py` (`ScriptPostprocessing` contracts for `ui` and `process`) + +## Required Checks + +### A. Standard Overrides: `__init__`, `title`, `show` + +For each class in `scripts/*.py` that subclasses `scripts.Script` or `scripts_manager.Script`: + +1. `title`: +- method exists +- callable signature is valid +- returns non-empty string value + +2. `show`: +- method exists +- signature is compatible with script runner usage (`show(is_img2img)` or permissive `*args/**kwargs`) +- return behavior is compatible (`bool` or `scripts.AlwaysVisible` / equivalent) + +3. `__init__` (if overridden): +- does not require mandatory constructor args that would break loader instantiation +- avoids side effects that require runtime-only globals at import time +- leaves class in a usable state before `ui()`/`run()`/`process()` are called + +Notes: +- `__init__` is optional; do not fail scripts that rely on inherited constructor. +- For dynamic patterns, flag as warning with rationale instead of hard fail. + +### B. `ui()` Output vs `run()`/`process()` Parameters + +For each script class: + +1. Determine execution target: +- Prefer `run()` if present for generation scripts +- Use `process()` if present and `run()` is absent or script is postprocessing-oriented + +2. Compare `ui()` output shape to target method parameter expectations: +- `ui()` list/tuple output count should match target positional argument capacity after the first processing arg (`p` or `pp`), unless target uses `*args` +- if target is strict positional (no `*args`/`**kwargs`), detect missing/extra UI values +- if target uses keyword-driven processing, ensure UI dict keys map to accepted params or `**kwargs` + +3. Validate ordering assumptions: +- UI control order should align with positional parameter order when positional binding is used +- detect obvious drift when new UI control was added but method signature was not updated + +4. Validate optionality/defaults: +- required target parameters should be satisfiable by UI outputs +- defaulted target params are acceptable even if UI omits them + +### C. Runner Compatibility + +Confirm script methods align with runner expectations in `modules/scripts_manager.py`: + +- `ui()` return type is compatible with runner collection (`list/tuple` or recognized mapping pattern where used) +- `run()`/`process()` receive args in expected form from runner slices +- no obvious mismatch between `args_from/args_to` assumptions and script method arity + +For postprocessing-style scripts in `scripts/*.py`: + +- verify compatibility with `modules/scripts_postprocessing.py` conventions (`ui()` list/dict, `process(pp, *args, **kwargs)`) + +## Procedure + +1. Enumerate all classes in `scripts/*.py` and classify by base class type. +2. For each generation script class, validate `title`, `show`, optional `__init__`, and `ui` -> `run/process` contracts. +3. For each postprocessing script class under `scripts/*.py`, validate `ui` -> `process` mapping semantics. +4. Cross-check ambiguous cases against script runner behavior from `modules/scripts_manager.py` and `modules/scripts_postprocessing.py`. +5. Report concrete mismatches with minimal fixes. + +## Reporting Format + +Return findings by severity: + +1. Blocking script contract failures +2. Runtime- likely arg/arity mismatches +3. Signature/type compatibility warnings +4. Style/consistency improvements + +For each finding include: + +- script file +- class name +- failing contract area (`init`, `title`, `show`, `ui->run`, `ui->process`) +- mismatch summary +- minimal fix + +Also include summary counts: + +- total `scripts/*.py` files checked +- total script classes checked +- classes with `run` contract checked +- classes with `process` contract checked +- override issues found (`init/title/show`) + +## Pass Criteria + +A full pass requires all of the following across audited `scripts/*.py` classes: + +- `title` and `show` overrides are valid and runner-compatible for generation scripts +- overridden `__init__` methods are safely instantiable +- `ui()` output contracts are compatible with `run()` or `process()` args +- no blocking arity/signature mismatch remains + +If a class uses highly dynamic argument routing that cannot be proven statically, mark as conditional pass with explicit runtime validation recommendation. diff --git a/.github/skills/debug-model/SKILL.md b/.github/skills/debug-model/SKILL.md new file mode 100644 index 000000000..57ec7317f --- /dev/null +++ b/.github/skills/debug-model/SKILL.md @@ -0,0 +1,217 @@ +--- +name: debug-model +description: "Debug a broken SD.Next or Diffusers model integration. Use when a newly added or ported model fails to load, misdetects, crashes during prompt encoding or sampling, or produces incorrect outputs." +argument-hint: "Describe the failing model, where it fails, the error message, and whether the model is upstream Diffusers, custom pipeline, or raw checkpoint based" +--- + +# Debug SD.Next And Diffusers Model Port + +Read the error, identify which integration layer is failing, isolate the smallest reproducible failure point, fix the root cause, and validate the fix without expanding scope. + +## When To Use + +- A newly added SD.Next model type does not autodetect correctly +- The loader fails to instantiate a pipeline or component +- A custom pipeline imports but fails during `from_pretrained` +- Prompt encoding fails because of tokenizer, processor, or text encoder mismatch +- Sampling fails due to tensor shape, dtype, device, or scheduler issues +- The model loads but outputs corrupted images, wrong output type, or obviously incorrect results + +## Debugging Order + +Always debug from the outside in. + +1. Detection and routing +2. Loader arguments and component selection +3. Checkpoint path and artifact layout +4. Weight loading and key mapping +5. Prompt encoding +6. Sampling forward path +7. Output postprocessing and SD.Next task integration + +Do not start by rewriting the architecture if the failure is likely in detection, loader wiring, or output handling. + +## Files To Check First + +- `.github/copilot-instructions.md` +- `.github/instructions/core.instructions.md` +- `modules/sd_detect.py` +- `modules/sd_models.py` +- `modules/modeldata.py` +- `pipelines/model_.py` +- `pipelines//model.py` +- `pipelines//pipeline.py` +- `pipelines/generic.py` + +If the port is based on a standalone script, compare the failing path against the original reference implementation and identify the first semantic divergence. + +## Failure Classification + +### 1. Model Not Detected Or Misclassified + +Check: + +- Filename and repo-name heuristics in `modules/sd_detect.py` +- Loader dispatch branch in `modules/sd_models.py` +- Reverse pipeline classification in `modules/modeldata.py` + +Typical symptoms: + +- Wrong loader called +- Pipeline classified as a broader family such as `chroma` instead of a custom `zetachroma` +- Task switching behaves incorrectly because the loaded pipeline type is wrong + +### 2. Loader Fails Before Pipeline Construction + +Check: + +- `sd_models.path_to_repo(checkpoint_info)` output +- `generic.load_transformer(...)` and `generic.load_text_encoder(...)` arguments +- Duplicate kwargs such as `torch_dtype` +- Wrong class chosen for text encoder, tokenizer, or processor +- Whether the source is really a Diffusers repo or only a raw checkpoint + +Typical symptoms: + +- Missing subfolder errors +- `from_pretrained` argument mismatch +- Component class mismatch + +### 3. Raw Checkpoint Load Fails + +Check: + +- Checkpoint path resolution for local file, local directory, and Hub repo +- State dict load method +- Key remapping logic +- Config inference from tensor shapes +- Missing versus unexpected keys after `load_state_dict` + +Typical symptoms: + +- Key mismatch explosion +- Wrong inferred head counts, dimensions, or decoder settings +- Silent shape corruption caused by a bad remap + +### 4. Prompt Encoding Fails + +Check: + +- Tokenizer or processor choice +- `trust_remote_code` requirements +- Chat template or custom prompt formatting +- Hidden state index selection +- Padding and batch alignment between positive and negative prompts + +Typical symptoms: + +- Tokenizer attribute errors +- Hidden state shape mismatch +- CFG failures when negative prompts do not match prompt batch length + +### 5. Sampling Or Forward Pass Fails + +Check: + +- Input tensor shape and channel count +- Device and dtype alignment across all components +- Scheduler timesteps and expected timestep convention +- Classifier-free guidance concatenation and split logic +- Pixel-space versus latent-space assumptions + +Typical symptoms: + +- Shape mismatch in attention or decoder blocks +- Device mismatch between text encoder output and model tensors +- Images exploding to NaNs because timestep semantics are inverted + +### 6. Output Is Wrong But No Exception Is Raised + +Check: + +- Whether the model predicts `x0`, noise, or velocity +- Whether the Euler or other sampler update matches the model objective +- Final scaling and clamp path +- `output_type` handling and `pipe.task_args` +- Whether a VAE is being applied incorrectly to direct pixel-space output + +Typical symptoms: + +- Black, gray, washed-out, or heavily clipped images +- Output with correct size but obviously broken semantics +- Correct tensors but wrong SD.Next display behavior because output type is mismatched + +## Minimal Debug Procedure + +### 1. Reproduce Narrowly + +Capture the smallest failing operation. + +- Pure import failure +- Loader-only failure +- `from_pretrained` failure +- Prompt encode failure +- Single forward pass failure +- First sampler step failure + +Prefer narrow Python checks before attempting a full generation run. + +### 2. Compare Against Working Pattern + +Find the closest working in-repo analogue and compare: + +- Loader structure +- Registered module names +- Pipeline class name and module registration +- Prompt encoding path +- Output conversion path + +### 3. Fix The Root Cause + +Examples: + +- Add the missing `modeldata` branch instead of patching downstream task handling +- Fix checkpoint remapping rather than forcing `strict=False` and ignoring real mismatches +- Correct the output path for pixel-space models instead of routing through a VAE +- Make config inference fail explicitly when ambiguous instead of guessing silently + +### 4. Validate In Layers + +After each meaningful fix, validate the narrowest relevant layer first. + +- `compileall` or syntax check +- `ruff` on touched files +- Import smoke test +- Loader-only smoke test +- Full run only when the lower layers are stable + +## Common Root Causes + +- `modules/modeldata.py` not updated after adding a new custom pipeline family +- `modules/sd_detect.py` branch order causes overbroad detection to win first +- Loader passes duplicated keyword args like `torch_dtype` +- Shared text encoder assumptions do not match the actual model variant +- `from_pretrained` assumes `transformer/` or `text_encoder/` subfolders that do not exist +- Key remapping merges QKV in the wrong order +- CFG path concatenates embeddings or latents incorrectly +- Direct pixel-space models are postprocessed like latent-space diffusion outputs +- Negative prompts are not padded or repeated to match prompt batch shape +- Pipeline class naming collides with broader family checks in `modeldata` + +## Validation Checklist + +When closing the task, report which of these were completed: + +1. Exact failing layer identified +2. Root cause fixed +3. Syntax check passed +4. Focused lint passed +5. Import or loader smoke test passed +6. Real generation tested, or explicitly not tested + +## Example Request Shapes + +- "The new model port fails in from_pretrained" +- "SD.Next detects my custom pipeline as the wrong model type" +- "The loader works but generation returns black images" +- "This standalone-script port loads weights but crashes in attention" \ No newline at end of file diff --git a/.github/skills/github-features/SKILL.md b/.github/skills/github-features/SKILL.md new file mode 100644 index 000000000..278506030 --- /dev/null +++ b/.github/skills/github-features/SKILL.md @@ -0,0 +1,102 @@ +--- +name: github-features +description: "Read SD.Next GitHub issues with [Feature] in the title and generate a markdown report with short summary, status, and suggested next steps per issue." +argument-hint: "Optionally specify state (open/closed/all), max issues, and whether to include labels/assignees" +--- + +# Summarize SD.Next [Feature] GitHub Issues + +Fetch issues from the SD.Next GitHub repository that contain `[Feature]` in the title, then produce a concise markdown report with one entry per issue. + +## When To Use + +- The user asks for periodic feature-request triage summaries +- You need an actionable status report for `[Feature]` tracker items +- You want suggested next actions for each matching issue + +## Repository + +Default target repository: + +- owner: `vladmandic` +- name: `sdnext` + +## Required Output + +Create markdown containing, for each matching issue: + +- issue link and title +- short summary (1-3 sentences) +- status (open/closed and relevant labels) +- suggested next steps (1-3 concrete actions) + +## Procedure + +### 1. Search Matching Issues + +Use GitHub search to find issues with `[Feature]` in title. + +Preferred search query template: + +- `is:issue in:title "[Feature]" repo:vladmandic/sdnext` + +State filters (when requested): + +- open only: add `is:open` (default) +- closed only: add `is:closed` + +Use `github-pull-request_doSearch` for the search step. + +### 2. Fetch Full Issue Details + +For each matched issue (within requested limit): + +- fetch details with `github-pull-request_issue_fetch` +- capture body, labels, assignees, state, updated time, and key discussion context + +### 3. Build Per-Issue Summary + +For each issue, produce: + +1. Short summary: +- describe feature request in plain language +- include current progress signal if present + +2. Status: +- open/closed +- notable labels (for example: enhancement, planned, blocked, stale) +- optional assignee and last update signal + +3. Suggested next steps: +- propose concrete, minimal actions +- tailor actions to issue state and content +- avoid generic filler + +### 4. Produce Markdown Report + +Return a markdown table or bullet report with one row/section per issue. + +Recommended table columns: + +- Issue +- Summary +- Status +- Suggested Next Steps + +If there are many issues, keep summaries short and prioritize clarity. + +## Reporting Rules + +- Keep each issue summary concise and actionable. +- Do not invent facts not present in issue data. +- If issue body is sparse, state assumptions explicitly. +- If no matching issues are found, output a clear "no matches" report. + +## Pass Criteria + +A successful run must: + +- search SD.Next issues with `[Feature]` in title +- include all matched issues in scope (or explicitly mention applied limit) +- provide summary, status, and suggested next steps for each issue +- return the final result as markdown diff --git a/.github/skills/github-issues/SKILL.md b/.github/skills/github-issues/SKILL.md new file mode 100644 index 000000000..befbb4bfb --- /dev/null +++ b/.github/skills/github-issues/SKILL.md @@ -0,0 +1,102 @@ +--- +name: github-issues +description: "Read SD.Next GitHub issues with [Issues] in the title and generate a markdown report with short summary, status, and suggested next steps per issue." +argument-hint: "Optionally specify state (open/closed/all), max issues, and whether to include labels/assignees" +--- + +# Summarize SD.Next [Issues] GitHub Issues + +Fetch issues from the SD.Next GitHub repository that contain `[Issues]` in the title, then produce a concise markdown report with one entry per issue. + +## When To Use + +- The user asks for periodic issue triage summaries +- You need an actionable status report for `[Issues]` tracker items +- You want suggested next actions for each matching issue + +## Repository + +Default target repository: + +- owner: `vladmandic` +- name: `sdnext` + +## Required Output + +Create markdown containing, for each matching issue: + +- issue link and title +- short summary (1-3 sentences) +- status (open/closed and relevant labels) +- suggested next steps (1-3 concrete actions) + +## Procedure + +### 1. Search Matching Issues + +Use GitHub search to find issues with `[Issues]` in title. + +Preferred search query template: + +- `is:issue in:title "[Issues]" repo:vladmandic/sdnext` + +State filters (when requested): + +- open only: add `is:open` (default) +- closed only: add `is:closed` + +Use `github-pull-request_doSearch` for the search step. + +### 2. Fetch Full Issue Details + +For each matched issue (within requested limit): + +- fetch details with `github-pull-request_issue_fetch` +- capture body, labels, assignees, state, updated time, and key discussion context + +### 3. Build Per-Issue Summary + +For each issue, produce: + +1. Short summary: +- describe problem/request in plain language +- include current progress signal if present + +2. Status: +- open/closed +- notable labels (for example: bug, enhancement, stale, blocked) +- optional assignee and last update signal + +3. Suggested next steps: +- propose concrete, minimal actions +- tailor actions to issue state and content +- avoid generic filler + +### 4. Produce Markdown Report + +Return a markdown table or bullet report with one row/section per issue. + +Recommended table columns: + +- Issue +- Summary +- Status +- Suggested Next Steps + +If there are many issues, keep summaries short and prioritize clarity. + +## Reporting Rules + +- Keep each issue summary concise and actionable. +- Do not invent facts not present in issue data. +- If issue body is sparse, state assumptions explicitly. +- If no matching issues are found, output a clear "no matches" report. + +## Pass Criteria + +A successful run must: + +- search SD.Next issues with `[Issues]` in title +- include all matched issues in scope (or explicitly mention applied limit) +- provide summary, status, and suggested next steps for each issue +- return the final result as markdown diff --git a/.github/skills/port-model/SKILL.md b/.github/skills/port-model/SKILL.md new file mode 100644 index 000000000..7d13fa517 --- /dev/null +++ b/.github/skills/port-model/SKILL.md @@ -0,0 +1,247 @@ +--- +name: port-model +description: "Port or add a model to SD.Next using existing Diffusers and custom pipeline patterns. Use when implementing a new model loader, custom pipeline, checkpoint conversion path, or SD.Next model-type integration." +argument-hint: "Describe the source model, target task, checkpoint format, and whether the model already has a Diffusers pipeline" +--- + +# Port Model To SD.Next And Diffusers + +Read the task, identify the model architecture and artifact layout, choose the narrowest integration path that matches existing SD.Next patterns, implement the loader and pipeline wiring, and validate the result. + +## When To Use + +- The user wants to add a new model family to SD.Next +- A standalone inference script needs to become a Diffusers-style pipeline +- A raw checkpoint or safetensors repo needs an SD.Next loader +- A model already exists in Diffusers but is not yet wired into SD.Next +- A custom architecture needs a repo-local `pipelines/` package and loader + +## Core Rule + +Prefer the smallest correct integration path. + +- If upstream Diffusers already supports the model cleanly, reuse the upstream pipeline and only add SD.Next wiring. +- If the model requires custom architecture or sampler behavior, add a repo-local pipeline package under `pipelines/`. +- If the model ships as a raw checkpoint instead of a Diffusers repo, load or remap weights explicitly instead of pretending it is a standard Diffusers layout. + +## Inputs To Collect First + +Before editing anything, determine these facts: + +- Model task: text-to-image, img2img, inpaint, editing, video, multimodal, or other +- Artifact format: Diffusers repo, local folder, single-file safetensors, ckpt, GGUF, or custom layout +- Core components: transformer or UNet, VAE or VQ model, text encoder, tokenizer or processor, scheduler, image encoder, adapters +- Whether output is latent-space or direct pixel-space +- Whether the model family has one fixed architecture or multiple variants +- Whether prompt encoding follows a normal tokenizer path or a custom chat/template path +- Whether there is an existing in-repo analogue to copy structurally + +If any of these remain unclear after reading the repo and source artifacts, ask concise clarifying questions before implementing. + +## Mandatory Category Question + +Before implementing model-reference updates, explicitly ask the user which category the model belongs to: + +- `base` +- `cloud` +- `quant` +- `distilled` +- `nunchaku` +- `community` + +Do not guess this category. Use the user answer to decide which reference JSON file(s) to update. + +## Repo Files To Check + +Start by reading the task description, then inspect the closest matching implementations. + +- `.github/copilot-instructions.md` +- `.github/instructions/core.instructions.md` +- `pipelines/generic.py` +- `pipelines/model_*.py` files similar to the target model +- `modules/sd_models.py` +- `modules/sd_detect.py` +- `modules/modeldata.py` + +Useful examples by pattern: + +- Existing upstream Diffusers loader: `pipelines/model_chroma.py`, `pipelines/model_z_image.py` +- Custom in-repo pipeline package: `pipelines/f_lite/` +- Shared Qwen loader pattern: `pipelines/model_z_image.py`, `pipelines/model_flux2_klein.py` + +## Integration Decision Tree + +### 1. Upstream Diffusers Support Exists + +Use this path when the model already has a usable Diffusers pipeline and component classes. + +Implement: + +- `pipelines/model_.py` +- `modules/sd_models.py` dispatch branch +- `modules/sd_detect.py` filename autodetect branch if appropriate +- `modules/modeldata.py` model type detection branch + +Reuse: + +- `generic.load_transformer(...)` +- `generic.load_text_encoder(...)` +- Existing processor or tokenizer loading patterns +- `sd_hijack_te.init_hijack(pipe)` and `sd_hijack_vae.init_hijack(pipe)` where relevant + +### 2. Custom Pipeline Needed + +Use this path when the model architecture, sampler, or prompt encoding is not available upstream. + +Implement: + +- `pipelines//__init__.py` +- `pipelines//model.py` +- `pipelines//pipeline.py` +- `pipelines/model_.py` +- SD.Next registration in `modules/sd_models.py`, `modules/sd_detect.py`, and `modules/modeldata.py` + +Model module responsibilities: + +- Architecture classes +- Config handling and defaults +- Weight remapping or checkpoint conversion helpers +- Raw checkpoint loading helpers if needed + +Pipeline module responsibilities: + +- `DiffusionPipeline` subclass +- `__init__` +- `from_pretrained` +- `encode_prompt` or equivalent prompt preparation +- `__call__` +- Output dataclass +- Optional callback handling and output conversion + +### 3. Raw Checkpoint Or Single-File Weights + +Use this path when the model source is not a normal Diffusers repository. + +Requirements: + +- Resolve checkpoint path from local file, local directory, or Hub repo +- Load state dict directly +- Remap keys when training format differs from inference format +- Infer config from tensor shapes only if the model family truly varies and no config file exists +- Raise explicit errors for ambiguous or incomplete layouts + +Do not fake a `from_pretrained` implementation that silently assumes missing subfolders exist. + +## Required SD.Next Touchpoints + +Most new model families need all of these: + +- `pipelines/model_.py` + Purpose: SD.Next loader entry point +- `modules/sd_models.py` + Purpose: route detected model type to the correct loader +- `modules/sd_detect.py` + Purpose: detect model family from filename or repo name +- `modules/modeldata.py` + Purpose: classify loaded pipeline instance back into SD.Next model type + +Add only what the model actually needs. + +Reference catalog touchpoints are also required for model ports intended to appear in SD.Next model references. + +- `data/reference.json` for `base` +- `data/reference-cloud.json` for `cloud` +- `data/reference-quant.json` for `quant` +- `data/reference-distilled.json` for `distilled` +- `data/reference-nunchaku.json` for `nunchaku` +- `data/reference-community.json` for `community` + +If the model belongs to multiple categories, update each corresponding `data/reference*.json` file. + +Possible extra integration points: + +- `diffusers.pipelines.auto_pipeline.AUTO_*_PIPELINES_MAPPING` when task switching matters +- `pipe.task_args` when SD.Next needs default runtime kwargs such as `output_type` +- VAE hijack, TE hijack, or task-specific processors + +## Loader Conventions + +In `pipelines/model_.py`: + +- Use `sd_models.path_to_repo(checkpoint_info)` +- Call `sd_models.hf_auth_check(checkpoint_info)` +- Use `model_quant.get_dit_args(...)` for load args +- Respect `devices.dtype` +- Log meaningful loader details +- Reuse `generic.load_transformer(...)` and `generic.load_text_encoder(...)` when possible +- Set `pipe.task_args = {'output_type': 'np'}` when the pipeline should default to numpy output for SD.Next +- Clean up temporary references and call `devices.torch_gc(...)` + +Do not hardcode assumptions about CUDA-only execution, local paths, or one-off environment state. + +## Pipeline Conventions + +When building a custom pipeline: + +- Inherit from `diffusers.DiffusionPipeline` +- Register modules with `DiffusionPipeline.register_modules(...)` +- Keep prompt encoding and output conversion inside the pipeline +- Support `prompt`, `negative_prompt`, `generator`, `output_type`, and `return_dict` when the task is text-to-image-like +- Expose only parameters that are part of the model’s actual sampling surface +- Preserve direct pixel-space output if the model does not use a VAE +- Use a VAE only if the model genuinely outputs latents + +Do not add generic Stable Diffusion arguments that the model does not support. + +## Validation Checklist + +After implementation, validate in this order: + +1. Syntax or bytecode compilation for new files +2. Focused linting on touched files +3. Import-level smoke test for new modules when safe +4. Loader-path validation without doing a full generation if the model is too large +5. One real load and generation pass if feasible + +Always report what was validated and what was not. + +## Reference Asset Update Requirements + +When the user asks to add or port a model for references, also perform these steps: + +1. Update the correct `data/reference*.json` file(s) based on the user-confirmed category. +2. Create a placeholder zero-byte thumbnail file in `models/Reference` for the new model. + +Notes: + +- In this repo, the folder is `models/Reference` (capital `R`). +- Use a deterministic filename that matches the model entry naming convention used in the target reference JSON. +- If a real thumbnail already exists, do not overwrite it with a zero-byte file. + +## Common Failure Modes + +- Model type is added to `sd_models.py` but not `sd_detect.py` +- Loader exists but `modules/modeldata.py` still classifies the pipeline incorrectly +- A custom pipeline is named in a way that collides with a broader existing branch such as `Chroma` +- `torch_dtype` or other loader args are passed twice +- The code assumes a Diffusers repo layout when the source is really a single checkpoint file +- Negative prompt handling does not match prompt batch size +- Output is postprocessed as latents even though the model is pixel-space +- Custom config inference uses fragile defaults without clear failure paths + +## Output Expectations + +When using this skill, the final implementation should usually include: + +- A clear integration path choice +- The minimal set of new files required +- SD.Next routing updates +- Targeted validation results +- Any remaining runtime risks called out explicitly + +## Example Request Shapes + +- "Port this standalone inference script into an SD.Next Diffusers pipeline" +- "Add support for this Hugging Face model repo to SD.Next" +- "Wire this upstream Diffusers pipeline into SD.Next autodetect and loading" +- "Convert this single-file checkpoint model into a custom Diffusers pipeline for SD.Next" \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 2ee05f966..eddab0ef8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,60 +1,8 @@ # SD.Next: AGENTS.md Project Guidelines -SD.Next is a complex codebase with specific patterns and conventions. -General app structure is: -- Python backend server - Uses Torch for model inference, FastAPI for API routes and Gradio for creation of UI components. -- JavaScript/CSS frontend +**SD.Next** is a complex codebase with specific patterns and conventions: +- Read main instructions file `.github/copilot-instructions.md` for general project guidelines, tools, structure, style, and conventions. +- For core tasks, also review instructions `.github/instructions/core.instructions.md` +- For UI tasks, also review instructions `.github/instructions/ui.instructions.md` -## Tools - -- `venv` for Python environment management, activated with `source venv/bin/activate` (Linux) or `venv\Scripts\activate` (Windows). - venv MUST be activated before running any Python commands or scripts to ensure correct dependencies and environment variables. -- `python` 3.10+. -- `pyproject.toml` for Python configuration, including linting and type checking settings. -- `eslint` configured for both core and UI code. -- `pnpm` for managing JavaScript dependencies and scripts, with key commands defined in `package.json`. -- `ruff` and `pylint` for Python linting, with configurations in `pyproject.toml` and executed via `pnpm ruff` and `pnpm pylint`. -- `pre-commit` hooks which also check line-endings and other formatting issues, configured in `.pre-commit-config.yaml`. - -## Project Structure - -- Entry/startup flow: `webui.sh` -> `launch.py` -> `webui.py` -> modules under `modules/`. -- Install: `installer.py` takes care of installing dependencies and setting up the environment. -- Core runtime state is centralized in `modules/shared.py` (shared.opts, model state, backend/device state). -- API/server routes are under `modules/api/`. -- UI codebase is split between base JS in `javascript/` and actual UI in `extensions-builtin/sdnext-modernui/`. -- Model and pipeline logic is split between `modules/sd_*` and `pipelines/`. -- Additional plug-ins live in `scripts/` and are used only when specified. -- Extensions live in `extensions-builtin/` and `extensions/` and are loaded dynamically. -- Tests and CLI scripts are under `test/` and `cli/`, with some API smoke checks in `test/full-test.sh`. - -## Code Style - -- Prefer existing project patterns over strict generic style rules; - this codebase intentionally allows patterns often flagged in default linters such as allowing long lines, etc. - -## Build And Test - -- Activate environment: `source venv/bin/activate` (always ensure this is active when working with Python code). -- Test startup: `python launch.py --test` -- Full startup: `python launch.py` -- Full lint sequence: `pnpm lint` -- Python checks individually: `pnpm ruff`, `pnpm pylint` -- JS checks: `pnpm eslint` and `pnpm eslint-ui` - -## Conventions - -- Keep PR-ready changes targeted to `dev` branch. -- Use conventions from `CONTRIBUTING`. -- Do not include unrelated edits or submodule changes when preparing contributions. -- Use existing CLI/API tool patterns in `cli/` and `test/` when adding automation scripts. -- Respect environment-driven behavior (`SD_*` flags and options) instead of hardcoding platform/model assumptions. -- For startup/init edits, preserve error handling and partial-failure tolerance in parallel scans and extension loading. - -## Pitfalls - -- Initialization order matters: startup paths in `launch.py` and `webui.py` are sensitive to import/load timing. -- Shared mutable global state can create subtle regressions; prefer narrow, explicit changes. -- Device/backend-specific code paths (**CUDA/ROCm/IPEX/DirectML/OpenVINO**) should not assume one platform. -- Scripts and extension loading is dynamic; failures may appear only when specific extensions or models are present. +For specific SKILLS, also review the relevant skill files specified in `.github/skills/README.md` and listed `.github/skills/*/SKILL.md` diff --git a/CHANGELOG.md b/CHANGELOG.md index 971a8c397..28873bc0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,19 @@ # Change Log for SD.Next -## Update for 2026-04-09 +## Update for 2026-04-10 - **Models** - [AiArtLab SDXS-1B](https://huggingface.co/AiArtLab/sdxs-1b) Simple Diffusion XS *(training still in progress)* this model combines Qwen3.5-1.8B text encoder with SDXL-style UNET with only 1.6B parameters and custom 32ch VAE - [Anima Preview-v3](https://huggingface.co/circlestone-labs/Anima) new version of Anima +- **Agents** + framework for AI agent based work in `.github/` + - general instructions: `copilot-instructions.md` (not copilot specific) + - additional instructions in `instructions/`: `core.instructions.md`, `ui.instructions.md` + - skills in in `skills/README.md`: + *validation skills*: `check-models`, `check-api`, `check-schedulers`, `check-processing`, `check-scripts` + *model skills*: `port-model`, `debug-model`, `analyze-model` + *github skills*: `github-issues`, `github-features` - **Caption & Prompt Enhance** - [Google Gemma 4] in *E2B* and *E4B* variants - **Compute** @@ -18,8 +26,19 @@ - enhanced filename pattern processing allows for any *processing* property name (as defined in `modules/processing_class.py` and saved to `ui-config.json`) allows for any *settings* property name (as defined in `modules/ui_definitions.py` and saved to `config.json`) +- **Agents** + created framework for AI agent based work in `.github/` + *note*: all skills are agent-model agnostic + - general instructions: + `AGENTS.md`, `copilot-instructions.md` + - additional instructions in `instructions/`: + `core.instructions.md`, `ui.instructions.md` + - skills in in `skills/README.md`: + *validation skills*: `check-models`, `check-api`, `check-schedulers`, `check-processing`, `check-scripts` + *model skills*: `port-model`, `debug-model`, `analyze-model` + *github skills*: `github-issues`, `github-features` - **Obsoleted** - - remove *system-info* from *extensions-builtin* + - removed *system-info* from *extensions-builtin* - **Internal** - additional typing and typechecks, thanks @awsr - wrap hf download methods diff --git a/TODO.md b/TODO.md index b522a9011..609a4af8f 100644 --- a/TODO.md +++ b/TODO.md @@ -46,7 +46,8 @@ TODO: Investigate which models are diffusers-compatible and prioritize! ### Image-Base -- [NucleusMoe]( torch.Tensor: + assert dim % 2 == 0 + scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float64, device=pos.device) + omega = 1.0 / (theta ** scale) + out = torch.einsum("...n,d->...nd", pos.to(torch.float64), omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + rot = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + return rot.reshape(*out.shape, 2, 2).float() + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: float, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(ids.shape[-1])], dim=-3) + return emb.unsqueeze(1) + + +def _apply_rope_single(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + x_ = x.float().reshape(*x.shape[:-1], -1, 1, 2) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out = x_out + freqs_cis[..., 1] * x_[..., 1] + return x_out.reshape(*x.shape).type_as(x) + + +def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): + return _apply_rope_single(xq, freqs_cis), _apply_rope_single(xk, freqs_cis) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256, output_size: int | None = None): + super().__init__() + if output_size is None: + output_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, output_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) + return self.mlp(t_freq) + + +class JointAttention(nn.Module): + def __init__(self, dim: int, n_heads: int, n_kv_heads: int | None = None, qk_norm: bool = True): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.head_dim = dim // n_heads + self.n_rep = n_heads // self.n_kv_heads + self.qkv = nn.Linear(dim, (n_heads + 2 * self.n_kv_heads) * self.head_dim, bias=False) + self.out = nn.Linear(n_heads * self.head_dim, dim, bias=False) + + if qk_norm: + self.q_norm = nn.RMSNorm(self.head_dim, elementwise_affine=True) + self.k_norm = nn.RMSNorm(self.head_dim, elementwise_affine=True) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def forward(self, x: torch.Tensor, mask: torch.Tensor | None, freqs_cis: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, _ = x.shape + qkv = self.qkv(x) + xq, xk, xv = torch.split( + qkv, + [ + self.n_heads * self.head_dim, + self.n_kv_heads * self.head_dim, + self.n_kv_heads * self.head_dim, + ], + dim=-1, + ) + + xq = self.q_norm(xq.view(batch_size, sequence_length, self.n_heads, self.head_dim)) + xk = self.k_norm(xk.view(batch_size, sequence_length, self.n_kv_heads, self.head_dim)) + xv = xv.view(batch_size, sequence_length, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rope(xq, xk, freqs_cis) + xq = xq.transpose(1, 2) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + + if self.n_rep > 1: + xk = xk.unsqueeze(2).repeat(1, 1, self.n_rep, 1, 1).flatten(1, 2) + xv = xv.unsqueeze(2).repeat(1, 1, self.n_rep, 1, 1).flatten(1, 2) + + out = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask) + return self.out(out.transpose(1, 2).reshape(batch_size, sequence_length, -1)) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, inner_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, inner_dim, bias=False) + self.w2 = nn.Linear(inner_dim, dim, bias=False) + self.w3 = nn.Linear(dim, inner_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class JointTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int | None, + ffn_hidden_dim: int, + norm_eps: float, + qk_norm: bool, + modulation: bool = True, + z_image_modulation: bool = False, + ): + super().__init__() + self.modulation = modulation + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm) + self.feed_forward = FeedForward(dim, ffn_hidden_dim) + self.attention_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) + self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) + self.attention_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) + self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) + + if modulation: + if z_image_modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, 256), 4 * dim, bias=True)) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(min(dim, 1024), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None, + freqs_cis: torch.Tensor, + adaln_input: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + self.attention(self.attention_norm1(x) * (1 + scale_msa.unsqueeze(1)), mask, freqs_cis) + ) + x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + self.feed_forward(self.ffn_norm1(x) * (1 + scale_mlp.unsqueeze(1))) + ) + return x + + x = x + self.attention_norm2(self.attention(self.attention_norm1(x), mask, freqs_cis)) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + return x + + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int = 8): + super().__init__() + self.max_freqs = max_freqs + self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size_input)) + self._pos_cache: dict[tuple[int, str, str], torch.Tensor] = {} + + def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + cache_key = (patch_size, str(device), str(dtype)) + if cache_key in self._pos_cache: + return self._pos_cache[cache_key] + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + coeffs = (1 + freqs_x * freqs_y) ** -1 + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2) + self._pos_cache[cache_key] = dct + return dct + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + batch_size, patch_area, channels = inputs.shape + _ = channels + patch_size = int(patch_area**0.5) + dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype).repeat(batch_size, 1, 1) + return self.embedder(torch.cat((inputs, dct), dim=-1)) + + +class PixelResBlock(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 3 * channels, bias=True), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + hidden = self.in_ln(x) * (1 + scale) + shift + return x + gate * self.mlp(hidden) + + +class DCTFinalLayer(nn.Module): + def __init__(self, model_channels: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + ): + super().__init__() + self.cond_embed = nn.Linear(z_channels, model_channels) + self.input_embedder = NerfEmbedder(in_channels, model_channels, max_freqs) + self.res_blocks = nn.ModuleList([PixelResBlock(model_channels) for _ in range(num_res_blocks)]) + self.final_layer = DCTFinalLayer(model_channels, out_channels) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + x = self.input_embedder(x) + y = self.cond_embed(c).unsqueeze(1) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x) + + +def pad_to_patch_size(x: torch.Tensor, patch_size: tuple[int, int]) -> torch.Tensor: + _, _, height, width = x.shape + patch_h, patch_w = patch_size + pad_h = (patch_h - height % patch_h) % patch_h + pad_w = (patch_w - width % patch_w) % patch_w + if pad_h or pad_w: + x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0) + return x + + +def pad_zimage(feats: torch.Tensor, pad_token: torch.Tensor, pad_tokens_multiple: int) -> tuple[torch.Tensor, int]: + pad_extra = (-feats.shape[1]) % pad_tokens_multiple + if pad_extra > 0: + feats = torch.cat( + [ + feats, + pad_token.to(device=feats.device, dtype=feats.dtype).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1), + ], + dim=1, + ) + return feats, pad_extra + + +class NextDiTPixelSpace(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): # type: ignore[misc] + @register_to_config + def __init__( + self, + patch_size: int = 32, + in_channels: int = 3, + dim: int = 3840, + n_layers: int = 30, + n_refiner_layers: int = 2, + n_heads: int = 30, + n_kv_heads: int | None = None, + ffn_hidden_dim: int = 10240, + norm_eps: float = 1e-5, + qk_norm: bool = True, + cap_feat_dim: int = 2560, + axes_dims: list[int] | None = None, + axes_lens: list[int] | None = None, + rope_theta: float = 256.0, + time_scale: float = 1000.0, + pad_tokens_multiple: int | None = 128, + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + ): + super().__init__() + axes_dims = axes_dims or [32, 48, 48] + axes_lens = axes_lens or [1536, 512, 512] + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.dim = dim + self.n_heads = n_heads + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.x_embedder = nn.Linear(patch_size * patch_size * in_channels, dim, bias=True) + self.noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + dim, + n_heads, + n_kv_heads, + ffn_hidden_dim, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + ) + for _ in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + JointTransformerBlock( + dim, + n_heads, + n_kv_heads, + ffn_hidden_dim, + norm_eps, + qk_norm, + modulation=False, + ) + for _ in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder( + hidden_size=min(dim, 1024), + frequency_embedding_size=256, + output_size=256, + ) + self.cap_embedder = nn.Sequential( + nn.RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + self.layers = nn.ModuleList( + [ + JointTransformerBlock( + dim, + n_heads, + n_kv_heads, + ffn_hidden_dim, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + ) + for _ in range(n_layers) + ] + ) + + if pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + assert dim // n_heads == sum(axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=list(axes_dims)) + dec_in_ch = patch_size**2 * in_channels + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + ) + self.register_buffer("__x0__", torch.tensor([])) + + def embed_cap(self, cap_feats: torch.Tensor, offset: float = 0, bsz: int = 1, device=None, dtype=None): + _ = dtype + cap_feats = self.cap_embedder(cap_feats) + cap_feats_len = cap_feats.shape[1] + if self.pad_tokens_multiple is not None: + cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple) + + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset + freqs_cis = self.rope_embedder(cap_pos_ids).movedim(1, 2) + return cap_feats, freqs_cis, cap_feats_len + + def pos_ids_x(self, start_t: float, h_tokens: int, w_tokens: int, batch_size: int, device): + x_pos_ids = torch.zeros((batch_size, h_tokens * w_tokens, 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = start_t + x_pos_ids[:, :, 1] = ( + torch.arange(h_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, w_tokens).flatten() + ) + x_pos_ids[:, :, 2] = ( + torch.arange(w_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(h_tokens, 1).flatten() + ) + return x_pos_ids + + def unpatchify(self, x: torch.Tensor, img_size: list[tuple[int, int]], cap_size: list[int]) -> torch.Tensor: + patch_h = patch_w = self.patch_size + imgs = [] + for index in range(x.size(0)): + height, width = img_size[index] + begin = cap_size[index] + end = begin + (height // patch_h) * (width // patch_w) + imgs.append( + x[index][begin:end] + .view(height // patch_h, width // patch_w, patch_h, patch_w, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + return torch.stack(imgs, dim=0) + + def _forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, num_tokens: int | None) -> torch.Tensor: + _ = num_tokens + timesteps = 1.0 - timesteps + _, _, original_h, original_w = x.shape + x = pad_to_patch_size(x, (self.patch_size, self.patch_size)) + batch_size, channels, height, width = x.shape + t_emb = self.t_embedder(timesteps * self.time_scale, dtype=x.dtype) + patch_h = patch_w = self.patch_size + + pixel_patches = ( + x.view(batch_size, channels, height // patch_h, patch_h, width // patch_w, patch_w) + .permute(0, 2, 4, 3, 5, 1) + .flatten(3) + .flatten(1, 2) + ) + num_image_tokens = pixel_patches.shape[1] + pixel_values = pixel_patches.reshape(batch_size * num_image_tokens, 1, patch_h * patch_w * channels) + + cap_feats_emb, cap_freqs_cis, _ = self.embed_cap(context, offset=0, bsz=batch_size, device=x.device, dtype=x.dtype) + + x_tokens = self.x_embedder( + x.view(batch_size, channels, height // patch_h, patch_h, width // patch_w, patch_w) + .permute(0, 2, 4, 3, 5, 1) + .flatten(3) + .flatten(1, 2) + ) + cap_feats_len_total = cap_feats_emb.shape[1] + x_pos_ids = self.pos_ids_x(cap_feats_len_total + 1, height // patch_h, width // patch_w, batch_size, x.device) + if self.pad_tokens_multiple is not None: + x_tokens, x_pad_extra = pad_zimage(x_tokens, self.x_pad_token, self.pad_tokens_multiple) + x_pos_ids = F.pad(x_pos_ids, (0, 0, 0, x_pad_extra)) + x_freqs_cis = self.rope_embedder(x_pos_ids).movedim(1, 2) + + for layer in self.context_refiner: + cap_feats_emb = layer(cap_feats_emb, None, cap_freqs_cis) + + for layer in self.noise_refiner: + x_tokens = layer(x_tokens, None, x_freqs_cis, t_emb) + + padded_full_embed = torch.cat([cap_feats_emb, x_tokens], dim=1) + full_freqs_cis = torch.cat([cap_freqs_cis, x_freqs_cis], dim=1) + img_len = x_tokens.shape[1] + cap_size = [padded_full_embed.shape[1] - img_len] * batch_size + + hidden_states = padded_full_embed + for layer in self.layers: + hidden_states = layer(hidden_states, None, full_freqs_cis.to(hidden_states.device), t_emb) + + img_hidden = hidden_states[:, cap_size[0] : cap_size[0] + num_image_tokens, :] + decoder_cond = img_hidden.reshape(batch_size * num_image_tokens, self.dim) + output = self.dec_net(pixel_values, decoder_cond).reshape(batch_size, num_image_tokens, -1) + cap_placeholder = torch.zeros( + batch_size, + cap_size[0], + output.shape[-1], + device=output.device, + dtype=output.dtype, + ) + img_size = [(height, width)] * batch_size + img_out = self.unpatchify(torch.cat([cap_placeholder, output], dim=1), img_size, cap_size) + return -img_out[:, :, :original_h, :original_w] + + def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, num_tokens: int | None = None): + x0_pred = self._forward(x, timesteps, context, num_tokens) + return (x - x0_pred) / timesteps.view(-1, 1, 1, 1) + + +def remap_checkpoint_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + new_sd: dict[str, torch.Tensor] = {} + qkv_parts: dict[str, dict[str, torch.Tensor]] = {} + + for key, value in state_dict.items(): + if key == "__x0__": + new_sd[key] = value + continue + if ".attention.to_q." in key: + prefix = key.split(".attention.to_q.")[0] + suffix = key.split(".attention.to_q.")[1] + qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["q"] = value + continue + if ".attention.to_k." in key: + prefix = key.split(".attention.to_k.")[0] + suffix = key.split(".attention.to_k.")[1] + qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["k"] = value + continue + if ".attention.to_v." in key: + prefix = key.split(".attention.to_v.")[0] + suffix = key.split(".attention.to_v.")[1] + qkv_parts.setdefault(f"{prefix}.attention.{suffix}", {})["v"] = value + continue + if ".attention.to_out.0." in key: + new_sd[key.replace(".attention.to_out.0.", ".attention.out.")] = value + continue + if ".attention.norm_q." in key: + new_sd[key.replace(".attention.norm_q.", ".attention.q_norm.")] = value + continue + if ".attention.norm_k." in key: + new_sd[key.replace(".attention.norm_k.", ".attention.k_norm.")] = value + continue + new_sd[key] = value + + for combined_key, parts in qkv_parts.items(): + if not {"q", "k", "v"}.issubset(parts): + raise ValueError(f"Incomplete QKV for {combined_key}: got {list(parts.keys())}") + prefix = combined_key.rsplit(".", 1)[0] + suffix = combined_key.rsplit(".", 1)[1] + new_sd[f"{prefix}.qkv.{suffix}"] = torch.cat([parts["q"], parts["k"], parts["v"]], dim=0) + + return new_sd + + +def default_axes_dims_for_head_dim(head_dim: int) -> list[int]: + if head_dim % 8 == 0: + first = head_dim // 4 + second = (head_dim * 3) // 8 + third = head_dim - first - second + if first % 2 == 0 and second % 2 == 0 and third % 2 == 0: + return [first, second, third] + if head_dim == 128: + return [32, 48, 48] + if head_dim % 6 == 0: + split = head_dim // 3 + return [split, split, split] + raise ValueError(f"Unable to infer axes_dims for head_dim={head_dim}; provide axes_dims explicitly") + + +def _count_module_blocks(state_dict: dict[str, torch.Tensor], prefix: str) -> int: + indices = set() + needle = f"{prefix}." + prefix_parts = prefix.count(".") + 1 + for key in state_dict: + if not key.startswith(needle): + continue + parts = key.split(".") + if len(parts) <= prefix_parts: + continue + try: + indices.add(int(parts[prefix_parts])) + except ValueError: + continue + return max(indices) + 1 if indices else 0 + + +def infer_model_config( + state_dict: dict[str, torch.Tensor], + axes_dims: list[int] | None = None, + axes_lens: list[int] | None = None, + rope_theta: float = 256.0, + time_scale: float = 1000.0, + pad_tokens_multiple: int | None = 128, + norm_eps: float = 1e-5, +) -> dict: + dim = state_dict["x_embedder.weight"].shape[0] + patch_channels = state_dict["x_embedder.weight"].shape[1] + in_channels = 3 if patch_channels % 3 == 0 else 1 + patch_size = int(round((patch_channels / in_channels) ** 0.5)) + head_dim = state_dict["layers.0.attention.q_norm.weight"].shape[0] + n_heads = dim // head_dim + total_proj_heads = state_dict["layers.0.attention.qkv.weight"].shape[0] // head_dim + n_kv_heads = (total_proj_heads - n_heads) // 2 + if n_kv_heads == n_heads: + n_kv_heads = None + qk_norm = "layers.0.attention.q_norm.weight" in state_dict + decoder_input = state_dict["dec_net.input_embedder.embedder.0.weight"].shape[1] + decoder_max_freqs = int(round((decoder_input - patch_channels) ** 0.5)) + config = { + "patch_size": patch_size, + "in_channels": in_channels, + "dim": dim, + "n_layers": _count_module_blocks(state_dict, "layers"), + "n_refiner_layers": _count_module_blocks(state_dict, "noise_refiner"), + "n_heads": n_heads, + "n_kv_heads": n_kv_heads, + "ffn_hidden_dim": state_dict["layers.0.feed_forward.w1.weight"].shape[0], + "norm_eps": norm_eps, + "qk_norm": qk_norm, + "cap_feat_dim": state_dict["cap_embedder.1.weight"].shape[1], + "axes_dims": axes_dims or default_axes_dims_for_head_dim(head_dim), + "axes_lens": axes_lens or [1536, 512, 512], + "rope_theta": rope_theta, + "time_scale": time_scale, + "pad_tokens_multiple": pad_tokens_multiple, + "decoder_hidden_size": state_dict["dec_net.cond_embed.weight"].shape[0], + "decoder_num_res_blocks": _count_module_blocks(state_dict, "dec_net.res_blocks"), + "decoder_max_freqs": decoder_max_freqs, + } + return config + + +def resolve_checkpoint_path( + pretrained_model_name_or_path: str, + checkpoint_filename: str | None = None, + cache_dir: str | None = None, + local_files_only: bool | None = None, +) -> str: + if os.path.isfile(pretrained_model_name_or_path): + return pretrained_model_name_or_path + + if os.path.isdir(pretrained_model_name_or_path): + if checkpoint_filename is not None: + candidate = os.path.join(pretrained_model_name_or_path, checkpoint_filename) + if os.path.isfile(candidate): + return candidate + candidates = [ + os.path.join(pretrained_model_name_or_path, name) + for name in os.listdir(pretrained_model_name_or_path) + if name.lower().endswith(".safetensors") + ] + if len(candidates) == 1: + return candidates[0] + if not candidates: + raise FileNotFoundError(f"No .safetensors checkpoint found in {pretrained_model_name_or_path}") + raise FileNotFoundError( + f"Multiple .safetensors checkpoints found in {pretrained_model_name_or_path}; provide checkpoint_filename" + ) + + filename = checkpoint_filename + if filename is None: + repo_files = [name for name in list_repo_files(pretrained_model_name_or_path) if name.lower().endswith(".safetensors")] + if not repo_files and DEFAULT_CHECKPOINT_FILENAME: + repo_files = [DEFAULT_CHECKPOINT_FILENAME] + if len(repo_files) == 1: + filename = repo_files[0] + elif DEFAULT_CHECKPOINT_FILENAME in repo_files: + filename = DEFAULT_CHECKPOINT_FILENAME + else: + raise FileNotFoundError( + f"Unable to determine checkpoint filename for {pretrained_model_name_or_path}; provide checkpoint_filename" + ) + + return hf_hub_download( + repo_id=pretrained_model_name_or_path, + filename=filename, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + + +def load_zetachroma_transformer( + pretrained_model_name_or_path: str, + checkpoint_filename: str | None = None, + torch_dtype: torch.dtype | None = None, + cache_dir: str | None = None, + local_files_only: bool | None = None, + model_config: dict | None = None, + axes_dims: list[int] | None = None, + axes_lens: list[int] | None = None, + rope_theta: float = 256.0, + time_scale: float = 1000.0, + pad_tokens_multiple: int | None = 128, +) -> NextDiTPixelSpace: + checkpoint_path = resolve_checkpoint_path( + pretrained_model_name_or_path, + checkpoint_filename=checkpoint_filename, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + raw_state_dict = load_file(checkpoint_path, device="cpu") + state_dict = remap_checkpoint_keys(raw_state_dict) + del raw_state_dict + config = infer_model_config( + state_dict, + axes_dims=axes_dims, + axes_lens=axes_lens, + rope_theta=rope_theta, + time_scale=time_scale, + pad_tokens_multiple=pad_tokens_multiple, + ) + if model_config is not None: + config.update(model_config) + + model = NextDiTPixelSpace(**config) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + ignored_missing = {"__x0__"} + relevant_missing = [key for key in missing if key not in ignored_missing] + if relevant_missing or unexpected: + raise ValueError( + "Zeta-Chroma checkpoint load mismatch: " + f"missing={relevant_missing[:20]} unexpected={unexpected[:20]}" + ) + if torch_dtype is not None: + model = model.to(dtype=torch_dtype) + return model diff --git a/pipelines/zetachroma/pipeline.py b/pipelines/zetachroma/pipeline.py new file mode 100644 index 000000000..b56e14a90 --- /dev/null +++ b/pipelines/zetachroma/pipeline.py @@ -0,0 +1,302 @@ +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import DiffusionPipeline +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, Qwen3ForCausalLM + +from .model import DEFAULT_TEXT_ENCODER_REPO, NextDiTPixelSpace, load_zetachroma_transformer + + +logger = logging.getLogger(__name__) + + +@dataclass +class ZetaChromaPipelineOutput(BaseOutput): + images: Union[List[Image.Image], np.ndarray, torch.Tensor] + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> torch.Tensor: + timesteps = torch.linspace(1, 0, num_steps + 1) + if shift: + mu = base_shift + (image_seq_len / 4096) * (max_shift - base_shift) + timesteps = math.exp(mu) / (math.exp(mu) + (1.0 / timesteps.clamp(min=1e-9) - 1.0)) + timesteps[0] = 1.0 + timesteps[-1] = 0.0 + return timesteps + + +class ZetaChromaPipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->dit_model" + + dit_model: NextDiTPixelSpace + text_encoder: PreTrainedModel + tokenizer: PreTrainedTokenizerBase + _progress_bar_config: Dict[str, Any] + + def __init__(self, dit_model: NextDiTPixelSpace, text_encoder: PreTrainedModel, tokenizer: PreTrainedTokenizerBase): + super().__init__() + DiffusionPipeline.register_modules(self, dit_model=dit_model, text_encoder=text_encoder, tokenizer=tokenizer) + if hasattr(self.text_encoder, "requires_grad_"): + self.text_encoder.requires_grad_(False) + if hasattr(self.dit_model, "requires_grad_"): + self.dit_model.requires_grad_(False) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + dit_model: Optional[NextDiTPixelSpace] = None, + text_encoder: Optional[PreTrainedModel] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + text_encoder_repo: str = DEFAULT_TEXT_ENCODER_REPO, + checkpoint_filename: Optional[str] = None, + cache_dir: Optional[str] = None, + trust_remote_code: bool = True, + torch_dtype: Optional[torch.dtype] = None, + local_files_only: Optional[bool] = None, + model_config: Optional[dict] = None, + axes_dims: Optional[list[int]] = None, + axes_lens: Optional[list[int]] = None, + rope_theta: float = 256.0, + time_scale: float = 1000.0, + pad_tokens_multiple: int | None = 128, + **kwargs, + ): + _ = kwargs + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained( + text_encoder_repo, + subfolder="tokenizer", + cache_dir=cache_dir, + trust_remote_code=trust_remote_code, + local_files_only=local_files_only, + ) + if text_encoder is None: + text_encoder = Qwen3ForCausalLM.from_pretrained( + text_encoder_repo, + subfolder="text_encoder", + cache_dir=cache_dir, + trust_remote_code=trust_remote_code, + local_files_only=local_files_only, + torch_dtype=torch_dtype, + ) + if dit_model is None: + dit_model = load_zetachroma_transformer( + pretrained_model_name_or_path, + checkpoint_filename=checkpoint_filename, + torch_dtype=torch_dtype, + cache_dir=cache_dir, + local_files_only=local_files_only, + model_config=model_config, + axes_dims=axes_dims, + axes_lens=axes_lens, + rope_theta=rope_theta, + time_scale=time_scale, + pad_tokens_multiple=pad_tokens_multiple, + ) + pipe = cls(dit_model=dit_model, text_encoder=text_encoder, tokenizer=tokenizer) + if torch_dtype is not None: + pipe.to(torch_dtype=torch_dtype) + return pipe + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs + + def progress_bar(self, iterable=None, **kwargs): # pylint: disable=arguments-differ + self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {} + config = {**self._progress_bar_config, **kwargs} + return tqdm(iterable, **config) + + def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False): # pylint: disable=arguments-differ + _ = silence_dtype_warnings + if hasattr(self, "text_encoder"): + self.text_encoder.to(device=torch_device, dtype=torch_dtype) + if hasattr(self, "dit_model"): + self.dit_model.to(device=torch_device, dtype=torch_dtype) + return self + + def _encode_single_prompt( + self, + prompt: str, + device: torch.device, + dtype: torch.dtype, + max_sequence_length: int, + ) -> torch.Tensor: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + inputs = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + max_length=max_sequence_length, + truncation=True, + ).to(device) + + outputs = self.text_encoder( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + output_hidden_states=True, + ) + hidden = outputs.hidden_states[-2].to(dtype) + hidden = hidden * inputs.attention_mask.unsqueeze(-1).to(dtype) + actual_len = int(inputs.attention_mask.sum(dim=1).max().item()) + return hidden[:, :actual_len, :] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(prompt, str): + prompt = [prompt] + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) + elif len(negative_prompt) != len(prompt): + raise ValueError("negative_prompt must have the same batch length as prompt") + + device = device or self._execution_device + dtype = dtype or next(self.text_encoder.parameters()).dtype + + positive = [self._encode_single_prompt(text, device, dtype, max_sequence_length) for text in prompt] + negative = [self._encode_single_prompt(text, device, dtype, max_sequence_length) for text in negative_prompt] + max_len = max(max(tensor.shape[1] for tensor in positive), max(tensor.shape[1] for tensor in negative)) + + def pad_sequence(sequence: torch.Tensor) -> torch.Tensor: + if sequence.shape[1] == max_len: + return sequence + return torch.cat( + [ + sequence, + torch.zeros( + sequence.shape[0], + max_len - sequence.shape[1], + sequence.shape[2], + dtype=sequence.dtype, + device=sequence.device, + ), + ], + dim=1, + ) + + prompt_embeds = torch.cat([pad_sequence(tensor) for tensor in positive], dim=0) + negative_embeds = torch.cat([pad_sequence(tensor) for tensor in negative], dim=0) + return prompt_embeds, negative_embeds + + def _postprocess_images(self, images: torch.Tensor, output_type: str): + images = images.float().clamp(-1, 1) * 0.5 + 0.5 + if output_type == "pt": + return images + + images = images.permute(0, 2, 3, 1).cpu().numpy() + if output_type == "np": + return images + if output_type == "pil": + return self.numpy_to_pil((images * 255).round().clip(0, 255).astype(np.uint8)) + raise ValueError(f"Unsupported output_type: {output_type}") + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 1024, + width: Optional[int] = 1024, + num_inference_steps: int = 30, + guidance_scale: float = 0.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 512, + shift_schedule: bool = True, + base_shift: float = 0.5, + max_shift: float = 1.15, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end=None, + **kwargs, + ): + _ = kwargs + height = height or 1024 + width = width or 1024 + dtype = dtype or next(self.dit_model.parameters()).dtype + device = self._execution_device + + prompt_embeds, negative_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_embeds = negative_embeds.repeat_interleave(num_images_per_prompt, dim=0) + batch_size = prompt_embeds.shape[0] + + latents = randn_tensor((batch_size, 3, height, width), generator=generator, device=device, dtype=dtype) + image_seq_len = (height // self.dit_model.patch_size) * (width // self.dit_model.patch_size) + sigmas = get_schedule( + num_inference_steps, + image_seq_len, + base_shift=base_shift, + max_shift=max_shift, + shift=shift_schedule, + ).to(device=device, dtype=dtype) + + self.dit_model.eval() + do_classifier_free_guidance = guidance_scale > 1.0 + for step, timestep in enumerate(self.progress_bar(range(num_inference_steps), desc="Sampling")): + _ = timestep + t_curr = sigmas[step] + t_next = sigmas[step + 1] + t_batch = torch.full((batch_size,), t_curr, device=device, dtype=dtype) + + if do_classifier_free_guidance: + latent_input = torch.cat([latents, latents], dim=0) + timestep_input = torch.cat([t_batch, t_batch], dim=0) + context_input = torch.cat([prompt_embeds, negative_embeds], dim=0) + velocity = self.dit_model(latent_input, timestep_input, context_input, context_input.shape[1]) + v_cond, v_uncond = velocity.chunk(2, dim=0) + velocity = v_uncond + guidance_scale * (v_cond - v_uncond) + else: + velocity = self.dit_model(latents, t_batch, prompt_embeds, prompt_embeds.shape[1]) + + latents = latents + (t_next - t_curr) * velocity + + if callback_on_step_end is not None: + callback_kwargs = { + "latents": latents, + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_embeds, + } + callback_result = callback_on_step_end(self, step, int(t_curr.item() * 1000), callback_kwargs) + if isinstance(callback_result, dict) and "latents" in callback_result: + latents = callback_result["latents"] + + images = self._postprocess_images(latents, output_type=output_type) + if not return_dict: + return (images,) + return ZetaChromaPipelineOutput(images=images)