There was an issue where the resample split was too early and dropped one
of the rolling convolutions a frame early. This is most noticable as a
lighting/color change between pixel frames 5->6 (latent 2->3), or as a
lighting change between the first and last frame in an FLF wan flow.
* sd: soft_empty_cache on tiler fallback
This doesnt cost a lot and creates the expected VRAM reduction in
resource monitors when you fallback to tiler.
* wan: vae: Don't recursion in local fns (move run_up)
Moved Decoder3d’s recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
* ltx: vae: Don't recursion in local fns (move run_up)
Mov the recursive run_up out of forward into a class
method to avoid nested closure self-reference cycles. This avoids
cyclic garbage that delays garbage of tensors which in turn delays
VRAM release before tiled fallback.
* ltx: vae: add cache state to downsample block
* ltx: vae: Add time stride awareness to causal_conv_3d
* ltx: vae: Automate truncation for encoder
Other VAEs just truncate without error. Do the same.
* sd/ltx: Make chunked_io a flag in its own right
Taking this bi-direcitonal, so make it a for-purpose named flag.
* ltx: vae: implement chunked encoder + CPU IO chunking
People are doing things with big frame counts in LTX including V2V
flows. Implement the time-chunked encoder to keep the VRAM down, with
the converse of the new CPU pre-allocation technique, where the chunks
are brought from the CPU JIT.
* ltx: vae-encode: round chunk sizes more strictly
Only powers of 2 and multiple of 8 are valid due to cache slicing.
* wan: vae: encoder: Add feature cache layer that corks singles
If a downsample only gives you a single frame, save it to the feature
cache and return nothing to the top level. This increases the
efficiency of cacheability, but also prepares support for going two
by two rather than four by four on the frames.
* wan: remove all concatentation with the feature cache
The loopers are now responsible for ensuring that non-final frames are
processes at least two-by-two, elimiating the need for this cat case.
* wan: vae: recurse and chunk for 2+2 frames on decode
Avoid having to clone off slices of 4 frame chunks and reduce the size
of the big 6 frame convolutions down to 4. Save the VRAMs.
* wan: encode frames 2x2.
Reduce VRAM usage greatly by encoding frames 2 at a time rather than
4.
* wan: vae: remove cloning
The loopers now control the chunking such there is noever more than 2
frames, so just cache these slices directly and avoid the clone
allocations completely.
* wan: vae: free consumer caller tensors on recursion
* wan: vae: restyle a little to match LTX
* ltx: vae: scale the chunk size with the users VRAM
Scale this linearly down for users with low VRAM.
* ltx: vae: free non-chunking recursive intermediates
* ltx: vae: cleanup some intermediates
The conv layer can be the VRAM peak and it does a torch.cat. So cleanup
the pieces of the cat. Also clear our the cache ASAP as each layer detect
its end as this VAE surges in VRAM at the end due to the ended padding
increasing the size of the final frame convolutions off-the-books to
the chunker. So if all the earlier layers free up their cache it can
offset that surge.
Its a fragmentation nightmare, and the chance of it having to recache the
pyt allocator is very high, but you wont OOM.
Pytorch only filters for OOMs in its own allocators however there are
paths that can OOM on allocators made outside the pytorch allocators.
These manifest as an AllocatorError as pytorch does not have universal
error translation to its OOM type on exception. Handle it. A log I have
for this also shows a double report of the error async, so call the
async discarder to cleanup and make these OOMs look like OOMs.
* draft zeta (z-image pixel space)
* revert gitignore
* model loaded and able to run however vector direction still wrong tho
* flip the vector direction to original again this time
* Move wrongly positioned Z image pixel space class
* inherit Radiance LatentFormat class
* Fix parameters in classes for Zeta x0 dino
* remove arbitrary nn.init instances
* Remove unused import of lru_cache
---------
Co-authored-by: silveroxides <ishimarukaito@gmail.com>
Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
* Fix bypass dtype/device moving
* Force offloading mode for training
* training context var
* offloading implementation in training node
* fix wrong input type
* Support bypass load lora model, correct adapter/offloading handling
The code throughout is None safe to just skip the feature cache saving
step if none. Set it none in single frame use so qwen doesn't burn VRAM
on the unused cache.
* ops: introduce autopad for conv3d
This works around pytorch missing ability to causal pad as part of the
kernel and avoids massive weight duplications for padding.
* wan-vae: rework causal padding
This currently uses F.pad which takes a full deep copy and is liable to
be the VRAM peak. Instead, kick spatial padding back to the op and
consolidate the temporal padding with the cat for the cache.
* wan-vae: implement zero pad fast path
The WAN VAE is also QWEN where it is used single-image. These
convolutions are however zero padded 3d convolutions, which means the
VAE is actually just 2D down the last element of the conv weight in
the temporal dimension. Fast path this, to avoid adding zeros that
then just evaporate in convoluton math but cost computation.