add builtin framepack

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4039/head
Vladimir Mandic 2025-07-08 15:47:07 -04:00
parent 239c3d6dd9
commit c559e26616
53 changed files with 3996 additions and 232 deletions

View File

@ -49,15 +49,12 @@
"node/shebang": "off"
},
"globals": {
// asssets
"panzoom": "readonly",
// logger.js
"log": "readonly",
"debug": "readonly",
"error": "readonly",
"xhrGet": "readonly",
"xhrPost": "readonly",
// script.js
"gradioApp": "readonly",
"executeCallbacks": "readonly",
"onAfterUiUpdate": "readonly",
@ -73,11 +70,8 @@
"getUICurrentTabContent": "readonly",
"waitForFlag": "readonly",
"logFn": "readonly",
// contextmenus.js
"generateForever": "readonly",
// contributors.js
"showContributors": "readonly",
// ui.js
"opts": "writable",
"sortUIElements": "readonly",
"all_gallery_buttons": "readonly",
@ -97,40 +91,29 @@
"toggleCompact": "readonly",
"setFontSize": "readonly",
"setTheme": "readonly",
// settings.js
"registerDragDrop": "readonly",
// extraNetworks.js
"getENActiveTab": "readonly",
"quickApplyStyle": "readonly",
"quickSaveStyle": "readonly",
"setupExtraNetworks": "readonly",
"showNetworks": "readonly",
// from python
"localization": "readonly",
// progressbar.js
"randomId": "readonly",
"requestProgress": "readonly",
"setRefreshInterval": "readonly",
// imageviewer.js
"modalPrevImage": "readonly",
"modalNextImage": "readonly",
"galleryClickEventHandler": "readonly",
"getExif": "readonly",
// logMonitor.js
"jobStatusEl": "readonly",
// loader.js
"removeSplash": "readonly",
// nvml.js
"initNVML": "readonly",
"disableNVML": "readonly",
// indexdb.js
"idbGet": "readonly",
"idbPut": "readonly",
"idbDel": "readonly",
"idbAdd": "readonly",
// changelog.js
"initChangelog": "readonly",
// notification.js
"sendNotification": "readonly"
},
"ignorePatterns": [

View File

@ -26,7 +26,12 @@ repos:
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
args: ["--allow-multiple-documents"]
- id: check-builtin-literals
- id: check-case-conflict
- id: check-json
- id: check-symlinks
- id: check-toml
- id: check-xml
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace

View File

@ -23,6 +23,7 @@ ignore-paths=/usr/lib/.*$,
modules/hijack/ddpm_edit.py,
modules/intel,
modules/intel/ipex,
modules/framepack/pipeline,
modules/ldsr,
modules/onnx_impl,
modules/pag,

3
.vscode/launch.json vendored
View File

@ -16,8 +16,7 @@
"--docs",
"--api-log",
"--log", "vscode.log",
"${command:pickArgs}",
]
"${command:pickArgs}"]
}
]
}

View File

@ -1,10 +1,5 @@
{
"python.analysis.extraPaths": [
".",
"./modules",
"./scripts",
"./pipelines",
],
"python.analysis.extraPaths": [".", "./modules", "./scripts", "./pipelines"],
"python.analysis.typeCheckingMode": "off",
"editor.formatOnSave": false,
"python.REPL.enableREPLSmartSend": false

View File

@ -38,6 +38,7 @@ Although upgrades and existing installations are tested and should work fine!
- Support **FLUX.1** all-in-one safetensors
- Support **TAESD** preview and remote VAE for **HunyuanDit**
- Support for [Gemma 3n](https://huggingface.co/google/gemma-3n-E4B-it) E2B and E4B LLM/VLM models in **prompt enhance** and process **captioning**
- **FramePack** support is now fully integrated instead of being a separate extension
- **UI**
- major update to modernui layout
- redesign of the Flat UI theme

View File

@ -4,13 +4,10 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma
## Current
- Bug: FramePack with SQND
## Future Candidates
- Feature: Common repo for `T5` and `CLiP`
- Feature: LoRA add OMI format support for SD35/FLUX.1
- Feature: Merge FramePack into core
- Refactor: sampler options
- Video: API support

View File

@ -488,7 +488,6 @@ function setupExtraNetworksForTab(tabname) {
en.style.top = '13em';
en.style.transition = 'width 0.3s ease';
en.style.zIndex = 100;
// gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width = `${100 - 2 - window.opts.extra_networks_sidebar_width}vw`;
gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width = `calc(100vw - 2em - min(${window.opts.extra_networks_sidebar_width}vw, 50vw))`;
} else {
en.style.position = 'relative';
@ -506,6 +505,7 @@ function setupExtraNetworksForTab(tabname) {
if (window.opts.extra_networks_card_cover === 'sidebar') en.style.width = 0;
gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width = 'unset';
}
if (tabname === 'video') gradioApp().getElementById('framepack_settings').parentNode.style.width = gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width;
}
});
intersectionObserver.observe(en); // monitor visibility

View File

@ -264,6 +264,25 @@ function submit_video(...args) {
return res;
}
function submit_framepack(...args) {
const id = randomId();
log('submitFramepack', id);
requestProgress(id, null, null);
window.submit_state = '';
args[0] = id;
return args;
}
function submit_video_wrapper(...args) {
log('submitVideoWrapper', args);
if (!args || args.length === 0) {
log('submitVideoWrapper: no args');
return;
}
const btn = gradioApp().getElementById(`${args[0]}_generate_btn`);
if (btn) btn.click();
}
function submit_postprocessing(...args) {
log('SubmitExtras');
clearGallery('extras');

View File

@ -58,8 +58,8 @@ _ACT_LAYER_ME = dict(
hard_sigmoid=HardSigmoidMe
)
_OVERRIDE_FN = dict()
_OVERRIDE_LAYER = dict()
_OVERRIDE_FN = {}
_OVERRIDE_LAYER = {}
def add_override_act_fn(name, fn):
@ -75,7 +75,7 @@ def update_override_act_fn(overrides):
def clear_override_act_fn():
global _OVERRIDE_FN
_OVERRIDE_FN = dict()
_OVERRIDE_FN = {}
def add_override_act_layer(name, fn):
@ -90,7 +90,7 @@ def update_override_act_layer(overrides):
def clear_override_act_layer():
global _OVERRIDE_LAYER
_OVERRIDE_LAYER = dict()
_OVERRIDE_LAYER = {}
def get_act_fn(name='relu'):

View File

@ -143,7 +143,7 @@ class EasyDict(dict):
__setitem__ = __setattr__
def update(self, e=None, **f):
d = e or dict()
d = e or {}
d.update(f)
for k in d:
setattr(self, k, d[k])

View File

@ -62,7 +62,7 @@ class HCounter(PDH_HCOUNTER):
itemBuffer = cast(malloc(c_size_t(bufferSize.value)), PPDH_FMT_COUNTERVALUE_ITEM_W)
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), itemBuffer) != PDH_OK:
raise PDHError("Couldn't get formatted counter array.")
result: dict[str, T] = dict()
result: dict[str, T] = {}
for i in range(0, itemCount.value):
item = itemBuffer[i]
result[item.szName] = getattr(item.FmtValue.u, attr_name)

View File

@ -189,7 +189,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
for j, a in enumerate(args):
try:
args[j] = eval(a) if isinstance(a, str) else a # eval strings
except:
except Exception:
pass
n = max(round(n * gd), 1) if n > 1 else n # depth gain

117
modules/framepack/create-video.py Executable file
View File

@ -0,0 +1,117 @@
#!/usr/bin/env python
import os
import io
import base64
import logging
import argparse
import requests
import urllib3
from PIL import Image
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
sd_username = os.environ.get('SDAPI_USR', None)
sd_password = os.environ.get('SDAPI_PWD', None)
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
log = logging.getLogger(__name__)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def auth():
if sd_username is not None and sd_password is not None:
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
return None
def get(endpoint: str, dct: dict = None):
req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
if req.status_code != 200:
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
else:
return req.json()
def post(endpoint: str, dct: dict = None):
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=None, verify=False, auth=auth())
if req.status_code != 200:
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
else:
return req.json()
def encode(f):
if not os.path.exists(f):
log.error(f'file not found: {f}')
os._exit(1)
image = Image.open(f)
if image.mode == 'RGBA':
image = image.convert('RGB')
with io.BytesIO() as stream:
image.save(stream, 'JPEG')
image.close()
values = stream.getvalue()
encoded = base64.b64encode(values).decode()
return encoded
def generate(args): # pylint: disable=redefined-outer-name
request = {
'variant': args.variant,
'prompt': args.prompt,
'section_prompt': args.sections,
'init_image': encode(args.init),
'end_image': encode(args.end) if args.end else None,
'resolution': int(args.resolution),
'duration': float(args.duration),
'mp4_fps': int(args.fps),
'seed': int(args.seed),
'steps': int(args.steps),
'shift': float(args.shift),
'cfg_scale': float(args.scale),
'cfg_rescale': float(args.rescale),
'cfg_distilled': float(args.distilled),
'use_teacache': bool(args.teacache),
'vlm_enhance': bool(args.enhance),
}
log.info(f'request: {args}')
result = post('/sdapi/v1/framepack', request) # can abandon request here and not wait for response or wait synchronously
log.info(f'response: {result}')
progress = get('/sdapi/v1/progress?skip_current_image=true', None) # monitor progress of the current task
task_id = progress.get('id', None)
log.info(f'id: {task_id}')
log.info(f'progress: {progress}')
outputs = []
history = get(f'/sdapi/v1/history?id={task_id}') # get history for the task
for event in history:
log.info(f'history: {event}')
outputs = event.get('outputs', [])
log.info(f'outputs: {outputs}') # you can download output files using /file={filename} endpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = 'api-framepack')
parser.add_argument('--init', required=True, help='init image')
parser.add_argument('--end', required=False, help='init image')
parser.add_argument('--prompt', required=False, default='', help='prompt text')
parser.add_argument('--sections', required=False, default='', help='per-section prompts')
parser.add_argument('--resolution', type=int, required=False, default=640, help='video resolution')
parser.add_argument('--duration', type=float, required=False, default=4.0, help='video duration')
parser.add_argument('--fps', type=int, required=False, default=30, help='video frames per second')
parser.add_argument('--seed', type=int, required=False, default=-1, help='random seed')
parser.add_argument('--enhance', required=False, action='store_true', help='enable prompt enhancer')
parser.add_argument('--teacache', required=False, action='store_true', help='enable teacache')
parser.add_argument('--steps', type=int, default=25, help='steps')
parser.add_argument('--scale', type=float, default=1.0, help='cfg scale')
parser.add_argument('--rescale', type=float, default=0.0, help='cfg rescale')
parser.add_argument('--distilled', type=float, default=10.0, help='cfg distilled')
parser.add_argument('--shift', type=float, default=3.0, help='sampler shift')
parser.add_argument('--variant', type=str, default='bi-directional', choices=['bi-directional', 'forward-only'], help='model variant')
args = parser.parse_args()
log.info(f'api-framepack: {args}')
generate(args)

View File

@ -0,0 +1,55 @@
#!/usr/bin/env python
import os
import logging
import argparse
import cv2
import torch
import torchvision
from safetensors.torch import safe_open
from tqdm.rich import trange
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
log = logging.getLogger("sd")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = 'framepack-cli')
parser.add_argument('--input', required=True, help='input safetensors')
parser.add_argument('--cv2', required=False, help='encode video file using cv2')
parser.add_argument('--tv', required=False, help='encode video file using torchvision')
parser.add_argument('--codec', default='libx264', help='specify video codec')
parser.add_argument('--export', required=False, help='export frames as images to folder')
parser.add_argument('--fps', default=30, help='frames-per-second')
args = parser.parse_args()
log.info(f'framepack-cli: {args}')
log.info(f'torch={torch.__version__} torchvision={torchvision.__version__}')
with safe_open(args.input, framework="pt", device="cpu") as f:
frames = f.get_tensor('frames')
metadata = f.metadata()
n, h, w, _c = frames.shape
log.info(f'file: metadata={metadata}')
log.info(f'tensor: frames={n} shape={frames.shape} dtype={frames.dtype} device={frames.device}')
fn = os.path.splitext(os.path.basename(args.input))[0]
if args.export:
log.info(f'export: folder="{args.export}" prefix="{fn}" frames={n} width={w} height={h}')
os.makedirs(args.export, exist_ok=True)
for i in trange(n):
image = cv2.cvtColor(frames[i].numpy(), cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(args.export, f'{fn}-{i:05d}.jpg'), image)
if args.cv2:
log.info(f'encode: file={args.cv2} frames={n} width={w} height={h} fps={args.fps} method=cv2')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoWriter(args.cv2, fourcc, args.fps, (w, h))
for i in trange(n):
image = cv2.cvtColor(frames[i].numpy(), cv2.COLOR_RGB2BGR)
video.write(image)
video.release()
if args.tv:
log.info(f'encode: file={args.tv} frames={n} width={w} height={h} fps={args.fps} method=tv ')
torchvision.io.write_video(args.tv, video_array=frames, fps=args.fps, video_codec=args.codec)

View File

@ -0,0 +1,131 @@
from typing import Optional, List
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
from fastapi.exceptions import HTTPException
from modules import shared
class ReqFramepack(BaseModel):
variant: str = Field(default=None, title="Model variant", description="Model variant to use")
prompt: str = Field(default=None, title="Prompt", description="Prompt for the model")
init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image")
end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image")
start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image")
end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image")
vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model")
optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section")
negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model")
seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model")
resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image")
duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds")
latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window")
steps: Optional[int] = Field(default=25, title="Steps", description="Number of steps for the model")
cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler")
use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video")
mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video")
mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video")
mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video")
mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec")
mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video")
mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video")
attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model")
vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model")
vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use")
vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
class ResFramepack(BaseModel):
id: str = Field(title="TaskID", description="Task ID")
filename: str = Field(title="TaskID", description="Task ID")
message: str = Field(title="TaskID", description="Task ID")
def framepack_post(request: ReqFramepack):
import numpy as np
from modules.api import helpers
from framepack_wrappers import run_framepack
task_id = shared.state.get_id()
try:
if request.init_image is not None:
init_image = np.array(helpers.decode_base64_to_image(request.init_image)) if request.init_image else None
else:
init_image = None
except Exception as e:
shared.log.error(f"API FramePack: id={task_id} cannot decode init image: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
try:
if request.end_image is not None:
end_image = np.array(helpers.decode_base64_to_image(request.end_image)) if request.end_image else None
else:
end_image = None
except Exception as e:
shared.log.error(f"API FramePack: id={task_id} cannot decode end image: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
del request.init_image
del request.end_image
shared.log.trace(f"API FramePack: id={task_id} init={init_image.shape} end={end_image.shape if end_image else None} {request}")
generator = run_framepack(
task_id=f'task({task_id})',
variant=request.variant,
init_image=init_image,
end_image=end_image,
start_weight=request.start_weight,
end_weight=request.end_weight,
vision_weight=request.vision_weight,
prompt=request.prompt,
system_prompt=request.system_prompt,
optimized_prompt=request.optimized_prompt,
section_prompt=request.section_prompt,
negative_prompt=request.negative_prompt,
styles=request.styles,
seed=request.seed,
resolution=request.resolution,
duration=request.duration,
latent_ws=request.latent_ws,
steps=request.steps,
cfg_scale=request.cfg_scale,
cfg_distilled=request.cfg_distilled,
cfg_rescale=request.cfg_rescale,
shift=request.shift,
use_teacache=request.use_teacache,
use_cfgzero=request.use_cfgzero,
use_preview=False,
mp4_fps=request.mp4_fps,
mp4_codec=request.mp4_codec,
mp4_sf=request.mp4_sf,
mp4_video=request.mp4_video,
mp4_frames=request.mp4_frames,
mp4_opt=request.mp4_opt,
mp4_ext=request.mp4_ext,
mp4_interpolate=request.mp4_interpolate,
attention=request.attention,
vae_type=request.vae_type,
vlm_enhance=request.vlm_enhance,
vlm_model=request.vlm_model,
vlm_system_prompt=request.vlm_system_prompt,
)
response = ResFramepack(id=task_id, filename='', message='')
for message in generator:
if isinstance(message, tuple) and len(message) == 3:
if isinstance(message[0], str):
response.filename = message[0]
if isinstance(message[2], str):
response.message = message[2]
return response
def create_api(_fastapi, _gradioapp):
shared.api.add_api_route("/sdapi/v1/framepack", framepack_post, methods=["POST"], response_model=ResFramepack)

View File

@ -0,0 +1,74 @@
DEFAULT_PROMPT_TEMPLATE = { # hunyuanvideo reference prompt template
"template": (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"3. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"4. camera angles, movements, and transitions used in the video:<|eot_id|>"
"5. background environment, light, style and atmosphere."
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
),
"crop_start": 95,
}
ORIGINAL_PROMPT_TEMPLATE = (
"\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"3. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"4. camera angles, movements, and transitions used in the video:<|eot_id|>"
"5. background environment, light, style and atmosphere."
)
def set_progress_bar_config():
from modules.framepack.pipeline import uni_pc_fm
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): # pylint: disable=unused-argument
return uni_pc_fm.FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=True)
uni_pc_fm.sample_unipc = sample_unipc
def set_prompt_template(prompt, system_prompt:str=None, optimized_prompt:bool=True, unmodified_prompt:bool=False):
from modules import shared
from modules.framepack.pipeline import hunyuan
mode = 'unknown'
if system_prompt is not None and len(system_prompt) > 0:
mode = 'custom'
system_prompt = system_prompt.strip()
elif not optimized_prompt:
mode = 'original'
system_prompt = ORIGINAL_PROMPT_TEMPLATE
# elif unmodified_prompt:
# mode = 'enhanced'
# system_prompt = "Repeat the user prompt without any modification. Do not add any additional information or context. "
else:
mode = 'optimized'
system_prefix = (
"You are a video director. "
"Both subject and camera movement should be dynamic and creative. Do not specify duration and assume video is short. Avoid slow-motion and prefer faster movements. "
"Do not alter the content specified in the user prompt, but use it as a reference to create a video. "
"NSFW content and nudity are allowed and must be fully described if present. "
"Based on provided user prompt, describe the video by detailing the following aspects: \n"
)
system_desc = (
# "Describe the video by detailing the following aspects: \n"
"1. Main content, style and theme of the video.\n"
"2. Actions, events, behaviors, temporal relationships, physical movement, and changes of the subjects or objects.\n"
"3. Camera angles, camera movements, and transitions used in the video.\n"
"4. Details of the scene and background environment, light, style, and atmosphere.\n"
)
system_prompt = system_prefix + system_desc
# system_prompt = DEFAULT_PROMPT_TEMPLATE["template"]
inputs = shared.sd_model.tokenizer(system_prompt, max_length=256, truncation=True, return_tensors="pt", return_length=True, return_overflowing_tokens=False, return_attention_mask=False)
tokens_system = inputs['length'].item() - int(shared.sd_model.tokenizer.bos_token_id is not None) - int(shared.sd_model.tokenizer.eos_token_id is not None)
inputs = shared.sd_model.tokenizer(prompt, max_length=256, truncation=True, return_tensors="pt", return_length=True, return_overflowing_tokens=False, return_attention_mask=False)
hunyuan.DEFAULT_PROMPT_TEMPLATE = {
"template": (
f"<|start_header_id|>system<|end_header_id|>{system_prompt}\n<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>{}<|eot_id|>"
),
"crop_start": tokens_system,
}
tokens_user = inputs['length'].item() - int(shared.sd_model.tokenizer.bos_token_id is not None) - int(shared.sd_model.tokenizer.eos_token_id is not None)
shared.log.trace(f'FramePack prompt: system={tokens_system} user={tokens_user} optimized={optimized_prompt} unmodified={unmodified_prompt} mode={mode}')

View File

@ -0,0 +1,75 @@
import os
import shutil
import git as gitpython
from installer import install, git
from modules.shared import log
def rename(src:str, dst:str):
import errno
try:
os.rename(src, dst)
except OSError as e:
if e.errno == errno.EXDEV: # cross-device
shutil.move(src, dst)
else:
raise e
def install_requirements(attention:str='SDPA'):
install('av')
import av
import torchvision
torchvision.io.video.av = av
if attention == 'Xformers':
log.debug('FramePack install: xformers')
install('xformers')
elif attention == 'FlashAttention':
log.debug('FramePack install: flash-attn')
install('flash-attn')
elif attention == 'SageAttention':
log.debug('FramePack install: sageattention')
install('sageattention')
def git_clone(git_repo:str, git_dir:str, tmp_dir:str):
if os.path.exists(git_dir):
return
try:
shutil.rmtree(tmp_dir, True)
args = {
'url': git_repo,
'to_path': tmp_dir,
'allow_unsafe_protocols': True,
'allow_unsafe_options': True,
'filter': ['blob:none'],
}
ssh = os.environ.get('GIT_SSH_COMMAND', None)
if ssh:
args['env'] = {'GIT_SSH_COMMAND':ssh}
log.info(f'FramePack install: url={args} path={git_repo}')
with gitpython.Repo.clone_from(**args) as repo:
repo.remote().fetch(verbose=True)
for submodule in repo.submodules:
submodule.update()
rename(tmp_dir, git_dir)
except Exception as e:
log.error(f'FramePack install: {e}')
shutil.rmtree(tmp_dir, True)
def git_update(git_dir:str, git_commit:str):
if not os.path.exists(git_dir):
return
try:
with gitpython.Repo(git_dir) as repo:
commit = repo.commit()
if f'{commit}' != git_commit:
log.info(f'FramePack update: path={repo.git_dir} current={commit} target={git_commit}')
repo.git.fetch(all=True)
repo.git.reset('origin', hard=True)
git(f'checkout {git_commit}', folder=git_dir, ignore=True, optional=True)
else:
log.debug(f'FramePack version: sha={commit}')
except Exception as e:
log.error(f'FramePack update: {e}')

View File

@ -0,0 +1,199 @@
import time
from modules import shared, devices, errors, sd_models, sd_checkpoint, model_quant
models = {
'bi-directional': 'lllyasviel/FramePackI2V_HY',
'forward-only': 'lllyasviel/FramePack_F1_I2V_HY_20250503',
}
default_model = {
'pipeline': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': '' },
'vae': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'vae' },
'text_encoder': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'text_encoder' },
'tokenizer': {'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'tokenizer' },
# 'text_encoder': { 'repo': 'Kijai/llava-llama-3-8b-text-encoder-tokenizer', 'subfolder': '' },
# 'tokenizer': { 'repo': 'Kijai/llava-llama-3-8b-text-encoder-tokenizer', 'subfolder': '' },
# 'text_encoder': { 'repo': 'xtuner/llava-llama-3-8b-v1_1-transformers', 'subfolder': '' },
# 'tokenizer': {'repo': 'xtuner/llava-llama-3-8b-v1_1-transformers', 'subfolder': '' },
'text_encoder_2': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'text_encoder_2' },
'tokenizer_2': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'tokenizer_2' },
'feature_extractor': { 'repo': 'lllyasviel/flux_redux_bfl', 'subfolder': 'feature_extractor' },
'image_encoder': { 'repo': 'lllyasviel/flux_redux_bfl', 'subfolder': 'image_encoder' },
'transformer': { 'repo': models.get('bi-directional'), 'subfolder': '' },
}
model = default_model.copy()
def split_url(url):
if url.count('/') == 1:
url += '/'
if url.count('/') != 2:
raise ValueError(f'Invalid URL: {url}')
url = [section.strip() for section in url.split('/')]
return { 'repo': f'{url[0]}/{url[1]}', 'subfolder': url[2] }
def set_model(receipe: str=None):
if receipe is None or receipe == '':
return
lines = [line.strip() for line in receipe.split('\n') if line.strip() != '' and ':' in line]
for line in lines:
k, v = line.split(':', 1)
k = k.strip()
if k not in default_model.keys():
shared.log.warning(f'FramePack receipe: key={k} invalid')
model[k] = split_url(v)
shared.log.debug(f'FramePack receipe: set {k}={model[k]}')
def get_model():
receipe = ''
for k, v in model.items():
receipe += f'{k}: {v["repo"]}/{v["subfolder"]}\n'
return receipe.strip()
def reset_model():
global model # pylint: disable=global-statement
model = default_model.copy()
shared.log.debug('FramePack receipe: reset')
return ''
def load_model(variant:str=None, pipeline:str=None, text_encoder:str=None, text_encoder_2:str=None, feature_extractor:str=None, image_encoder:str=None, transformer:str=None):
shared.state.begin('Load')
if variant is not None:
if variant not in models.keys():
raise ValueError(f'FramePack: variant="{variant}" invalid')
model['transformer']['repo'] = models[variant]
if pipeline is not None:
model['pipeline'] = split_url(pipeline)
if text_encoder is not None:
model['text_encoder'] = split_url(text_encoder)
if text_encoder_2 is not None:
model['text_encoder_2'] = split_url(text_encoder_2)
if feature_extractor is not None:
model['feature_extractor'] = split_url(feature_extractor)
if image_encoder is not None:
model['image_encoder'] = split_url(image_encoder)
if transformer is not None:
model['transformer'] = split_url(transformer)
# shared.log.trace(f'FramePack load: {model}')
try:
import diffusers
from diffusers import HunyuanVideoImageToVideoPipeline, AutoencoderKLHunyuanVideo
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, SiglipImageProcessor, SiglipVisionModel
from modules.framepack.pipeline.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
class FramepackHunyuanVideoPipeline(HunyuanVideoImageToVideoPipeline): # inherit and override
def __init__(
self,
text_encoder: LlamaModel,
tokenizer: LlamaTokenizerFast,
text_encoder_2: CLIPTextModel,
tokenizer_2: CLIPTokenizer,
vae: AutoencoderKLHunyuanVideo,
feature_extractor: SiglipImageProcessor,
image_processor: SiglipVisionModel,
transformer: HunyuanVideoTransformer3DModelPacked,
scheduler,
):
super().__init__(
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=transformer,
image_processor=image_processor,
scheduler=scheduler,
)
self.register_modules(
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
vae=vae,
feature_extractor=feature_extractor,
image_processor=image_processor,
transformer=transformer,
scheduler=scheduler,
)
sd_models.unload_model_weights()
t0 = time.time()
shared.log.debug(f'FramePack load: module=llm {model["text_encoder"]}')
load_args, quant_args = model_quant.get_dit_args({}, module='TE', device_map=True)
text_encoder = LlamaModel.from_pretrained(model["text_encoder"]["repo"], subfolder=model["text_encoder"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **load_args, **quant_args)
tokenizer = LlamaTokenizerFast.from_pretrained(model["tokenizer"]["repo"], subfolder=model["tokenizer"]["subfolder"], cache_dir=shared.opts.hfcache_dir)
text_encoder.requires_grad_(False)
text_encoder.eval()
sd_models.move_model(text_encoder, devices.cpu)
shared.log.debug(f'FramePack load: module=te {model["text_encoder_2"]}')
text_encoder_2 = CLIPTextModel.from_pretrained(model["text_encoder_2"]["repo"], subfolder=model["text_encoder_2"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
tokenizer_2 = CLIPTokenizer.from_pretrained(model["pipeline"]["repo"], subfolder='tokenizer_2', cache_dir=shared.opts.hfcache_dir)
text_encoder_2.requires_grad_(False)
text_encoder_2.eval()
sd_models.move_model(text_encoder_2, devices.cpu)
shared.log.debug(f'FramePack load: module=vae {model["vae"]}')
vae = AutoencoderKLHunyuanVideo.from_pretrained(model["vae"]["repo"], subfolder=model["vae"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
vae.requires_grad_(False)
vae.eval()
vae.enable_slicing()
vae.enable_tiling()
sd_models.move_model(vae, devices.cpu)
shared.log.debug(f'FramePack load: module=encoder {model["feature_extractor"]} model={model["image_encoder"]}')
feature_extractor = SiglipImageProcessor.from_pretrained(model["feature_extractor"]["repo"], subfolder=model["feature_extractor"]["subfolder"], cache_dir=shared.opts.hfcache_dir)
image_encoder = SiglipVisionModel.from_pretrained(model["image_encoder"]["repo"], subfolder=model["image_encoder"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
image_encoder.requires_grad_(False)
image_encoder.eval()
sd_models.move_model(image_encoder, devices.cpu)
shared.log.debug(f'FramePack load: module=transformer {model["transformer"]}')
dit_repo = model["transformer"]["repo"]
load_args, quant_args = model_quant.get_dit_args({}, module='Model', device_map=True)
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(dit_repo, subfolder=model["transformer"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **load_args, **quant_args)
transformer.high_quality_fp32_output_for_inference = False
transformer.requires_grad_(False)
transformer.eval()
sd_models.move_model(transformer, devices.cpu)
shared.sd_model = FramepackHunyuanVideoPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
vae=vae,
feature_extractor=feature_extractor,
image_processor=image_encoder,
transformer=transformer,
scheduler=None,
)
shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(dit_repo) # pylint: disable=attribute-defined-outside-init
shared.sd_model.sd_model_checkpoint = dit_repo # pylint: disable=attribute-defined-outside-init
shared.sd_model = model_quant.do_post_load_quant(shared.sd_model, allow=False)
t1 = time.time()
diffusers.loaders.peft._SET_ADAPTER_SCALE_FN_MAPPING['HunyuanVideoTransformer3DModelPacked'] = lambda model_cls, weights: weights # pylint: disable=protected-access
shared.log.info(f'FramePack load: model={shared.sd_model.__class__.__name__} variant="{variant}" type={shared.sd_model_type} time={t1-t0:.2f}')
sd_models.apply_balanced_offload(shared.sd_model)
devices.torch_gc(force=True)
except Exception as e:
shared.log.error(f'FramePack load: {e}')
errors.display(e, 'FramePack')
shared.state.end()
return None
shared.state.end()
return variant
def unload_model():
sd_models.unload_model_weights()

View File

@ -0,0 +1,133 @@
import gradio as gr
from modules import ui_sections, ui_common, ui_video_vlm
from modules.framepack import framepack_load
from modules.framepack.framepack_worker import get_latent_paddings
from modules.framepack.framepack_wrappers import get_codecs, load_model, unload_model
from modules.framepack.framepack_wrappers import run_framepack # pylint: disable=wrong-import-order
def change_sections(duration, mp4_fps, mp4_interpolate, latent_ws, variant):
num_sections = len(get_latent_paddings(mp4_fps, mp4_interpolate, latent_ws, duration, variant))
num_frames = (latent_ws * 4 - 3) * num_sections + 1
return gr.update(value=f'Target video: {num_frames} frames in {num_sections} sections'), gr.update(lines=max(2, 2*num_sections//3))
def create_ui(prompt, negative, styles, _overrides):
with gr.Row():
with gr.Column(variant='compact', elem_id="framepack_settings", elem_classes=['settings-column'], scale=1):
with gr.Row():
generate = gr.Button('Generate', elem_id="framepack_generate_btn", variant='primary', visible=False)
with gr.Row():
variant = gr.Dropdown(label="Model variant", choices=list(framepack_load.models), value='bi-directional', type='value')
with gr.Row():
resolution = gr.Slider(label="Resolution", minimum=240, maximum=1088, value=640, step=16)
duration = gr.Slider(label="Duration", minimum=1, maximum=120, value=4, step=0.1)
mp4_fps = gr.Slider(label="FPS", minimum=1, maximum=60, value=24, step=1)
mp4_interpolate = gr.Slider(label="Interpolation", minimum=0, maximum=10, value=0, step=1)
with gr.Row():
section_html = gr.HTML(show_label=False, elem_id="framepack_section_html")
with gr.Accordion(label="Inputs", open=True):
with gr.Row():
input_image = gr.Image(sources='upload', type="numpy", label="Init image", width=256, height=256, interactive=True, tool="editor", image_mode='RGB', elem_id="framepack_input_image")
end_image = gr.Image(sources='upload', type="numpy", label="End image", width=256, height=256, interactive=True, tool="editor", image_mode='RGB', elem_id="framepack_end_image")
with gr.Row():
start_weight = gr.Slider(label="Init strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_start_weight")
end_weight = gr.Slider(label="End strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_end_weight")
vision_weight = gr.Slider(label="Vision strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_vision_weight")
with gr.Accordion(label="Sections", open=False):
section_prompt = gr.Textbox(label="Section prompts", elem_id="framepack_section_prompt", lines=2, placeholder="Optional one-line prompt suffix per each video section", interactive=True)
with gr.Accordion(label="Video", open=False):
with gr.Row():
mp4_codec = gr.Dropdown(label="Codec", choices=['none', 'libx264'], value='libx264', type='value')
ui_common.create_refresh_button(mp4_codec, get_codecs)
mp4_ext = gr.Textbox(label="Format", value='mp4', elem_id="framepack_mp4_ext")
mp4_opt = gr.Textbox(label="Options", value='crf:16', elem_id="framepack_mp4_ext")
with gr.Row():
mp4_video = gr.Checkbox(label='Save Video', value=True, elem_id="framepack_mp4_video")
mp4_frames = gr.Checkbox(label='Save Frames', value=False, elem_id="framepack_mp4_frames")
mp4_sf = gr.Checkbox(label='Save SafeTensors', value=False, elem_id="framepack_mp4_sf")
with gr.Accordion(label="Advanced", open=False):
seed = ui_sections.create_seed_inputs('control', reuse_visible=False, subseed_visible=False, accordion=False)[0]
latent_ws = gr.Slider(label="Latent window size", minimum=1, maximum=33, value=9, step=1)
with gr.Row():
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
shift = gr.Slider(label="Sampler shift", minimum=0.0, maximum=10.0, value=3.0, step=0.01)
with gr.Row():
cfg_scale = gr.Slider(label="CFG scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01)
cfg_distilled = gr.Slider(label="Distilled CFG scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01)
cfg_rescale = gr.Slider(label="CFG re-scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01)
vlm_enhance, vlm_model, vlm_system_prompt = ui_video_vlm.create_ui(prompt_element=prompt, image_element=input_image)
with gr.Accordion(label="Model", open=False):
with gr.Row():
btn_load = gr.Button(value="Load model", elem_id="framepack_btn_load", interactive=True)
btn_unload = gr.Button(value="Unload model", elem_id="framepack_btn_unload", interactive=True)
with gr.Row():
system_prompt = gr.Textbox(label="System prompt", elem_id="framepack_system_prompt", lines=6, placeholder="Optional system prompt for the model", interactive=True)
with gr.Row():
receipe = gr.Textbox(label="Model receipe", elem_id="framepack_model_receipe", lines=6, placeholder="Model receipe", interactive=True)
with gr.Row():
receipe_get = gr.Button(value="Get receipe", elem_id="framepack_btn_get_model", interactive=True)
receipe_set = gr.Button(value="Set receipe", elem_id="framepack_btn_set_model", interactive=True)
receipe_reset = gr.Button(value="Reset receipe", elem_id="framepack_btn_reset_model", interactive=True)
use_teacache = gr.Checkbox(label='Enable TeaCache', value=True)
optimized_prompt = gr.Checkbox(label='Use optimized system prompt', value=True)
use_cfgzero = gr.Checkbox(label='Enable CFGZero', value=False)
use_preview = gr.Checkbox(label='Enable Preview', value=True)
attention = gr.Dropdown(label="Attention", choices=['Default', 'Xformers', 'FlashAttention', 'SageAttention'], value='Default', type='value')
vae_type = gr.Dropdown(label="VAE", choices=['Full', 'Tiny', 'Remote'], value='Local', type='value')
with gr.Column(elem_id='framepack-output-column', scale=2) as _column_output:
with gr.Tabs():
with gr.TabItem("Video"):
result_video = gr.Video(label="Video", autoplay=True, show_share_button=False, height=512, loop=True, show_label=False, elem_id="framepack_result_video")
with gr.Tab("Preview"):
preview_image = gr.Image(label="Current", height=512, show_label=False, elem_id="framepack_preview_image")
progress_desc = gr.HTML('', show_label=False, elem_id="framepack_progress_desc")
# hidden fields
task_id = gr.Textbox(visible=False, value='')
ui_state = gr.Textbox(visible=False, value='')
state_inputs = [task_id, ui_state]
framepack_outputs = [
result_video,
preview_image,
progress_desc,
]
duration.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt])
mp4_fps.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt])
mp4_interpolate.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt])
btn_load.click(fn=load_model, inputs=[variant, attention], outputs=framepack_outputs)
btn_unload.click(fn=unload_model, outputs=framepack_outputs)
receipe_get.click(fn=framepack_load.get_model, inputs=[], outputs=receipe)
receipe_set.click(fn=framepack_load.set_model, inputs=[receipe], outputs=[])
receipe_reset.click(fn=framepack_load.reset_model, inputs=[], outputs=[receipe])
framepack_inputs=[
input_image, end_image,
start_weight, end_weight, vision_weight,
prompt, system_prompt, optimized_prompt, section_prompt, negative, styles,
seed,
resolution,
duration,
latent_ws,
steps,
cfg_scale, cfg_distilled, cfg_rescale,
shift,
use_teacache, use_cfgzero, use_preview,
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
attention, vae_type, variant,
vlm_enhance, vlm_model, vlm_system_prompt,
]
framepack_dict = dict(
fn=run_framepack,
_js="submit_framepack",
inputs=state_inputs + framepack_inputs,
outputs=framepack_outputs,
show_progress=False,
)
generate.click(**framepack_dict)

View File

@ -0,0 +1,96 @@
import torch
import einops
from modules import shared, devices
latent_rgb_factors = [ # from comfyui
[-0.0395, -0.0331, 0.0445],
[0.0696, 0.0795, 0.0518],
[0.0135, -0.0945, -0.0282],
[0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[0.1166, 0.1627, 0.0962],
[0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
vae_weight = None
vae_bias = None
taesd = None
def vae_decode_simple(latents):
global vae_weight, vae_bias # pylint: disable=global-statement
with devices.inference_context():
if vae_weight is None or vae_bias is None:
vae_weight = torch.tensor(latent_rgb_factors, device=devices.device, dtype=devices.dtype).transpose(0, 1)[:, :, None, None, None]
vae_bias = torch.tensor(latent_rgb_factors_bias, device=devices.device, dtype=devices.dtype)
images = torch.nn.functional.conv3d(latents, weight=vae_weight, bias=vae_bias, stride=1, padding=0, dilation=1, groups=1)
images = (images + 1.2) * 100 # sort-of normalized
images = einops.rearrange(images, 'b c t h w -> (b h) (t w) c')
images = images.to(torch.uint8).detach().cpu().numpy().clip(0, 255)
return images
def vae_decode_tiny(latents):
global taesd # pylint: disable=global-statement
if taesd is None:
from modules import sd_vae_taesd
taesd = sd_vae_taesd.get_model(variant='TAE HunyuanVideo')
shared.log.debug(f'Video VAE: type=Tiny cls={taesd.__class__.__name__} latents={latents.shape}')
with devices.inference_context():
taesd = taesd.to(device=devices.device, dtype=devices.dtype)
latents = latents.transpose(1, 2) # pipe produces NCTHW and tae wants NTCHW
images = taesd.decode_video(latents, parallel=False, show_progress_bar=False)
images = images.transpose(1, 2).mul_(2).sub_(1) # normalize
taesd = taesd.to(device=devices.cpu, dtype=devices.dtype)
return images
def vae_decode_remote(latents):
# from modules.sd_vae_remote import remote_decode
# images = remote_decode(latents, model_type='hunyuanvideo')
from diffusers.utils.remote_utils import remote_decode
images = remote_decode(
tensor=latents.contiguous(),
endpoint='https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud',
output_type='pt',
return_type='pt',
)
return images
def vae_decode_full(latents):
with devices.inference_context():
vae = shared.sd_model.vae
latents = (latents / vae.config.scaling_factor).to(device=vae.device, dtype=vae.dtype)
images = vae.decode(latents).sample
return images
def vae_decode(latents, vae_type):
latents = latents.to(device=devices.device, dtype=devices.dtype)
if vae_type == 'Tiny':
return vae_decode_tiny(latents)
elif vae_type == 'Preview':
return vae_decode_simple(latents)
elif vae_type == 'Remote':
return vae_decode_remote(latents)
else: # vae_type == 'Full'
return vae_decode_full(latents)
def vae_encode(image):
with devices.inference_context():
vae = shared.sd_model.vae
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents

View File

@ -0,0 +1,121 @@
import os
import time
import datetime
import cv2
import torch
import einops
from modules import shared, errors ,timer, rife
def atomic_save_video(filename, tensor:torch.Tensor, fps:float=24, codec:str='libx264', pix_fmt:str='yuv420p', options:str='', metadata:dict={}, pbar=None):
try:
import av
av.logging.set_level(av.logging.ERROR) # pylint: disable=c-extension-no-member
except Exception as e:
shared.log.error(f'FramePack video: {e}')
return
frames, height, width, _channels = tensor.shape
rate = round(fps)
options_str = options
options = {}
for option in [option.strip() for option in options_str.split(',')]:
if '=' in option:
key, value = option.split('=', 1)
elif ':' in option:
key, value = option.split(':', 1)
else:
continue
options[key.strip()] = value.strip()
shared.log.info(f'FramePack video: file="{filename}" codec={codec} frames={frames} width={width} height={height} fps={rate} options={options}')
video_array = torch.as_tensor(tensor, dtype=torch.uint8).numpy(force=True)
task = pbar.add_task('encoding', total=frames) if pbar is not None else None
if task is not None:
pbar.update(task, description='video encoding')
with av.open(filename, mode="w") as container:
for k, v in metadata.items():
container.metadata[k] = v
stream: av.VideoStream = container.add_stream(codec, rate=rate, options=options)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = pix_fmt
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
for packet in stream.encode_lazy(frame):
container.mux(packet)
if task is not None:
pbar.update(task, advance=1)
for packet in stream.encode(): # flush
container.mux(packet)
shared.state.outputs(filename)
def save_video(
pixels:torch.Tensor,
mp4_fps:int=24,
mp4_codec:str='libx264',
mp4_opt:str='',
mp4_ext:str='mp4',
mp4_sf:bool=False, # save safetensors
mp4_video:bool=True, # save video
mp4_frames:bool=False, # save frames
mp4_interpolate:int=0, # rife interpolation
stream=None, # async progress reporting stream
metadata:dict={}, # metadata for video
pbar=None, # progress bar for video
):
if pixels is None:
return 0
t_save = time.time()
n, _c, t, h, w = pixels.shape
size = pixels.element_size() * pixels.numel()
shared.log.debug(f'FramePack video: video={mp4_video} export={mp4_frames} safetensors={mp4_sf} interpolate={mp4_interpolate}')
shared.log.debug(f'FramePack video: encode={t} raw={size} latent={pixels.shape} fps={mp4_fps} codec={mp4_codec} ext={mp4_ext} options="{mp4_opt}"')
try:
if stream is not None:
stream.output_queue.push(('progress', (None, 'Saving video...')))
if mp4_interpolate > 0:
x = pixels.squeeze(0).permute(1, 0, 2, 3)
interpolated = rife.interpolate_nchw(x, count=mp4_interpolate+1)
pixels = torch.stack(interpolated, dim=0)
pixels = pixels.permute(1, 2, 0, 3, 4)
n, _c, t, h, w = pixels.shape
x = torch.clamp(pixels.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=n)
x = x.contiguous()
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
output_filename = os.path.join(shared.opts.outdir_video, f'{timestamp}-{mp4_codec}-f{t}')
if mp4_sf:
fn = f'{output_filename}.safetensors'
shared.log.info(f'FramePack export: file="{fn}" type=savetensors shape={x.shape}')
from safetensors.torch import save_file
shared.state.outputs(fn)
save_file({ 'frames': x }, fn, metadata={'format': 'video', 'frames': str(t), 'width': str(w), 'height': str(h), 'fps': str(mp4_fps), 'codec': mp4_codec, 'options': mp4_opt, 'ext': mp4_ext, 'interpolate': str(mp4_interpolate)})
if mp4_frames:
shared.log.info(f'FramePack frames: files="{output_filename}-00000.jpg" frames={t} width={w} height={h}')
for i in range(t):
image = cv2.cvtColor(x[i].numpy(), cv2.COLOR_RGB2BGR)
fn = f'{output_filename}-{i:05d}.jpg'
shared.state.outputs(fn)
cv2.imwrite(fn, image)
if mp4_video and (mp4_codec != 'none'):
fn = f'{output_filename}.{mp4_ext}'
atomic_save_video(fn, tensor=x, fps=mp4_fps, codec=mp4_codec, options=mp4_opt, metadata=metadata, pbar=pbar)
if stream is not None:
stream.output_queue.push(('progress', (None, f'Video {os.path.basename(fn)} | Codec {mp4_codec} | Size {w}x{h}x{t} | FPS {mp4_fps}')))
stream.output_queue.push(('file', fn))
else:
if stream is not None:
stream.output_queue.push(('progress', (None, '')))
except Exception as e:
shared.log.error(f'FramePack video: raw={size} {e}')
errors.display(e, 'FramePack video')
timer.process.add('save', time.time()-t_save)
return t

View File

@ -0,0 +1,319 @@
import time
import torch
import rich.progress as rp
from modules import shared, errors ,devices, sd_models, timer, memstats
from modules.framepack import framepack_vae # pylint: disable=wrong-import-order
from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order
from modules.framepack import framepack_video # pylint: disable=wrong-import-order
stream = None # AsyncStream
def get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant):
try:
real_fps = mp4_fps / (mp4_interpolate + 1)
is_f1 = variant == 'forward-only'
if is_f1:
total_latent_sections = (total_second_length * real_fps) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))
latent_paddings = list(range(total_latent_sections))
else:
total_latent_sections = int(max((total_second_length * real_fps) / (latent_window_size * 4), 1))
latent_paddings = list(reversed(range(total_latent_sections)))
if total_latent_sections > 4: # extra padding for better quality
# latent_paddings = list(reversed(range(total_latent_sections)))
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
except Exception:
latent_paddings = [0]
return latent_paddings
def worker(
input_image, end_image,
start_weight, end_weight, vision_weight,
prompts, n_prompt, system_prompt, optimized_prompt, unmodified_prompt,
seed,
total_second_length,
latent_window_size,
steps,
cfg_scale, cfg_distilled, cfg_rescale,
shift,
use_teacache, use_cfgzero, use_preview,
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
vae_type,
variant,
metadata:dict={},
):
timer.process.reset()
memstats.reset_stats()
if stream is None or shared.state.interrupted or shared.state.skipped:
shared.log.error('FramePack: stream is None')
stream.output_queue.push(('end', None))
return
from modules.framepack.pipeline import hunyuan
from modules.framepack.pipeline import utils
from modules.framepack.pipeline.k_diffusion_hunyuan import sample_hunyuan
is_f1 = variant == 'forward-only'
total_generated_frames = 0
total_generated_latent_frames = 0
latent_paddings = get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant)
num_frames = latent_window_size * 4 - 3 # number of frames to generate in each section
metadata['title'] = 'sdnext framepack'
metadata['description'] = f'variant:{variant} seed:{seed} steps:{steps} scale:{cfg_scale} distilled:{cfg_distilled} rescale:{cfg_rescale} shift:{shift} start:{start_weight} end:{end_weight} vision:{vision_weight}'
shared.state.begin('Video')
shared.state.job_count = 1
text_encoder = shared.sd_model.text_encoder
text_encoder_2 = shared.sd_model.text_encoder_2
tokenizer = shared.sd_model.tokenizer
tokenizer_2 = shared.sd_model.tokenizer_2
vae = shared.sd_model.vae
feature_extractor = shared.sd_model.feature_extractor
image_encoder = shared.sd_model.image_processor
transformer = shared.sd_model.transformer
sd_models.apply_balanced_offload(shared.sd_model)
pbar = rp.Progress(rp.TextColumn('[cyan]Video'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
task = pbar.add_task('starting', total=steps * len(latent_paddings))
t_last = time.time()
if not is_f1:
prompts = list(reversed(prompts))
def text_encode(prompt, i:int=None):
pbar.update(task, description=f'text encode section={i}')
t0 = time.time()
torch.manual_seed(seed)
# shared.log.debug(f'FramePack: section={i} prompt="{prompt}"')
shared.state.textinfo = 'Text encode'
stream.output_queue.push(('progress', (None, 'Text encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(text_encoder, devices.device, force=True) # required as hunyuan.encode_prompt_conds checks device before calling model
sd_models.move_model(text_encoder_2, devices.device, force=True)
framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt)
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
metadata['comment'] = prompt
if cfg_scale > 1 and n_prompt is not None and len(n_prompt) > 0:
llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
else:
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512)
timer.process.add('prompt', time.time()-t0)
return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n
def latents_encode(input_image, end_image):
pbar.update(task, description='image encode')
# shared.log.debug(f'FramePack: image encode init={input_image.shape} end={end_image.shape if end_image is not None else None}')
t0 = time.time()
torch.manual_seed(seed)
stream.output_queue.push(('progress', (None, 'VAE encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(vae, devices.device, force=True)
if input_image is not None:
input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
start_latent = framepack_vae.vae_encode(input_image_pt)
if start_weight < 1:
noise = torch.randn_like(start_latent)
start_latent = start_latent * start_weight + noise * (1 - start_weight)
if end_image is not None:
end_image_pt = torch.from_numpy(end_image).float() / 127.5 - 1
end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None]
end_latent = framepack_vae.vae_encode(end_image_pt)
else:
end_latent = None
timer.process.add('encode', time.time()-t0)
return start_latent, end_latent
def vision_encode(input_image, end_image):
pbar.update(task, description='vision encode')
# shared.log.debug(f'FramePack: vision encode init={input_image.shape} end={end_image.shape if end_image is not None else None}')
t0 = time.time()
shared.state.textinfo = 'Vision encode'
stream.output_queue.push(('progress', (None, 'Vision encoding...')))
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(feature_extractor, devices.device, force=True)
sd_models.move_model(image_encoder, devices.device, force=True)
preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
image_encoder_output = image_encoder(**preprocessed)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
if end_image is not None:
preprocessed = feature_extractor.preprocess(images=end_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
end_image_encoder_output = image_encoder(**preprocessed)
end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach
timer.process.add('vision', time.time()-t0)
image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight
return image_encoder_last_hidden_state
def step_callback(d):
if use_cfgzero and is_first_section and d['i'] == 0:
d['denoised'] = d['denoised'] * 0
t_current = time.time()
if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped:
stream.output_queue.push(('progress', (None, 'Interrupted...')))
stream.output_queue.push(('end', None))
raise AssertionError('Interrupted...')
if shared.state.paused:
shared.log.debug('Sampling paused')
while shared.state.paused:
if shared.state.interrupted or shared.state.skipped:
raise AssertionError('Interrupted...')
time.sleep(0.1)
nonlocal total_generated_frames, t_last
t_preview = time.time()
current_step = d['i'] + 1
shared.state.textinfo = ''
shared.state.sampling_step = ((lattent_padding_loop-1) * steps) + current_step
shared.state.sampling_steps = steps * len(latent_paddings)
progress = shared.state.sampling_step / shared.state.sampling_steps
total_generated_frames = int(max(0, total_generated_latent_frames * 4 - 3))
pbar.update(task, advance=1, description=f'its={1/(t_current-t_last):.2f} sample={d["i"]+1}/{steps} section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)}')
desc = f'Step {shared.state.sampling_step}/{shared.state.sampling_steps} | Current {current_step}/{steps} | Section {lattent_padding_loop}/{len(latent_paddings)} | Progress {progress:.2%}'
if use_preview:
preview = framepack_vae.vae_decode(d['denoised'], 'Preview')
stream.output_queue.push(('progress', (preview, desc)))
else:
stream.output_queue.push(('progress', (None, desc)))
timer.process.add('preview', time.time() - t_preview)
t_last = t_current
try:
with devices.inference_context(), pbar:
t0 = time.time()
height, width, _C = input_image.shape
start_latent, end_latent = latents_encode(input_image, end_image)
image_encoder_last_hidden_state = vision_encode(input_image, end_image)
# Sample loop
shared.state.textinfo = 'Sample'
stream.output_queue.push(('progress', (None, 'Start sampling...')))
generator = torch.Generator("cpu").manual_seed(seed)
if is_f1:
history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu()
else:
history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=devices.dtype).cpu()
history_pixels = None
lattent_padding_loop = 0
last_prompt = None
for latent_padding in latent_paddings:
current_prompt = prompts[lattent_padding_loop]
if current_prompt != last_prompt:
llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n = text_encode(current_prompt, i=lattent_padding_loop+1)
last_prompt = current_prompt
lattent_padding_loop += 1
# shared.log.trace(f'FramePack: op=sample section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)} window={latent_window_size} size={num_frames}')
if is_f1:
is_first_section, is_last_section = False, False
else:
is_first_section, is_last_section = latent_padding == latent_paddings[0], latent_padding == 0
if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped:
stream.output_queue.push(('end', None))
return
if is_f1:
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
else:
latent_padding_size = latent_padding * latent_window_size
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
clean_latent_indices_pre, _blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
clean_latents_pre = start_latent.to(history_latents)
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
if end_image is not None and is_first_section:
clean_latents_post = (clean_latents_post * start_weight / len(latent_paddings)) + (end_weight * end_latent.to(history_latents)) / (start_weight/len(latent_paddings) + end_weight) # pylint: disable=possibly-used-before-assignment
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
sd_models.apply_balanced_offload(shared.sd_model)
transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps, rel_l1_thresh=shared.opts.teacache_thresh)
t_sample = time.time()
generated_latents = sample_hunyuan(
transformer=transformer,
sampler='unipc',
width=width,
height=height,
frames=num_frames,
num_inference_steps=steps,
real_guidance_scale=cfg_scale,
distilled_guidance_scale=cfg_distilled,
guidance_rescale=cfg_rescale,
shift=shift if shift > 0 else None,
generator=generator,
prompt_embeds=llama_vec, # pylint: disable=possibly-used-before-assignment
prompt_embeds_mask=llama_attention_mask, # pylint: disable=possibly-used-before-assignment
prompt_poolers=clip_l_pooler, # pylint: disable=possibly-used-before-assignment
negative_prompt_embeds=llama_vec_n, # pylint: disable=possibly-used-before-assignment
negative_prompt_embeds_mask=llama_attention_mask_n, # pylint: disable=possibly-used-before-assignment
negative_prompt_poolers=clip_l_pooler_n, # pylint: disable=possibly-used-before-assignment
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
device=devices.device,
dtype=devices.dtype,
callback=step_callback,
)
timer.process.add('sample', time.time()-t_sample)
if is_last_section:
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
total_generated_latent_frames += int(generated_latents.shape[2])
if is_f1:
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
else:
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
t_vae = time.time()
sd_models.apply_balanced_offload(shared.sd_model)
sd_models.move_model(vae, devices.device, force=True)
if history_pixels is None:
history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu()
else:
overlapped_frames = latent_window_size * 4 - 3
if is_f1:
section_latent_frames = latent_window_size * 2
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
history_pixels = utils.soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
else:
section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu()
history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
timer.process.add('vae', time.time()-t_vae)
if is_last_section:
break
total_generated_frames = framepack_video.save_video(history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate, pbar=pbar, stream=stream, metadata=metadata)
except AssertionError:
shared.log.info('FramePack: interrupted')
if shared.opts.keep_incomplete:
framepack_video.save_video(history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate=0, stream=stream, metadata=metadata)
except Exception as e:
shared.log.error(f'FramePack: {e}')
errors.display(e, 'FramePack')
sd_models.apply_balanced_offload(shared.sd_model)
stream.output_queue.push(('end', None))
t1 = time.time()
shared.log.info(f'Processed: frames={total_generated_frames} fps={total_generated_frames/(t1-t0):.2f} its={(shared.state.sampling_step)/(t1-t0):.2f} time={t1-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
shared.state.end()

View File

@ -0,0 +1,234 @@
import os
import re
import random
import threading
import numpy as np
import torch
import gradio as gr
from modules import shared, processing, timer, paths, extra_networks, progress, ui_video_vlm
from modules.framepack import framepack_install # pylint: disable=wrong-import-order
from modules.framepack import framepack_load # pylint: disable=wrong-import-order
from modules.framepack import framepack_worker # pylint: disable=wrong-import-order
from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order
tmp_dir = os.path.join(paths.data_path, 'tmp', 'framepack')
git_dir = os.path.join(os.path.dirname(__file__), 'framepack')
git_repo = 'https://github.com/lllyasviel/framepack'
git_commit = 'c5d375661a2557383f0b8da9d11d14c23b0c4eaf'
queue_lock = threading.Lock()
loaded_variant = None
def check_av():
try:
import av
except Exception as e:
shared.log.error(f'av package: {e}')
return False
return av
def get_codecs():
av = check_av()
if av is None:
return []
codecs = []
for codec in av.codecs_available:
try:
c = av.Codec(codec, mode='w')
if c.type == 'video' and c.is_encoder and len(c.video_formats) > 0:
if not any(c.name == ca.name for ca in codecs):
codecs.append(c)
except Exception:
pass
hw_codecs = [c for c in codecs if (c.capabilities & 0x40000 > 0) or (c.capabilities & 0x80000 > 0)]
sw_codecs = [c for c in codecs if c not in hw_codecs]
shared.log.debug(f'Video codecs: hardware={len(hw_codecs)} software={len(sw_codecs)}')
# for c in hw_codecs:
# shared.log.trace(f'codec={c.name} cname="{c.canonical_name}" decs="{c.long_name}" intra={c.intra_only} lossy={c.lossy} lossless={c.lossless} capabilities={c.capabilities} hw=True')
# for c in sw_codecs:
# shared.log.trace(f'codec={c.name} cname="{c.canonical_name}" decs="{c.long_name}" intra={c.intra_only} lossy={c.lossy} lossless={c.lossless} capabilities={c.capabilities} hw=False')
return ['none'] + [c.name for c in hw_codecs + sw_codecs]
def prepare_image(image, resolution):
from modules.framepack.pipeline.utils import resize_and_center_crop
buckets = [
(416, 960), (448, 864), (480, 832), (512, 768), (544, 704), (576, 672), (608, 640),
(640, 608), (672, 576), (704, 544), (768, 512), (832, 480), (864, 448), (960, 416),
]
h, w, _c = image.shape
min_metric = float('inf')
scale_factor = resolution / 640.0
scaled_h, scaled_w = h, w
for (bucket_h, bucket_w) in buckets:
metric = abs(h * bucket_w - w * bucket_h)
if metric <= min_metric:
min_metric = metric
scaled_h = round(bucket_h * scale_factor / 16) * 16
scaled_w = round(bucket_w * scale_factor / 16) * 16
image = resize_and_center_crop(image, target_height=scaled_h, target_width=scaled_w)
h0, w0, _c = image.shape
shared.log.debug(f'FramePack prepare: input="{w}x{h}" resized="{w0}x{h0}" resolution={resolution} scale={scale_factor}')
return image
def interpolate_prompts(prompts, steps):
interpolated_prompts = [''] * steps
if prompts is None:
return interpolated_prompts
if isinstance(prompts, str):
prompts = re.split(r'[,\n]', prompts)
prompts = [p.strip() for p in prompts]
if len(prompts) == 0:
return interpolated_prompts
if len(prompts) == steps:
return prompts
factor = steps / len(prompts)
for i in range(steps):
prompt_index = int(i / factor)
interpolated_prompts[i] = prompts[prompt_index]
# shared.log.trace(f'FramePack interpolate: section={i} prompt="{interpolated_prompts[i]}"')
return interpolated_prompts
def prepare_prompts(p, init_image, prompt:str, section_prompt:str, num_sections:int, vlm_enhance:bool, vlm_model:str, vlm_system_prompt:str):
section_prompts = interpolate_prompts(section_prompt, num_sections)
p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)
p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
shared.prompt_styles.apply_styles_to_extra(p)
p.prompts, p.network_data = extra_networks.parse_prompts([p.prompt])
extra_networks.activate(p)
prompt = p.prompts[0]
generated_prompts = [''] * num_sections
previous_prompt = None
for i in range(num_sections):
current_prompt = (prompt + ' ' + section_prompts[i]).strip()
if current_prompt == previous_prompt:
generated_prompts[i] = generated_prompts[i - 1]
else:
generated_prompts[i] = ui_video_vlm.enhance_prompt(
enable=vlm_enhance,
model=vlm_model,
image=init_image,
prompt=current_prompt,
system_prompt=vlm_system_prompt,
)
previous_prompt = current_prompt
return generated_prompts
def load_model(variant, attention):
global loaded_variant # pylint: disable=global-statement
if (shared.sd_model_type != 'hunyuanvideo') or (loaded_variant != variant):
yield gr.update(), gr.update(), 'Verifying FramePack'
framepack_install.install_requirements(attention)
# framepack_install.git_clone(git_repo=git_repo, git_dir=git_dir, tmp_dir=tmp_dir)
# framepack_install.git_update(git_dir=git_dir, git_commit=git_commit)
# sys.path.append(git_dir)
framepack_hijack.set_progress_bar_config()
yield gr.update(), gr.update(), 'Model loading...', ''
loaded_variant = framepack_load.load_model(variant)
if loaded_variant is not None:
yield gr.update(), gr.update(), 'Model loaded'
else:
yield gr.update(), gr.update(), 'Model load failed'
def unload_model():
shared.log.debug('FramePack unload')
framepack_load.unload_model()
yield gr.update(), gr.update(), 'Model unloaded'
def run_framepack(task_id, _ui_state, init_image, end_image, start_weight, end_weight, vision_weight, prompt, system_prompt, optimized_prompt, section_prompt, negative_prompt, styles, seed, resolution, duration, latent_ws, steps, cfg_scale, cfg_distilled, cfg_rescale, shift, use_teacache, use_cfgzero, use_preview, mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, attention, vae_type, variant, vlm_enhance, vlm_model, vlm_system_prompt):
variant = variant or 'bi-directional'
if init_image is None:
init_image = np.zeros((resolution, resolution, 3), dtype=np.uint8)
mode = 't2v'
elif end_image is not None:
mode = 'flf2v'
else:
mode = 'i2v'
av = check_av()
if av is None:
yield gr.update(), gr.update(), 'AV package not installed'
return
progress.add_task_to_queue(task_id)
with queue_lock:
progress.start_task(task_id)
yield from load_model(variant, attention)
if shared.sd_model_type != 'hunyuanvideo':
progress.finish_task(task_id)
yield gr.update(), gr.update(), 'Model load failed'
return
yield gr.update(), gr.update(), 'Generate starting...'
from modules.framepack.pipeline.thread_utils import AsyncStream, async_run
framepack_worker.stream = AsyncStream()
if seed is None or seed == '' or seed == -1:
random.seed()
seed = random.randrange(4294967294)
seed = int(seed)
torch.manual_seed(seed)
num_sections = len(framepack_worker.get_latent_paddings(mp4_fps, mp4_interpolate, latent_ws, duration, variant))
num_frames = (latent_ws * 4 - 3) * num_sections + 1
shared.log.info(f'FramePack start: mode={mode} variant="{variant}" frames={num_frames} sections={num_sections} resolution={resolution} seed={seed} duration={duration} teacache={use_teacache} thres={shared.opts.teacache_thresh} cfgzero={use_cfgzero}')
shared.log.info(f'FramePack params: steps={steps} start={start_weight} end={end_weight} vision={vision_weight} scale={cfg_scale} distilled={cfg_distilled} rescale={cfg_rescale} shift={shift}')
init_image = prepare_image(init_image, resolution)
if end_image is not None:
end_image = prepare_image(end_image, resolution)
w, h, _c = init_image.shape
p = processing.StableDiffusionProcessingVideo(
sd_model=shared.sd_model,
prompt=prompt,
negative_prompt=negative_prompt,
styles=styles,
steps=steps,
seed=seed,
width=w,
height=h,
)
prompts = prepare_prompts(p, init_image, prompt, section_prompt, num_sections, vlm_enhance, vlm_model, vlm_system_prompt)
async_run(
framepack_worker.worker,
init_image, end_image,
start_weight, end_weight, vision_weight,
prompts, p.negative_prompt, system_prompt, optimized_prompt, vlm_enhance,
seed,
duration,
latent_ws,
p.steps,
cfg_scale, cfg_distilled, cfg_rescale,
shift,
use_teacache, use_cfgzero, use_preview,
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
vae_type, variant,
)
output_filename = None
while True:
flag, data = framepack_worker.stream.output_queue.next()
if flag == 'file':
output_filename = data
yield output_filename, gr.update(), gr.update()
if flag == 'progress':
preview, text = data
summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
memory = shared.mem_mon.summary()
stats = f"<div class='performance'><p>{summary} {memory}</p></div>"
yield gr.update(), gr.update(value=preview), f'{text} {stats}'
if flag == 'end':
yield output_filename, gr.update(value=None), gr.update()
break
progress.finish_task(task_id)
yield gr.update(), gr.update(), 'Generate finished'
return

View File

@ -0,0 +1,29 @@
bucket_options = {
640: [
(416, 960),
(448, 864),
(480, 832),
(512, 768),
(544, 704),
(576, 672),
(608, 640),
(640, 608),
(672, 576),
(704, 544),
(768, 512),
(832, 480),
(864, 448),
(960, 416),
],
}
def find_nearest_bucket(h, w, resolution=640):
min_metric = float('inf')
best_bucket = None
for (bucket_h, bucket_w) in bucket_options[resolution]:
metric = abs(h * bucket_w - w * bucket_h)
if metric <= min_metric:
min_metric = metric
best_bucket = (bucket_h, bucket_w)
return best_bucket

View File

@ -0,0 +1,12 @@
import numpy as np
def hf_clip_vision_encode(image, feature_extractor, image_encoder):
assert isinstance(image, np.ndarray)
assert image.ndim == 3 and image.shape[2] == 3
assert image.dtype == np.uint8
preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
image_encoder_output = image_encoder(**preprocessed)
return image_encoder_output

View File

@ -0,0 +1,53 @@
import torch
import accelerate.accelerator
from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
def LayerNorm_forward(self, x):
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
LayerNorm.forward = LayerNorm_forward
torch.nn.LayerNorm.forward = LayerNorm_forward
def FP32LayerNorm_forward(self, x):
origin_dtype = x.dtype
return torch.nn.functional.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)
FP32LayerNorm.forward = FP32LayerNorm_forward
def RMSNorm_forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is None:
return hidden_states.to(input_dtype)
return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
RMSNorm.forward = RMSNorm_forward
def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = emb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward

View File

@ -0,0 +1,109 @@
import torch
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
@torch.no_grad()
def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
assert isinstance(prompt, str)
prompt = [prompt]
# LLAMA
prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
llama_inputs = tokenizer(
prompt_llama,
padding="max_length",
max_length=max_length + crop_start,
truncation=True,
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
)
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
llama_attention_length = int(llama_attention_mask.sum())
llama_outputs = text_encoder(
input_ids=llama_input_ids,
attention_mask=llama_attention_mask,
output_hidden_states=True,
)
llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
# llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
assert torch.all(llama_attention_mask.bool())
# CLIP
clip_l_input_ids = tokenizer_2(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
).input_ids
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
return llama_vec, clip_l_pooler
@torch.no_grad()
def vae_decode_fake(latents):
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[0.0696, 0.0795, 0.0518],
[0.0135, -0.0945, -0.0282],
[0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[0.1166, 0.1627, 0.0962],
[0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[0.0249, -0.0469, -0.1703]
] # From comfyui
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
images = images.clamp(0.0, 1.0)
return images
@torch.no_grad()
def vae_decode(latents, vae, image_mode=False):
latents = latents / vae.config.scaling_factor
if not image_mode:
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
else:
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
image = torch.cat(image, dim=2)
return image
@torch.no_grad()
def vae_encode(image, vae):
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,119 @@
import math
import torch
from modules.framepack.pipeline.uni_pc_fm import sample_unipc
from modules.framepack.pipeline.wrapper import fm_wrapper
from modules.framepack.pipeline.utils import repeat_to_batch_size
def flux_time_shift(t, mu=1.15, sigma=1.0):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
k = (y2 - y1) / (x2 - x1)
b = y1 - k * x1
mu = k * context_length + b
mu = min(mu, math.log(exp_max))
return mu
def get_flux_sigmas_from_mu(n, mu):
sigmas = torch.linspace(1, 0, steps=n + 1)
sigmas = flux_time_shift(sigmas, mu=mu)
return sigmas
@torch.inference_mode()
def sample_hunyuan(
transformer,
sampler='unipc',
initial_latent=None,
concat_latent=None,
strength=1.0,
width=512,
height=512,
frames=16,
real_guidance_scale=1.0,
distilled_guidance_scale=6.0,
guidance_rescale=0.0,
shift=None,
num_inference_steps=25,
batch_size=None,
generator=None,
prompt_embeds=None,
prompt_embeds_mask=None,
prompt_poolers=None,
negative_prompt_embeds=None,
negative_prompt_embeds_mask=None,
negative_prompt_poolers=None,
dtype=torch.bfloat16,
device=None,
negative_kwargs=None,
callback=None,
**kwargs,
):
device = device or transformer.device
if batch_size is None:
batch_size = int(prompt_embeds.shape[0])
latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
_B, _C, T, H, W = latents.shape
seq_length = T * H * W // 4
if shift is None:
mu = calculate_flux_mu(seq_length, exp_max=7.0)
else:
mu = math.log(shift)
sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
k_model = fm_wrapper(transformer)
if initial_latent is not None:
sigmas = sigmas * strength
first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
initial_latent = initial_latent.to(device=device, dtype=torch.float32)
latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
if concat_latent is not None:
concat_latent = concat_latent.to(latents)
distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
concat_latent = repeat_to_batch_size(concat_latent, batch_size)
sampler_kwargs = dict(
dtype=dtype,
cfg_scale=real_guidance_scale,
cfg_rescale=guidance_rescale,
concat_latent=concat_latent,
positive=dict(
pooled_projections=prompt_poolers,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_embeds_mask,
guidance=distilled_guidance,
**kwargs,
),
negative=dict(
pooled_projections=negative_prompt_poolers,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_embeds_mask,
guidance=distilled_guidance,
**(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
)
)
if sampler == 'unipc':
results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
else:
raise NotImplementedError(f'Sampler {sampler} is not supported.')
return results

View File

@ -0,0 +1,76 @@
import time
from threading import Thread, Lock
class Listener:
task_queue = []
lock = Lock()
thread = None
@classmethod
def _process_tasks(cls):
while True:
task = None
with cls.lock:
if cls.task_queue:
task = cls.task_queue.pop(0)
if task is None:
time.sleep(0.001)
continue
func, args, kwargs = task
try:
func(*args, **kwargs)
except Exception as e:
print(f"Error in listener thread: {e}")
@classmethod
def add_task(cls, func, *args, **kwargs):
with cls.lock:
cls.task_queue.append((func, args, kwargs))
if cls.thread is None:
cls.thread = Thread(target=cls._process_tasks, daemon=True)
cls.thread.start()
def async_run(func, *args, **kwargs):
Listener.add_task(func, *args, **kwargs)
class FIFOQueue:
def __init__(self):
self.queue = []
self.lock = Lock()
def push(self, item):
with self.lock:
self.queue.append(item)
def pop(self):
with self.lock:
if self.queue:
return self.queue.pop(0)
return None
def top(self):
with self.lock:
if self.queue:
return self.queue[0]
return None
def next(self):
while True:
with self.lock:
if self.queue:
return self.queue.pop(0)
time.sleep(0.001)
class AsyncStream:
def __init__(self):
self.input_queue = FIFOQueue()
self.output_queue = FIFOQueue()

View File

@ -0,0 +1,141 @@
# Better Flow Matching UniPC by Lvmin Zhang
# (c) 2025
# CC BY-SA 4.0
# Attribution-ShareAlike 4.0 International Licence
import torch
from tqdm.auto import trange
def expand_dims(v, dims):
return v[(...,) + (None,) * (dims - 1)]
class FlowMatchUniPC:
def __init__(self, model, extra_args, variant='bh1'):
self.model = model
self.variant = variant
self.extra_args = extra_args
def model_fn(self, x, t):
return self.model(x, t, **self.extra_args)
def update_fn(self, x, model_prev_list, t_prev_list, t, order):
assert order <= len(model_prev_list)
dims = x.dim()
t_prev_0 = t_prev_list[-1]
lambda_prev_0 = - torch.log(t_prev_0)
lambda_t = - torch.log(t)
model_prev_0 = model_prev_list[-1]
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)]
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = - torch.log(t_prev_i)
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.)
rks = torch.tensor(rks, device=x.device)
R = []
b = []
hh = -h[0]
h_phi_1 = torch.expm1(hh)
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.variant == 'bh1':
B_h = hh
elif self.variant == 'bh2':
B_h = torch.expm1(hh)
else:
raise NotImplementedError('Bad variant!')
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=x.device)
use_predictor = len(D1s) > 0
if use_predictor:
D1s = torch.stack(D1s, dim=1)
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
rhos_p = None
if order == 1:
rhos_c = torch.tensor([0.5], device=b.device)
else:
rhos_c = torch.linalg.solve(R, b)
x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
if use_predictor:
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
else:
pred_res = 0
x_t = x_t_ - expand_dims(B_h, dims) * pred_res
model_t = self.model_fn(x_t, t)
if D1s is not None:
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
else:
corr_res = 0
D1_t = model_t - model_prev_0
x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
return x_t, model_t
def sample(self, x, sigmas, callback=None, disable_pbar=False):
order = min(3, len(sigmas) - 2)
model_prev_list, t_prev_list = [], []
for i in trange(len(sigmas) - 1, disable=disable_pbar):
vec_t = sigmas[i].expand(x.shape[0])
if i == 0:
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
elif i < order:
init_order = i
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
else:
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
model_prev_list = model_prev_list[-order:]
t_prev_list = t_prev_list[-order:]
if callback is not None:
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
return model_prev_list[-1]
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
assert variant in ['bh1', 'bh2']
return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)

View File

@ -0,0 +1,567 @@
import os
import cv2
import json
import random
import glob
import torch
import einops
import numpy as np
import datetime
import torchvision
import safetensors.torch as sf
from PIL import Image
def min_resize(x, m):
if x.shape[0] < x.shape[1]:
s0 = m
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
else:
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
s1 = m
new_max = max(s1, s0)
raw_max = max(x.shape[0], x.shape[1])
if new_max < raw_max:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
return y
def d_resize(x, y):
H, W, C = y.shape
new_min = min(H, W)
raw_min = min(x.shape[0], x.shape[1])
if new_min < raw_min:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_LANCZOS4
y = cv2.resize(x, (W, H), interpolation=interpolation)
return y
def resize_and_center_crop(image, target_width, target_height):
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return np.array(cropped_image)
def resize_and_center_crop_pytorch(image, target_width, target_height):
B, C, H, W = image.shape
if H == target_height and W == target_width:
return image
scale_factor = max(target_width / W, target_height / H)
resized_width = int(round(W * scale_factor))
resized_height = int(round(H * scale_factor))
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
top = (resized_height - target_height) // 2
left = (resized_width - target_width) // 2
cropped = resized[:, :, top:top + target_height, left:left + target_width]
return cropped
def resize_without_crop(image, target_width, target_height):
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
return np.array(resized_image)
def just_crop(image, w, h):
if h == image.shape[0] and w == image.shape[1]:
return image
original_height, original_width = image.shape[:2]
k = min(original_height / h, original_width / w)
new_width = int(round(w * k))
new_height = int(round(h * k))
x_start = (original_width - new_width) // 2
y_start = (original_height - new_height) // 2
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
return cropped_image
def write_to_json(data, file_path):
temp_file_path = file_path + ".tmp"
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
json.dump(data, temp_file, indent=4)
os.replace(temp_file_path, file_path)
return
def read_from_json(file_path):
with open(file_path, 'rt', encoding='utf-8') as file:
data = json.load(file)
return data
def get_active_parameters(m):
return {k: v for k, v in m.named_parameters() if v.requires_grad}
def cast_training_params(m, dtype=torch.float32):
result = {}
for n, param in m.named_parameters():
if param.requires_grad:
param.data = param.to(dtype)
result[n] = param
return result
def separate_lora_AB(parameters, B_patterns=None):
parameters_normal = {}
parameters_B = {}
if B_patterns is None:
B_patterns = ['.lora_B.', '__zero__']
for k, v in parameters.items():
if any(B_pattern in k for B_pattern in B_patterns):
parameters_B[k] = v
else:
parameters_normal[k] = v
return parameters_normal, parameters_B
def set_attr_recursive(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
return
@torch.no_grad()
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
batch_size = a.size(0)
if b is None:
b = torch.zeros_like(a)
if mask_a is None:
mask_a = torch.rand(batch_size) < probability_a
mask_a = mask_a.to(a.device)
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
result = torch.where(mask_a, a, b)
return result
@torch.no_grad()
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
@torch.no_grad()
def supress_lower_channels(m, k, alpha=0.01):
data = m.weight.data.clone()
assert int(data.shape[1]) >= k
data[:, :k] = data[:, :k] * alpha
m.weight.data = data.contiguous().clone()
return m
def freeze_module(m):
if not hasattr(m, '_forward_inside_frozen_module'):
m._forward_inside_frozen_module = m.forward
m.requires_grad_(False)
m.forward = torch.no_grad()(m.forward)
return m
def get_latest_safetensors(folder_path):
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
if not safetensors_files:
raise ValueError('No file to resume!')
latest_file = max(safetensors_files, key=os.path.getmtime)
latest_file = os.path.abspath(os.path.realpath(latest_file))
return latest_file
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
tags = tags_str.split(', ')
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
prompt = ', '.join(tags)
return prompt
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
if round_to_int:
numbers = np.round(numbers).astype(int)
return numbers.tolist()
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
edges = np.linspace(0, 1, n + 1)
points = np.random.uniform(edges[:-1], edges[1:])
numbers = inclusive + (exclusive - inclusive) * points
if round_to_int:
numbers = np.round(numbers).astype(int)
return numbers.tolist()
def soft_append_bcthw(history, current, overlap=0):
if overlap <= 0:
return torch.cat([history, current], dim=2)
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
return output.to(history)
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
b, c, t, h, w = x.shape
per_row = b
for p in [6, 5, 4, 3, 2]:
if b % p == 0:
per_row = p
break
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
return x
def save_bcthw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def save_bchw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c h w -> c h (b w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def add_tensors_with_padding(tensor1, tensor2):
if tensor1.shape == tensor2.shape:
return tensor1 + tensor2
shape1 = tensor1.shape
shape2 = tensor2.shape
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
padded_tensor1 = torch.zeros(new_shape)
padded_tensor2 = torch.zeros(new_shape)
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
result = padded_tensor1 + padded_tensor2
return result
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
from PIL import Image, ImageDraw, ImageFont
txt = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype(font_path, size=size)
if text == '':
return np.array(txt)
# Split text into lines that fit within the image width
lines = []
words = text.split()
current_line = words[0]
for word in words[1:]:
line_with_word = f"{current_line} {word}"
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
current_line = line_with_word
else:
lines.append(current_line)
current_line = word
lines.append(current_line)
# Draw the text line by line
y = 0
line_height = draw.textbbox((0, 0), "A", font=font)[3]
for line in lines:
if y + line_height > height:
break # stop drawing if the next line will be outside the image
draw.text((0, y), line, fill="black", font=font)
y += line_height
return np.array(txt)
def blue_mark(x):
x = x.copy()
c = x[:, :, 2]
b = cv2.blur(c, (9, 9))
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
return x
def green_mark(x):
x = x.copy()
x[:, :, 2] = -1
x[:, :, 0] = -1
return x
def frame_mark(x):
x = x.copy()
x[:64] = -1
x[-64:] = -1
x[:, :8] = 1
x[:, -8:] = 1
return x
@torch.inference_mode()
def pytorch2numpy(imgs):
results = []
for x in imgs:
y = x.movedim(0, -1)
y = y * 127.5 + 127.5
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
results.append(y)
return results
@torch.inference_mode()
def numpy2pytorch(imgs):
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
h = h.movedim(-1, 1)
return h
@torch.no_grad()
def duplicate_prefix_to_suffix(x, count, zero_out=False):
if zero_out:
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
else:
return torch.cat([x, x[:count]], dim=0)
def weighted_mse(a, b, weight):
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
x = (x - x_min) / (x_max - x_min)
x = max(0.0, min(x, 1.0))
x = x ** sigma
return y_min + x * (y_max - y_min)
def expand_to_dims(x, target_dims):
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
if tensor is None:
return None
first_dim = tensor.shape[0]
if first_dim == batch_size:
return tensor
if batch_size % first_dim != 0:
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
repeat_times = batch_size // first_dim
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
def dim5(x):
return expand_to_dims(x, 5)
def dim4(x):
return expand_to_dims(x, 4)
def dim3(x):
return expand_to_dims(x, 3)
def crop_or_pad_yield_mask(x, length):
B, F, C = x.shape
device = x.device
dtype = x.dtype
if F < length:
y = torch.zeros((B, length, C), dtype=dtype, device=device)
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
y[:, :F, :] = x
mask[:, :F] = True
return y, mask
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
def extend_dim(x, dim, minimal_length, zero_pad=False):
original_length = int(x.shape[dim])
if original_length >= minimal_length:
return x
if zero_pad:
padding_shape = list(x.shape)
padding_shape[dim] = minimal_length - original_length
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
else:
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
last_element = x[idx]
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
return torch.cat([x, padding], dim=dim)
def lazy_positional_encoding(t, repeats=None):
if not isinstance(t, list):
t = [t]
from diffusers.models.embeddings import get_timestep_embedding
te = torch.tensor(t)
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
if repeats is None:
return te
te = te[:, None, :].expand(-1, repeats, -1)
return te
def state_dict_offset_merge(A, B, C=None):
result = {}
keys = A.keys()
for key in keys:
A_value = A[key]
B_value = B[key].to(A_value)
if C is None:
result[key] = A_value + B_value
else:
C_value = C[key].to(A_value)
result[key] = A_value + B_value - C_value
return result
def state_dict_weighted_merge(state_dicts, weights):
if len(state_dicts) != len(weights):
raise ValueError("Number of state dictionaries must match number of weights")
if not state_dicts:
return {}
total_weight = sum(weights)
if total_weight == 0:
raise ValueError("Sum of weights cannot be zero")
normalized_weights = [w / total_weight for w in weights]
keys = state_dicts[0].keys()
result = {}
for key in keys:
result[key] = state_dicts[0][key] * normalized_weights[0]
for i in range(1, len(state_dicts)):
state_dict_value = state_dicts[i][key].to(result[key])
result[key] += state_dict_value * normalized_weights[i]
return result
def group_files_by_folder(all_files):
grouped_files = {}
for file in all_files:
folder_name = os.path.basename(os.path.dirname(file))
if folder_name not in grouped_files:
grouped_files[folder_name] = []
grouped_files[folder_name].append(file)
list_of_lists = list(grouped_files.values())
return list_of_lists
def generate_timestamp():
now = datetime.datetime.now()
timestamp = now.strftime('%y%m%d_%H%M%S')
milliseconds = f"{int(now.microsecond / 1000):03d}"
random_number = random.randint(0, 9999)
return f"{timestamp}_{milliseconds}_{random_number}"
def write_PIL_image_with_png_info(image, metadata, path):
from PIL.PngImagePlugin import PngInfo
png_info = PngInfo()
for key, value in metadata.items():
png_info.add_text(key, value)
image.save(path, "PNG", pnginfo=png_info)
return image
def torch_safe_save(content, path):
torch.save(content, path + '_tmp')
os.replace(path + '_tmp', path)
return path
def move_optimizer_to_device(optimizer, device):
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)

View File

@ -0,0 +1,51 @@
import torch
def append_dims(x, target_dims):
return x[(...,) + (None,) * (target_dims - x.ndim)]
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
if guidance_rescale == 0:
return noise_cfg
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
return noise_cfg
def fm_wrapper(transformer, t_scale=1000.0):
def k_model(x, sigma, **extra_args):
dtype = extra_args['dtype']
cfg_scale = extra_args['cfg_scale']
cfg_rescale = extra_args['cfg_rescale']
concat_latent = extra_args['concat_latent']
original_dtype = x.dtype
sigma = sigma.float()
x = x.to(dtype)
timestep = (sigma * t_scale).to(dtype)
if concat_latent is None:
hidden_states = x
else:
hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
if cfg_scale == 1.0:
pred_negative = torch.zeros_like(pred_positive)
else:
pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
return x0.to(dtype=original_dtype)
return k_model

View File

@ -8,7 +8,7 @@ from diffusers.pipelines import auto_pipeline
current_steps = 50
def sd15_hidiffusion_key():
modified_key = dict()
modified_key = {}
modified_key['down_module_key'] = ['down_blocks.0.downsamplers.0.conv']
modified_key['down_module_key_extra'] = ['down_blocks.1']
modified_key['up_module_key'] = ['up_blocks.2.upsamplers.0.conv']
@ -22,7 +22,7 @@ def sd15_hidiffusion_key():
return modified_key
def sdxl_hidiffusion_key():
modified_key = dict()
modified_key = {}
modified_key['down_module_key'] = ['down_blocks.1']
modified_key['down_module_key_extra'] = ['down_blocks.1.downsamplers.0.conv']
modified_key['up_module_key'] = ['up_blocks.1']
@ -42,7 +42,7 @@ def sdxl_hidiffusion_key():
def sdxl_turbo_hidiffusion_key():
modified_key = dict()
modified_key = {}
modified_key['down_module_key'] = ['down_blocks.1']
modified_key['up_module_key'] = ['up_blocks.1']
modified_key['windown_attn_module_key'] = [

View File

@ -164,14 +164,11 @@ def qwen(question: str, image: Image.Image, repo: str = None, system_prompt: str
def gemma(question: str, image: Image.Image, repo: str = None, system_prompt: str = None):
global processor, model, loaded # pylint: disable=global-statement
if not hasattr(transformers, 'Gemma3ForConditionalGeneration'):
shared.log.error(f'Interrogate: vlm="{repo}" gemma is not available')
return ''
if model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
model = None
if '3n' in repo:
cls = transformers.Gemma3nForConditionalGeneration
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
else:
cls = transformers.Gemma3ForConditionalGeneration
model = cls.from_pretrained(

View File

@ -20,9 +20,10 @@ try:
import numpy.random # pylint: disable=W0611,C0411 # this causes failure if numpy version changed
def obj2sctype(obj):
return np.dtype(obj).type
np.obj2sctype = obj2sctype # noqa: NPY201
np.bool8 = np.bool
np.float_ = np.float64 # noqa: NPY201
if np.__version__.startswith('2.'): # monkeypatch for np==1.2 compatibility
np.obj2sctype = obj2sctype # noqa: NPY201
np.bool8 = np.bool
np.float_ = np.float64 # noqa: NPY201
except Exception as e:
errors.log.error(f'Loader: numpy=={np.__version__ if np is not None else None} {e}')
errors.log.error('Please restart the app to fix this issue')

View File

@ -47,7 +47,7 @@ def edge_detect_for_pixelart(image: PipelineImageInput, image_weight: float = 1.
block_size_sq = block_size * block_size
new_image = process_image_input(image).to(device, dtype=torch.float32) / 255
new_image = new_image.permute(0,3,1,2)
batch_size, channels, height, width = new_image.shape
batch_size, _channels, height, width = new_image.shape
min_pool = -torch.nn.functional.max_pool2d(-new_image, block_size, 1, block_size//2, 1, False, False)
min_pool = min_pool[:, :, :height, :width]
@ -203,6 +203,7 @@ class JPEGEncoder(ImageProcessingMixin, ConfigMixin):
self.norm = norm
self.latents_std = latents_std
self.latents_mean = latents_mean
super().__init__()
def encode(self, images: PipelineImageInput, device: str="cpu") -> torch.FloatTensor:

View File

@ -171,7 +171,7 @@ def apply_styles_to_extra(p, style: Style):
k = name_map[k]
if k in name_exclude: # exclude some fields
continue
elif hasattr(p, k):
if hasattr(p, k):
orig = getattr(p, k)
if (type(orig) != type(v)) and (orig is not None):
if not (type(orig) == int and type(v) == float): # dont convert float to int

View File

@ -129,7 +129,7 @@ def create_ui(_blocks: gr.Blocks=None):
with gr.Row(elem_id='control_status'):
result_txt = gr.HTML(elem_classes=['control-result'], elem_id='control-result')
with gr.Row(elem_id='control_settings'):
with gr.Row(elem_id='control_settings', elem_classes=['settings-column']):
state = gr.Textbox(value='', visible=False)

View File

@ -1001,5 +1001,5 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
return ui
def setup_ui(ui, gallery):
def setup_ui(ui, gallery: gr.Gallery = None):
ui.gallery = gallery

View File

@ -47,7 +47,7 @@ def create_ui():
timer.startup.record('ui-networks')
with gr.Row(elem_id="img2img_interface", equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
with gr.Column(variant='compact', elem_id="img2img_settings", elem_classes=['settings-column']):
copy_image_buttons = []
copy_image_destinations = {}

View File

@ -4,7 +4,7 @@ from modules.ui_components import ToolButton
from modules.interrogate import interrogate
def create_toprow(is_img2img: bool = False, id_part: str = None, negative_visible: bool = True, reprocess_visible: bool = True):
def create_toprow(is_img2img: bool = False, id_part: str = None, generate_visible: bool = True, negative_visible: bool = True, reprocess_visible: bool = True):
def apply_styles(prompt, prompt_neg, styles):
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles, wildcards=not shared.opts.extra_networks_apply_unparsed)
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles, wildcards=not shared.opts.extra_networks_apply_unparsed)
@ -29,7 +29,7 @@ def create_toprow(is_img2img: bool = False, id_part: str = None, negative_visibl
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
with gr.Row(elem_id=f"{id_part}_generate_box"):
reprocess = []
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary', visible=generate_visible)
if reprocess_visible:
reprocess.append(gr.Button('Reprocess', elem_id=f"{id_part}_reprocess", variant='primary', visible=True))
reprocess.append(gr.Button('Reprocess decode', elem_id=f"{id_part}_reprocess_decode", variant='primary', visible=False))

View File

@ -21,7 +21,7 @@ def create_ui():
timer.startup.record('ui-networks')
with gr.Row(elem_id="txt2img_interface", equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
with gr.Column(variant='compact', elem_id="txt2img_settings", elem_classes=['settings-column']):
with gr.Row():
width, height = ui_sections.create_resolution_inputs('txt2img')

View File

@ -1,190 +1,49 @@
import os
import gradio as gr
from modules import shared, sd_models, timer, images, ui_common, ui_sections, ui_symbols, call_queue, generation_parameters_copypaste
from modules.ui_components import ToolButton
from modules.video_models import models_def, video_utils
from modules import shared, timer, images, ui_common, ui_sections, generation_parameters_copypaste
debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None
def engine_change(engine):
debug(f'Video change: engine="{engine}"')
found = [model.name for model in models_def.models.get(engine, [])]
return gr.update(choices=found, value=found[0] if len(found) > 0 else None)
def get_selected(engine, model):
found = [model.name for model in models_def.models.get(engine, [])]
if len(models_def.models[engine]) > 0 and len(found) > 0:
selected = [m for m in models_def.models[engine] if m.name == model][0]
return selected
return None
def model_change(engine, model):
debug(f'Video change: engine="{engine}" model="{model}"')
found = [model.name for model in models_def.models.get(engine, [])]
selected = [m for m in models_def.models[engine] if m.name == model][0] if len(found) > 0 else None
url = video_utils.get_url(selected.url if selected else None)
i2v = 'i2v' in selected.name.lower() if selected else False
return url, gr.update(visible=i2v)
def model_load(engine, model):
debug(f'Video load: engine="{engine}" model="{model}"')
selected = get_selected(engine, model)
yield f'Video model loading: {selected.name}'
if selected:
if 'None' in selected.name:
sd_models.unload_model_weights()
msg = 'Video model unloaded'
else:
from modules.video_models import video_load
msg = video_load.load_model(selected)
else:
sd_models.unload_model_weights()
msg = 'Video model unloaded'
yield msg
return msg
def run_video(*args):
engine, model = args[2], args[3]
debug(f'Video run: engine="{engine}" model="{model}"')
selected = get_selected(engine, model)
if not selected or engine is None or model is None or engine == 'None' or model == 'None':
return video_utils.queue_err('model not selected')
debug(f'Video run: {str(selected)}')
from modules.video_models import video_run
if selected and 'Hunyuan' in selected.name:
return video_run.generate(*args)
elif selected and 'LTX' in selected.name:
return video_run.generate(*args)
elif selected and 'Mochi' in selected.name:
return video_run.generate(*args)
elif selected and 'Cog' in selected.name:
return video_run.generate(*args)
elif selected and 'Allegro' in selected.name:
return video_run.generate(*args)
elif selected and 'WAN' in selected.name:
return video_run.generate(*args)
elif selected and 'Latte' in selected.name:
return video_run.generate(*args)
elif selected and 'anisora' in selected.name.lower():
return video_run.generate(*args)
return video_utils.queue_err(f'model not found: engine="{engine}" model="{model}"')
def create_ui():
shared.log.debug('UI initialize: video')
with gr.Blocks(analytics_enabled=False) as _video_interface:
prompt, styles, negative, generate, _reprocess, paste, networks_button, _token_counter, _token_button, _token_counter_negative, _token_button_negative = ui_sections.create_toprow(is_img2img=False, id_part="video", negative_visible=True, reprocess_visible=False)
prompt, styles, negative, generate_btn, _reprocess, paste, networks_button, _token_counter, _token_button, _token_counter_negative, _token_button_negative = ui_sections.create_toprow(
is_img2img=False,
id_part="video",
negative_visible=True,
reprocess_visible=False,
)
prompt_image = gr.File(label="", elem_id="video_prompt_image", file_count="single", type="binary", visible=False)
prompt_image.change(fn=images.image_data, inputs=[prompt_image], outputs=[prompt, prompt_image])
with gr.Row(variant='compact', elem_id="video_extra_networks", elem_classes=["extra_networks_root"], visible=False) as extra_networks_ui:
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, networks_button, 'video', skip_indexing=shared.opts.extra_network_skip_indexing)
ui_extra_networks.setup_ui(extra_networks_ui)
timer.startup.record('ui-networks')
with gr.Row(elem_id="video_interface", equal_height=False):
with gr.Column(variant='compact', elem_id="video_settings", scale=1):
with gr.Tabs(elem_classes=['video-tabs'], elem_id='video-tabs'):
overrides = ui_common.create_override_inputs('video')
with gr.Tab('Video', id='video-tab') as video_tab:
from modules.video_models import video_ui
video_ui.create_ui(prompt, negative, styles, overrides)
with gr.Tab('FramePack', id='framepack-tab') as framepack_tab:
from modules.framepack import framepack_ui
framepack_ui.create_ui(prompt, negative, styles, overrides)
with gr.Row():
engine = gr.Dropdown(label='Engine', choices=list(models_def.models), value='None', elem_id="video_engine")
model = gr.Dropdown(label='Model', choices=[''], value=None, elem_id="video_model")
btn_load = ToolButton(ui_symbols.loading, elem_id="video_model_load")
with gr.Row():
url = gr.HTML(label='Model URL', elem_id='video_model_url', value='<br><br>')
with gr.Accordion(open=True, label="Size", elem_id='video_size_accordion'):
with gr.Row():
width, height = ui_sections.create_resolution_inputs('video', default_width=832, default_height=480)
with gr.Row():
frames = gr.Slider(label='Frames', minimum=1, maximum=1024, step=1, value=15, elem_id="video_frames")
seed = gr.Number(label='Initial seed', value=-1, elem_id="video_seed", container=True)
random_seed = ToolButton(ui_symbols.random, elem_id="video_random_seed")
reuse_seed = ToolButton(ui_symbols.reuse, elem_id="video_reuse_seed")
with gr.Accordion(open=True, label="Parameters", elem_id='video_parameters_accordion'):
steps, sampler_index = ui_sections.create_sampler_and_steps_selection(None, "video")
with gr.Row():
sampler_shift = gr.Slider(label='Sampler shift', minimum=-1.0, maximum=20.0, step=0.1, value=-1.0, elem_id="video_scheduler_shift")
dynamic_shift = gr.Checkbox(label='Dynamic shift', value=False, elem_id="video_dynamic_shift")
with gr.Row():
guidance_scale = gr.Slider(label='Guidance scale', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_scale")
guidance_true = gr.Slider(label='True guidance', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_true")
with gr.Accordion(open=True, label="Decode", elem_id='video_decode_accordion'):
with gr.Row():
vae_type = gr.Dropdown(label='VAE decode', choices=['Default', 'Tiny', 'Remote'], value='Default', elem_id="video_vae_type")
vae_tile_frames = gr.Slider(label='Tile frames', minimum=1, maximum=64, step=1, value=16, elem_id="video_vae_tile_frames")
with gr.Accordion(open=False, label="Init image", elem_id='video_init_accordion', visible=False) as init_accordion:
init_strength = gr.Slider(label='Init strength', minimum=0.0, maximum=1.0, step=0.01, value=0.5, elem_id="video_denoising_strength")
gr.HTML("<br>&nbsp Init image")
init_image = gr.Image(elem_id="video_image", show_label=False, type="pil", image_mode="RGB", height=512)
gr.HTML("<br>&nbsp Last image")
last_image = gr.Image(elem_id="video_last", show_label=False, type="pil", image_mode="RGB", height=512)
with gr.Accordion(open=True, label="Output", elem_id='video_output_accordion'):
with gr.Row():
save_frames = gr.Checkbox(label='Save image frames', value=False, elem_id="video_save_frames")
with gr.Row():
video_type, video_duration, video_loop, video_pad, video_interpolate = ui_sections.create_video_inputs(tab='video', show_always=True)
override_settings = ui_common.create_override_inputs('video')
# output panel with gallery and video tabs
with gr.Column(elem_id='video-output-column', scale=2) as _column_output:
with gr.Tabs(elem_classes=['video-output-tabs'], elem_id='video-output-tabs'):
with gr.Tab('Frames', id='out-gallery'):
gallery, gen_info, html_info, _html_info_formatted, html_log = ui_common.create_output_panel("video", prompt=prompt, preview=False, transfer=False, scale=2)
with gr.Tab('Video', id='out-video'):
video = gr.Video(label="Output", show_label=False, elem_id='control_output_video', elem_classes=['control-image'], height=512, autoplay=False)
# connect reuse seed button
ui_common.connect_reuse_seed(seed, reuse_seed, gen_info, is_subseed=False)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
# handle engine and model change
engine.change(fn=engine_change, inputs=[engine], outputs=[model])
model.change(fn=model_change, inputs=[engine, model], outputs=[url, init_accordion])
btn_load.click(fn=model_load, inputs=[engine, model], outputs=[html_log])
# setup extra networks
ui_extra_networks.setup_ui(extra_networks_ui, gallery)
# handle restore fields
paste_fields = [
(prompt, "Prompt"),
(width, "Size-1"),
(height, "Size-2"),
(frames, "Frames"),
(prompt, "Prompt"), # cannot add more fields as they are not defined yet
]
generation_parameters_copypaste.add_paste_fields("video", None, paste_fields, override_settings)
generation_parameters_copypaste.add_paste_fields("video", None, paste_fields, overrides)
bindings = generation_parameters_copypaste.ParamBinding(paste_button=paste, tabname="video", source_text_component=prompt, source_image_component=None)
generation_parameters_copypaste.register_paste_params_button(bindings)
# hidden fields
task_id = gr.Textbox(visible=False, value='')
ui_state = gr.Textbox(visible=False, value='')
# generate args
video_args = [
task_id, ui_state,
engine, model,
prompt, negative, styles,
width, height,
frames,
steps, sampler_index,
sampler_shift, dynamic_shift,
seed,
guidance_scale, guidance_true,
init_image, init_strength, last_image,
vae_type, vae_tile_frames,
save_frames,
video_type, video_duration, video_loop, video_pad, video_interpolate,
override_settings,
]
# generate function
video_dict = dict(
fn=call_queue.wrap_gradio_gpu_call(run_video, extra_outputs=[None, '', ''], name='Video'),
_js="submit_video",
inputs=video_args,
outputs=[gallery, video, gen_info, html_info, html_log],
show_progress=False,
)
prompt.submit(**video_dict)
generate.click(**video_dict)
current_tab = gr.Textbox(visible=False, value='video')
video_tab.select(fn=lambda: 'video', inputs=[], outputs=[current_tab])
framepack_tab.select(fn=lambda: 'framepack', inputs=[], outputs=[current_tab])
generate_btn.click(fn=None, _js='submit_video_wrapper', inputs=[current_tab], outputs=[])
# from framepack_api import create_api # pylint: disable=wrong-import-order

View File

@ -92,7 +92,6 @@ class UpscalerLatent(Upscaler):
mode, antialias = 'bicubic', True
else:
raise log.error(f"Upscale: type=latent model={selected_model} unknown")
return img
return F.interpolate(img, size=(h, w), mode=mode, antialias=antialias)
@ -170,7 +169,6 @@ class UpscalerVIPS(Upscaler):
if selected_model is None:
return img
from installer import install
from modules.shared import log
install('pyvips')
try:
import pyvips

View File

@ -0,0 +1,173 @@
import os
import gradio as gr
from modules import shared, sd_models, ui_common, ui_sections, ui_symbols, call_queue
from modules.ui_components import ToolButton
from modules.video_models import models_def, video_utils
from modules.video_models import video_run
debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None
def engine_change(engine):
debug(f'Video change: engine="{engine}"')
found = [model.name for model in models_def.models.get(engine, [])]
return gr.update(choices=found, value=found[0] if len(found) > 0 else None)
def get_selected(engine, model):
found = [model.name for model in models_def.models.get(engine, [])]
if len(models_def.models[engine]) > 0 and len(found) > 0:
selected = [m for m in models_def.models[engine] if m.name == model][0]
return selected
return None
def model_change(engine, model):
debug(f'Video change: engine="{engine}" model="{model}"')
found = [model.name for model in models_def.models.get(engine, [])]
selected = [m for m in models_def.models[engine] if m.name == model][0] if len(found) > 0 else None
url = video_utils.get_url(selected.url if selected else None)
i2v = 'i2v' in selected.name.lower() if selected else False
return url, gr.update(visible=i2v)
def model_load(engine, model):
debug(f'Video load: engine="{engine}" model="{model}"')
selected = get_selected(engine, model)
yield f'Video model loading: {selected.name}'
if selected:
if 'None' in selected.name:
sd_models.unload_model_weights()
msg = 'Video model unloaded'
else:
from modules.video_models import video_load
msg = video_load.load_model(selected)
else:
sd_models.unload_model_weights()
msg = 'Video model unloaded'
yield msg
return msg
def run_video(*args):
engine, model = args[2], args[3]
debug(f'Video run: engine="{engine}" model="{model}"')
selected = get_selected(engine, model)
if not selected or engine is None or model is None or engine == 'None' or model == 'None':
return video_utils.queue_err('model not selected')
debug(f'Video run: {str(selected)}')
if selected and 'Hunyuan' in selected.name:
return video_run.generate(*args)
elif selected and 'LTX' in selected.name:
return video_run.generate(*args)
elif selected and 'Mochi' in selected.name:
return video_run.generate(*args)
elif selected and 'Cog' in selected.name:
return video_run.generate(*args)
elif selected and 'Allegro' in selected.name:
return video_run.generate(*args)
elif selected and 'WAN' in selected.name:
return video_run.generate(*args)
elif selected and 'Latte' in selected.name:
return video_run.generate(*args)
elif selected and 'anisora' in selected.name.lower():
return video_run.generate(*args)
return video_utils.queue_err(f'model not found: engine="{engine}" model="{model}"')
def create_ui(prompt, negative, styles, overrides):
with gr.Row():
with gr.Column(variant='compact', elem_id="video_settings", elem_classes=['settings-column'], scale=1):
with gr.Row():
generate = gr.Button('Generate', elem_id="video_generate_btn", variant='primary', visible=False)
with gr.Row():
engine = gr.Dropdown(label='Engine', choices=list(models_def.models), value='None', elem_id="video_engine")
model = gr.Dropdown(label='Model', choices=[''], value=None, elem_id="video_model")
btn_load = ToolButton(ui_symbols.loading, elem_id="video_model_load")
with gr.Row():
url = gr.HTML(label='Model URL', elem_id='video_model_url', value='<br><br>')
with gr.Accordion(open=True, label="Size", elem_id='video_size_accordion'):
with gr.Row():
width, height = ui_sections.create_resolution_inputs('video', default_width=832, default_height=480)
with gr.Row():
frames = gr.Slider(label='Frames', minimum=1, maximum=1024, step=1, value=15, elem_id="video_frames")
seed = gr.Number(label='Initial seed', value=-1, elem_id="video_seed", container=True)
random_seed = ToolButton(ui_symbols.random, elem_id="video_random_seed")
reuse_seed = ToolButton(ui_symbols.reuse, elem_id="video_reuse_seed")
with gr.Accordion(open=True, label="Parameters", elem_id='video_parameters_accordion'):
steps, sampler_index = ui_sections.create_sampler_and_steps_selection(None, "video")
with gr.Row():
sampler_shift = gr.Slider(label='Sampler shift', minimum=-1.0, maximum=20.0, step=0.1, value=-1.0, elem_id="video_scheduler_shift")
dynamic_shift = gr.Checkbox(label='Dynamic shift', value=False, elem_id="video_dynamic_shift")
with gr.Row():
guidance_scale = gr.Slider(label='Guidance scale', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_scale")
guidance_true = gr.Slider(label='True guidance', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_true")
with gr.Accordion(open=True, label="Decode", elem_id='video_decode_accordion'):
with gr.Row():
vae_type = gr.Dropdown(label='VAE decode', choices=['Default', 'Tiny', 'Remote'], value='Default', elem_id="video_vae_type")
vae_tile_frames = gr.Slider(label='Tile frames', minimum=1, maximum=64, step=1, value=16, elem_id="video_vae_tile_frames")
with gr.Accordion(open=False, label="Init image", elem_id='video_init_accordion', visible=False) as init_accordion:
init_strength = gr.Slider(label='Init strength', minimum=0.0, maximum=1.0, step=0.01, value=0.5, elem_id="video_denoising_strength")
gr.HTML("<br>&nbsp Init image")
init_image = gr.Image(elem_id="video_image", show_label=False, type="pil", image_mode="RGB", width=256, height=256)
gr.HTML("<br>&nbsp Last image")
last_image = gr.Image(elem_id="video_last", show_label=False, type="pil", image_mode="RGB", width=256, height=256)
with gr.Accordion(open=True, label="Output", elem_id='video_output_accordion'):
with gr.Row():
save_frames = gr.Checkbox(label='Save image frames', value=False, elem_id="video_save_frames")
with gr.Row():
video_type, video_duration, video_loop, video_pad, video_interpolate = ui_sections.create_video_inputs(tab='video', show_always=True)
# output panel with gallery and video tabs
with gr.Column(elem_id='video-output-column', scale=2) as _column_output:
with gr.Tabs(elem_classes=['video-output-tabs'], elem_id='video-output-tabs'):
with gr.Tab('Frames', id='out-gallery'):
gallery, gen_info, html_info, _html_info_formatted, html_log = ui_common.create_output_panel("video", prompt=prompt, preview=False, transfer=False, scale=2)
with gr.Tab('Video', id='out-video'):
video = gr.Video(label="Output", show_label=False, elem_id='control_output_video', elem_classes=['control-image'], height=512, autoplay=False)
# connect reuse seed button
ui_common.connect_reuse_seed(seed, reuse_seed, gen_info, is_subseed=False)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
# handle engine and model change
engine.change(fn=engine_change, inputs=[engine], outputs=[model])
model.change(fn=model_change, inputs=[engine, model], outputs=[url, init_accordion])
btn_load.click(fn=model_load, inputs=[engine, model], outputs=[html_log])
# hidden fields
task_id = gr.Textbox(visible=False, value='')
ui_state = gr.Textbox(visible=False, value='')
state_inputs = [task_id, ui_state]
# generate args
video_inputs = [
engine, model,
prompt, negative, styles,
width, height,
frames,
steps, sampler_index,
sampler_shift, dynamic_shift,
seed,
guidance_scale, guidance_true,
init_image, init_strength, last_image,
vae_type, vae_tile_frames,
save_frames,
video_type, video_duration, video_loop, video_pad, video_interpolate,
overrides,
]
video_outputs = [
gallery,
video,
gen_info,
html_info,
html_log,
]
video_dict = dict(
fn=call_queue.wrap_gradio_gpu_call(video_run.generate, extra_outputs=[None, '', ''], name='Video'),
_js="submit_video",
inputs=state_inputs + video_inputs,
outputs=video_outputs,
show_progress=False,
)
generate.click(**video_dict)

View File

@ -362,7 +362,7 @@ class InstantIRPipeline(
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
lora_state_dict = dict()
lora_state_dict = {}
for k, v in unet_state_dict.items():
if "ip" in k:
k = k.replace("attn2", "attn2.processor")

View File

@ -24,7 +24,7 @@ class ScriptPixelArt(scripts_postprocessing.ScriptPostprocessing):
"pixelart_sharpen_amount": pixelart_sharpen_amount,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, pixelart_enabled: bool, pixelart_use_edge_detection: bool, pixelart_block_size: int, pixelart_edge_block_size: int, pixelart_image_weight: float, pixelart_sharpen_amount: float):
def process(self, pp: scripts_postprocessing.PostprocessedImage, pixelart_enabled: bool, pixelart_use_edge_detection: bool, pixelart_block_size: int, pixelart_edge_block_size: int, pixelart_image_weight: float, pixelart_sharpen_amount: float): # pylint: disable=arguments-differ
if not pixelart_enabled:
return
from modules.postprocess.pixelart import img_to_pixelart, edge_detect_for_pixelart

View File

@ -9,7 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F
try:
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
except:
except Exception:
from timm.layers import drop_path, to_2tuple, trunc_normal_
from .transformer import PatchDropout
@ -18,7 +18,7 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
if os.getenv('ENV_TYPE') == 'deepspeed':
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
except Exception:
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
@ -27,7 +27,7 @@ try:
import xformers
import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True
except:
except Exception:
XFORMERS_IS_AVAILBLE = False
class DropPath(nn.Module):

View File

@ -14,7 +14,7 @@ from torch import nn
try:
from .hf_model import HFTextEncoder
except:
except Exception:
HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
@ -23,7 +23,7 @@ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, Tex
try:
from apex.normalization import FusedLayerNorm
except:
except Exception:
FusedLayerNorm = LayerNorm
@dataclass

View File

@ -116,7 +116,7 @@ class SimpleTokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except Exception:
new_word.extend(word[i:])
break

View File

@ -10,7 +10,7 @@ from torch.nn import functional as F
try:
from timm.models.layers import trunc_normal_
except:
except Exception:
from timm.layers import trunc_normal_
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast