mirror of https://github.com/vladmandic/automatic
add builtin framepack
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4039/head
parent
239c3d6dd9
commit
c559e26616
|
|
@ -49,15 +49,12 @@
|
|||
"node/shebang": "off"
|
||||
},
|
||||
"globals": {
|
||||
// asssets
|
||||
"panzoom": "readonly",
|
||||
// logger.js
|
||||
"log": "readonly",
|
||||
"debug": "readonly",
|
||||
"error": "readonly",
|
||||
"xhrGet": "readonly",
|
||||
"xhrPost": "readonly",
|
||||
// script.js
|
||||
"gradioApp": "readonly",
|
||||
"executeCallbacks": "readonly",
|
||||
"onAfterUiUpdate": "readonly",
|
||||
|
|
@ -73,11 +70,8 @@
|
|||
"getUICurrentTabContent": "readonly",
|
||||
"waitForFlag": "readonly",
|
||||
"logFn": "readonly",
|
||||
// contextmenus.js
|
||||
"generateForever": "readonly",
|
||||
// contributors.js
|
||||
"showContributors": "readonly",
|
||||
// ui.js
|
||||
"opts": "writable",
|
||||
"sortUIElements": "readonly",
|
||||
"all_gallery_buttons": "readonly",
|
||||
|
|
@ -97,40 +91,29 @@
|
|||
"toggleCompact": "readonly",
|
||||
"setFontSize": "readonly",
|
||||
"setTheme": "readonly",
|
||||
// settings.js
|
||||
"registerDragDrop": "readonly",
|
||||
// extraNetworks.js
|
||||
"getENActiveTab": "readonly",
|
||||
"quickApplyStyle": "readonly",
|
||||
"quickSaveStyle": "readonly",
|
||||
"setupExtraNetworks": "readonly",
|
||||
"showNetworks": "readonly",
|
||||
// from python
|
||||
"localization": "readonly",
|
||||
// progressbar.js
|
||||
"randomId": "readonly",
|
||||
"requestProgress": "readonly",
|
||||
"setRefreshInterval": "readonly",
|
||||
// imageviewer.js
|
||||
"modalPrevImage": "readonly",
|
||||
"modalNextImage": "readonly",
|
||||
"galleryClickEventHandler": "readonly",
|
||||
"getExif": "readonly",
|
||||
// logMonitor.js
|
||||
"jobStatusEl": "readonly",
|
||||
// loader.js
|
||||
"removeSplash": "readonly",
|
||||
// nvml.js
|
||||
"initNVML": "readonly",
|
||||
"disableNVML": "readonly",
|
||||
// indexdb.js
|
||||
"idbGet": "readonly",
|
||||
"idbPut": "readonly",
|
||||
"idbDel": "readonly",
|
||||
"idbAdd": "readonly",
|
||||
// changelog.js
|
||||
"initChangelog": "readonly",
|
||||
// notification.js
|
||||
"sendNotification": "readonly"
|
||||
},
|
||||
"ignorePatterns": [
|
||||
|
|
|
|||
|
|
@ -26,7 +26,12 @@ repos:
|
|||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: check-yaml
|
||||
args: ["--allow-multiple-documents"]
|
||||
- id: check-builtin-literals
|
||||
- id: check-case-conflict
|
||||
- id: check-json
|
||||
- id: check-symlinks
|
||||
- id: check-toml
|
||||
- id: check-xml
|
||||
- id: end-of-file-fixer
|
||||
- id: mixed-line-ending
|
||||
- id: trailing-whitespace
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ ignore-paths=/usr/lib/.*$,
|
|||
modules/hijack/ddpm_edit.py,
|
||||
modules/intel,
|
||||
modules/intel/ipex,
|
||||
modules/framepack/pipeline,
|
||||
modules/ldsr,
|
||||
modules/onnx_impl,
|
||||
modules/pag,
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@
|
|||
"--docs",
|
||||
"--api-log",
|
||||
"--log", "vscode.log",
|
||||
"${command:pickArgs}",
|
||||
]
|
||||
"${command:pickArgs}"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,5 @@
|
|||
{
|
||||
"python.analysis.extraPaths": [
|
||||
".",
|
||||
"./modules",
|
||||
"./scripts",
|
||||
"./pipelines",
|
||||
],
|
||||
"python.analysis.extraPaths": [".", "./modules", "./scripts", "./pipelines"],
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"editor.formatOnSave": false,
|
||||
"python.REPL.enableREPLSmartSend": false
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ Although upgrades and existing installations are tested and should work fine!
|
|||
- Support **FLUX.1** all-in-one safetensors
|
||||
- Support **TAESD** preview and remote VAE for **HunyuanDit**
|
||||
- Support for [Gemma 3n](https://huggingface.co/google/gemma-3n-E4B-it) E2B and E4B LLM/VLM models in **prompt enhance** and process **captioning**
|
||||
- **FramePack** support is now fully integrated instead of being a separate extension
|
||||
- **UI**
|
||||
- major update to modernui layout
|
||||
- redesign of the Flat UI theme
|
||||
|
|
|
|||
3
TODO.md
3
TODO.md
|
|
@ -4,13 +4,10 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma
|
|||
|
||||
## Current
|
||||
|
||||
- Bug: FramePack with SQND
|
||||
|
||||
## Future Candidates
|
||||
|
||||
- Feature: Common repo for `T5` and `CLiP`
|
||||
- Feature: LoRA add OMI format support for SD35/FLUX.1
|
||||
- Feature: Merge FramePack into core
|
||||
- Refactor: sampler options
|
||||
- Video: API support
|
||||
|
||||
|
|
|
|||
|
|
@ -488,7 +488,6 @@ function setupExtraNetworksForTab(tabname) {
|
|||
en.style.top = '13em';
|
||||
en.style.transition = 'width 0.3s ease';
|
||||
en.style.zIndex = 100;
|
||||
// gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width = `${100 - 2 - window.opts.extra_networks_sidebar_width}vw`;
|
||||
gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width = `calc(100vw - 2em - min(${window.opts.extra_networks_sidebar_width}vw, 50vw))`;
|
||||
} else {
|
||||
en.style.position = 'relative';
|
||||
|
|
@ -506,6 +505,7 @@ function setupExtraNetworksForTab(tabname) {
|
|||
if (window.opts.extra_networks_card_cover === 'sidebar') en.style.width = 0;
|
||||
gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width = 'unset';
|
||||
}
|
||||
if (tabname === 'video') gradioApp().getElementById('framepack_settings').parentNode.style.width = gradioApp().getElementById(`${tabname}_settings`).parentNode.style.width;
|
||||
}
|
||||
});
|
||||
intersectionObserver.observe(en); // monitor visibility
|
||||
|
|
|
|||
|
|
@ -264,6 +264,25 @@ function submit_video(...args) {
|
|||
return res;
|
||||
}
|
||||
|
||||
function submit_framepack(...args) {
|
||||
const id = randomId();
|
||||
log('submitFramepack', id);
|
||||
requestProgress(id, null, null);
|
||||
window.submit_state = '';
|
||||
args[0] = id;
|
||||
return args;
|
||||
}
|
||||
|
||||
function submit_video_wrapper(...args) {
|
||||
log('submitVideoWrapper', args);
|
||||
if (!args || args.length === 0) {
|
||||
log('submitVideoWrapper: no args');
|
||||
return;
|
||||
}
|
||||
const btn = gradioApp().getElementById(`${args[0]}_generate_btn`);
|
||||
if (btn) btn.click();
|
||||
}
|
||||
|
||||
function submit_postprocessing(...args) {
|
||||
log('SubmitExtras');
|
||||
clearGallery('extras');
|
||||
|
|
|
|||
|
|
@ -58,8 +58,8 @@ _ACT_LAYER_ME = dict(
|
|||
hard_sigmoid=HardSigmoidMe
|
||||
)
|
||||
|
||||
_OVERRIDE_FN = dict()
|
||||
_OVERRIDE_LAYER = dict()
|
||||
_OVERRIDE_FN = {}
|
||||
_OVERRIDE_LAYER = {}
|
||||
|
||||
|
||||
def add_override_act_fn(name, fn):
|
||||
|
|
@ -75,7 +75,7 @@ def update_override_act_fn(overrides):
|
|||
|
||||
def clear_override_act_fn():
|
||||
global _OVERRIDE_FN
|
||||
_OVERRIDE_FN = dict()
|
||||
_OVERRIDE_FN = {}
|
||||
|
||||
|
||||
def add_override_act_layer(name, fn):
|
||||
|
|
@ -90,7 +90,7 @@ def update_override_act_layer(overrides):
|
|||
|
||||
def clear_override_act_layer():
|
||||
global _OVERRIDE_LAYER
|
||||
_OVERRIDE_LAYER = dict()
|
||||
_OVERRIDE_LAYER = {}
|
||||
|
||||
|
||||
def get_act_fn(name='relu'):
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class EasyDict(dict):
|
|||
__setitem__ = __setattr__
|
||||
|
||||
def update(self, e=None, **f):
|
||||
d = e or dict()
|
||||
d = e or {}
|
||||
d.update(f)
|
||||
for k in d:
|
||||
setattr(self, k, d[k])
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class HCounter(PDH_HCOUNTER):
|
|||
itemBuffer = cast(malloc(c_size_t(bufferSize.value)), PPDH_FMT_COUNTERVALUE_ITEM_W)
|
||||
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), itemBuffer) != PDH_OK:
|
||||
raise PDHError("Couldn't get formatted counter array.")
|
||||
result: dict[str, T] = dict()
|
||||
result: dict[str, T] = {}
|
||||
for i in range(0, itemCount.value):
|
||||
item = itemBuffer[i]
|
||||
result[item.szName] = getattr(item.FmtValue.u, attr_name)
|
||||
|
|
|
|||
|
|
@ -189,7 +189,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|||
for j, a in enumerate(args):
|
||||
try:
|
||||
args[j] = eval(a) if isinstance(a, str) else a # eval strings
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
n = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
import argparse
|
||||
import requests
|
||||
import urllib3
|
||||
from PIL import Image
|
||||
|
||||
|
||||
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
||||
sd_username = os.environ.get('SDAPI_USR', None)
|
||||
sd_password = os.environ.get('SDAPI_PWD', None)
|
||||
|
||||
|
||||
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
def auth():
|
||||
if sd_username is not None and sd_password is not None:
|
||||
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
||||
return None
|
||||
|
||||
|
||||
def get(endpoint: str, dct: dict = None):
|
||||
req = requests.get(f'{sd_url}{endpoint}', json=dct, timeout=300, verify=False, auth=auth())
|
||||
if req.status_code != 200:
|
||||
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
||||
else:
|
||||
return req.json()
|
||||
|
||||
|
||||
def post(endpoint: str, dct: dict = None):
|
||||
req = requests.post(f'{sd_url}{endpoint}', json = dct, timeout=None, verify=False, auth=auth())
|
||||
if req.status_code != 200:
|
||||
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
||||
else:
|
||||
return req.json()
|
||||
|
||||
|
||||
def encode(f):
|
||||
if not os.path.exists(f):
|
||||
log.error(f'file not found: {f}')
|
||||
os._exit(1)
|
||||
image = Image.open(f)
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
with io.BytesIO() as stream:
|
||||
image.save(stream, 'JPEG')
|
||||
image.close()
|
||||
values = stream.getvalue()
|
||||
encoded = base64.b64encode(values).decode()
|
||||
return encoded
|
||||
|
||||
|
||||
def generate(args): # pylint: disable=redefined-outer-name
|
||||
request = {
|
||||
'variant': args.variant,
|
||||
'prompt': args.prompt,
|
||||
'section_prompt': args.sections,
|
||||
'init_image': encode(args.init),
|
||||
'end_image': encode(args.end) if args.end else None,
|
||||
'resolution': int(args.resolution),
|
||||
'duration': float(args.duration),
|
||||
'mp4_fps': int(args.fps),
|
||||
'seed': int(args.seed),
|
||||
'steps': int(args.steps),
|
||||
'shift': float(args.shift),
|
||||
'cfg_scale': float(args.scale),
|
||||
'cfg_rescale': float(args.rescale),
|
||||
'cfg_distilled': float(args.distilled),
|
||||
'use_teacache': bool(args.teacache),
|
||||
'vlm_enhance': bool(args.enhance),
|
||||
}
|
||||
log.info(f'request: {args}')
|
||||
result = post('/sdapi/v1/framepack', request) # can abandon request here and not wait for response or wait synchronously
|
||||
log.info(f'response: {result}')
|
||||
|
||||
progress = get('/sdapi/v1/progress?skip_current_image=true', None) # monitor progress of the current task
|
||||
task_id = progress.get('id', None)
|
||||
log.info(f'id: {task_id}')
|
||||
log.info(f'progress: {progress}')
|
||||
|
||||
outputs = []
|
||||
history = get(f'/sdapi/v1/history?id={task_id}') # get history for the task
|
||||
for event in history:
|
||||
log.info(f'history: {event}')
|
||||
outputs = event.get('outputs', [])
|
||||
|
||||
log.info(f'outputs: {outputs}') # you can download output files using /file={filename} endpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description = 'api-framepack')
|
||||
parser.add_argument('--init', required=True, help='init image')
|
||||
parser.add_argument('--end', required=False, help='init image')
|
||||
parser.add_argument('--prompt', required=False, default='', help='prompt text')
|
||||
parser.add_argument('--sections', required=False, default='', help='per-section prompts')
|
||||
parser.add_argument('--resolution', type=int, required=False, default=640, help='video resolution')
|
||||
parser.add_argument('--duration', type=float, required=False, default=4.0, help='video duration')
|
||||
parser.add_argument('--fps', type=int, required=False, default=30, help='video frames per second')
|
||||
parser.add_argument('--seed', type=int, required=False, default=-1, help='random seed')
|
||||
parser.add_argument('--enhance', required=False, action='store_true', help='enable prompt enhancer')
|
||||
parser.add_argument('--teacache', required=False, action='store_true', help='enable teacache')
|
||||
parser.add_argument('--steps', type=int, default=25, help='steps')
|
||||
parser.add_argument('--scale', type=float, default=1.0, help='cfg scale')
|
||||
parser.add_argument('--rescale', type=float, default=0.0, help='cfg rescale')
|
||||
parser.add_argument('--distilled', type=float, default=10.0, help='cfg distilled')
|
||||
parser.add_argument('--shift', type=float, default=3.0, help='sampler shift')
|
||||
parser.add_argument('--variant', type=str, default='bi-directional', choices=['bi-directional', 'forward-only'], help='model variant')
|
||||
args = parser.parse_args()
|
||||
log.info(f'api-framepack: {args}')
|
||||
generate(args)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors.torch import safe_open
|
||||
from tqdm.rich import trange
|
||||
|
||||
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
||||
log = logging.getLogger("sd")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description = 'framepack-cli')
|
||||
parser.add_argument('--input', required=True, help='input safetensors')
|
||||
parser.add_argument('--cv2', required=False, help='encode video file using cv2')
|
||||
parser.add_argument('--tv', required=False, help='encode video file using torchvision')
|
||||
parser.add_argument('--codec', default='libx264', help='specify video codec')
|
||||
parser.add_argument('--export', required=False, help='export frames as images to folder')
|
||||
parser.add_argument('--fps', default=30, help='frames-per-second')
|
||||
args = parser.parse_args()
|
||||
|
||||
log.info(f'framepack-cli: {args}')
|
||||
log.info(f'torch={torch.__version__} torchvision={torchvision.__version__}')
|
||||
|
||||
with safe_open(args.input, framework="pt", device="cpu") as f:
|
||||
frames = f.get_tensor('frames')
|
||||
metadata = f.metadata()
|
||||
n, h, w, _c = frames.shape
|
||||
log.info(f'file: metadata={metadata}')
|
||||
log.info(f'tensor: frames={n} shape={frames.shape} dtype={frames.dtype} device={frames.device}')
|
||||
fn = os.path.splitext(os.path.basename(args.input))[0]
|
||||
|
||||
if args.export:
|
||||
log.info(f'export: folder="{args.export}" prefix="{fn}" frames={n} width={w} height={h}')
|
||||
os.makedirs(args.export, exist_ok=True)
|
||||
for i in trange(n):
|
||||
image = cv2.cvtColor(frames[i].numpy(), cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(args.export, f'{fn}-{i:05d}.jpg'), image)
|
||||
|
||||
if args.cv2:
|
||||
log.info(f'encode: file={args.cv2} frames={n} width={w} height={h} fps={args.fps} method=cv2')
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video = cv2.VideoWriter(args.cv2, fourcc, args.fps, (w, h))
|
||||
for i in trange(n):
|
||||
image = cv2.cvtColor(frames[i].numpy(), cv2.COLOR_RGB2BGR)
|
||||
video.write(image)
|
||||
video.release()
|
||||
|
||||
if args.tv:
|
||||
log.info(f'encode: file={args.tv} frames={n} width={w} height={h} fps={args.fps} method=tv ')
|
||||
torchvision.io.write_video(args.tv, video_array=frames, fps=args.fps, video_codec=args.codec)
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from fastapi.exceptions import HTTPException
|
||||
from modules import shared
|
||||
|
||||
|
||||
class ReqFramepack(BaseModel):
|
||||
variant: str = Field(default=None, title="Model variant", description="Model variant to use")
|
||||
prompt: str = Field(default=None, title="Prompt", description="Prompt for the model")
|
||||
init_image: str = Field(default=None, title="Initial image", description="Base64 encoded initial image")
|
||||
end_image: Optional[str] = Field(default=None, title="End image", description="Base64 encoded end image")
|
||||
start_weight: Optional[float] = Field(default=1.0, title="Start weight", description="Weight of the initial image")
|
||||
end_weight: Optional[float] = Field(default=1.0, title="End weight", description="Weight of the end image")
|
||||
vision_weight: Optional[float] = Field(default=1.0, title="Vision weight", description="Weight of the vision model")
|
||||
system_prompt: Optional[str] = Field(default=None, title="System prompt", description="System prompt for the model")
|
||||
optimized_prompt: Optional[bool] = Field(default=True, title="Optimized system prompt", description="Use optimized system prompt for the model")
|
||||
section_prompt: Optional[str] = Field(default=None, title="Section prompt", description="Prompt for each section")
|
||||
negative_prompt: Optional[str] = Field(default=None, title="Negative prompt", description="Negative prompt for the model")
|
||||
styles: Optional[List[str]] = Field(default=None, title="Styles", description="Styles for the model")
|
||||
seed: Optional[int] = Field(default=None, title="Seed", description="Seed for the model")
|
||||
resolution: Optional[int] = Field(default=640, title="Resolution", description="Resolution of the image")
|
||||
duration: Optional[float] = Field(default=4, title="Duration", description="Duration of the video in seconds")
|
||||
latent_ws: Optional[int] = Field(default=9, title="Latent window size", description="Size of the latent window")
|
||||
steps: Optional[int] = Field(default=25, title="Steps", description="Number of steps for the model")
|
||||
cfg_scale: Optional[float] = Field(default=1.0, title="CFG scale", description="CFG scale for the model")
|
||||
cfg_distilled: Optional[float] = Field(default=10.0, title="Distilled CFG scale", description="Distilled CFG scale for the model")
|
||||
cfg_rescale: Optional[float] = Field(default=0.0, title="CFG re-scale", description="CFG re-scale for the model")
|
||||
shift: Optional[float] = Field(default=0, title="Sampler shift", description="Shift for the sampler")
|
||||
use_teacache: Optional[bool] = Field(default=True, title="Enable TeaCache", description="Use TeaCache for the model")
|
||||
use_cfgzero: Optional[bool] = Field(default=False, title="Enable CFGZero", description="Use CFGZero for the model")
|
||||
mp4_fps: Optional[int] = Field(default=30, title="FPS", description="Frames per second for the video")
|
||||
mp4_codec: Optional[str] = Field(default="libx264", title="Codec", description="Codec for the video")
|
||||
mp4_sf: Optional[bool] = Field(default=False, title="Save SafeTensors", description="Save SafeTensors for the video")
|
||||
mp4_video: Optional[bool] = Field(default=True, title="Save Video", description="Save video")
|
||||
mp4_frames: Optional[bool] = Field(default=False, title="Save Frames", description="Save frames for the video")
|
||||
mp4_opt: Optional[str] = Field(default="crf:16", title="Options", description="Options for the video codec")
|
||||
mp4_ext: Optional[str] = Field(default="mp4", title="Format", description="Format for the video")
|
||||
mp4_interpolate: Optional[int] = Field(default=0, title="Interpolation", description="Interpolation for the video")
|
||||
attention: Optional[str] = Field(default="Default", title="Attention", description="Attention type for the model")
|
||||
vae_type: Optional[str] = Field(default="Local", title="VAE", description="VAE type for the model")
|
||||
vlm_enhance: Optional[bool] = Field(default=False, title="VLM enhance", description="Enable VLM enhance")
|
||||
vlm_model: Optional[str] = Field(default=None, title="VLM model", description="VLM model to use")
|
||||
vlm_system_prompt: Optional[str] = Field(default=None, title="VLM system prompt", description="System prompt for the VLM model")
|
||||
|
||||
|
||||
class ResFramepack(BaseModel):
|
||||
id: str = Field(title="TaskID", description="Task ID")
|
||||
filename: str = Field(title="TaskID", description="Task ID")
|
||||
message: str = Field(title="TaskID", description="Task ID")
|
||||
|
||||
|
||||
def framepack_post(request: ReqFramepack):
|
||||
import numpy as np
|
||||
from modules.api import helpers
|
||||
from framepack_wrappers import run_framepack
|
||||
task_id = shared.state.get_id()
|
||||
|
||||
try:
|
||||
if request.init_image is not None:
|
||||
init_image = np.array(helpers.decode_base64_to_image(request.init_image)) if request.init_image else None
|
||||
else:
|
||||
init_image = None
|
||||
except Exception as e:
|
||||
shared.log.error(f"API FramePack: id={task_id} cannot decode init image: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
try:
|
||||
if request.end_image is not None:
|
||||
end_image = np.array(helpers.decode_base64_to_image(request.end_image)) if request.end_image else None
|
||||
else:
|
||||
end_image = None
|
||||
except Exception as e:
|
||||
shared.log.error(f"API FramePack: id={task_id} cannot decode end image: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
del request.init_image
|
||||
del request.end_image
|
||||
shared.log.trace(f"API FramePack: id={task_id} init={init_image.shape} end={end_image.shape if end_image else None} {request}")
|
||||
|
||||
generator = run_framepack(
|
||||
task_id=f'task({task_id})',
|
||||
variant=request.variant,
|
||||
init_image=init_image,
|
||||
end_image=end_image,
|
||||
start_weight=request.start_weight,
|
||||
end_weight=request.end_weight,
|
||||
vision_weight=request.vision_weight,
|
||||
prompt=request.prompt,
|
||||
system_prompt=request.system_prompt,
|
||||
optimized_prompt=request.optimized_prompt,
|
||||
section_prompt=request.section_prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
styles=request.styles,
|
||||
seed=request.seed,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
latent_ws=request.latent_ws,
|
||||
steps=request.steps,
|
||||
cfg_scale=request.cfg_scale,
|
||||
cfg_distilled=request.cfg_distilled,
|
||||
cfg_rescale=request.cfg_rescale,
|
||||
shift=request.shift,
|
||||
use_teacache=request.use_teacache,
|
||||
use_cfgzero=request.use_cfgzero,
|
||||
use_preview=False,
|
||||
mp4_fps=request.mp4_fps,
|
||||
mp4_codec=request.mp4_codec,
|
||||
mp4_sf=request.mp4_sf,
|
||||
mp4_video=request.mp4_video,
|
||||
mp4_frames=request.mp4_frames,
|
||||
mp4_opt=request.mp4_opt,
|
||||
mp4_ext=request.mp4_ext,
|
||||
mp4_interpolate=request.mp4_interpolate,
|
||||
attention=request.attention,
|
||||
vae_type=request.vae_type,
|
||||
vlm_enhance=request.vlm_enhance,
|
||||
vlm_model=request.vlm_model,
|
||||
vlm_system_prompt=request.vlm_system_prompt,
|
||||
)
|
||||
response = ResFramepack(id=task_id, filename='', message='')
|
||||
for message in generator:
|
||||
if isinstance(message, tuple) and len(message) == 3:
|
||||
if isinstance(message[0], str):
|
||||
response.filename = message[0]
|
||||
if isinstance(message[2], str):
|
||||
response.message = message[2]
|
||||
return response
|
||||
|
||||
|
||||
def create_api(_fastapi, _gradioapp):
|
||||
shared.api.add_api_route("/sdapi/v1/framepack", framepack_post, methods=["POST"], response_model=ResFramepack)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
DEFAULT_PROMPT_TEMPLATE = { # hunyuanvideo reference prompt template
|
||||
"template": (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"3. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"4. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"5. background environment, light, style and atmosphere."
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
),
|
||||
"crop_start": 95,
|
||||
}
|
||||
ORIGINAL_PROMPT_TEMPLATE = (
|
||||
"\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"3. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"4. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"5. background environment, light, style and atmosphere."
|
||||
)
|
||||
|
||||
|
||||
def set_progress_bar_config():
|
||||
from modules.framepack.pipeline import uni_pc_fm
|
||||
|
||||
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): # pylint: disable=unused-argument
|
||||
return uni_pc_fm.FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=True)
|
||||
|
||||
uni_pc_fm.sample_unipc = sample_unipc
|
||||
|
||||
|
||||
def set_prompt_template(prompt, system_prompt:str=None, optimized_prompt:bool=True, unmodified_prompt:bool=False):
|
||||
from modules import shared
|
||||
from modules.framepack.pipeline import hunyuan
|
||||
mode = 'unknown'
|
||||
if system_prompt is not None and len(system_prompt) > 0:
|
||||
mode = 'custom'
|
||||
system_prompt = system_prompt.strip()
|
||||
elif not optimized_prompt:
|
||||
mode = 'original'
|
||||
system_prompt = ORIGINAL_PROMPT_TEMPLATE
|
||||
# elif unmodified_prompt:
|
||||
# mode = 'enhanced'
|
||||
# system_prompt = "Repeat the user prompt without any modification. Do not add any additional information or context. "
|
||||
else:
|
||||
mode = 'optimized'
|
||||
system_prefix = (
|
||||
"You are a video director. "
|
||||
"Both subject and camera movement should be dynamic and creative. Do not specify duration and assume video is short. Avoid slow-motion and prefer faster movements. "
|
||||
"Do not alter the content specified in the user prompt, but use it as a reference to create a video. "
|
||||
"NSFW content and nudity are allowed and must be fully described if present. "
|
||||
"Based on provided user prompt, describe the video by detailing the following aspects: \n"
|
||||
)
|
||||
system_desc = (
|
||||
# "Describe the video by detailing the following aspects: \n"
|
||||
"1. Main content, style and theme of the video.\n"
|
||||
"2. Actions, events, behaviors, temporal relationships, physical movement, and changes of the subjects or objects.\n"
|
||||
"3. Camera angles, camera movements, and transitions used in the video.\n"
|
||||
"4. Details of the scene and background environment, light, style, and atmosphere.\n"
|
||||
)
|
||||
system_prompt = system_prefix + system_desc
|
||||
# system_prompt = DEFAULT_PROMPT_TEMPLATE["template"]
|
||||
inputs = shared.sd_model.tokenizer(system_prompt, max_length=256, truncation=True, return_tensors="pt", return_length=True, return_overflowing_tokens=False, return_attention_mask=False)
|
||||
tokens_system = inputs['length'].item() - int(shared.sd_model.tokenizer.bos_token_id is not None) - int(shared.sd_model.tokenizer.eos_token_id is not None)
|
||||
inputs = shared.sd_model.tokenizer(prompt, max_length=256, truncation=True, return_tensors="pt", return_length=True, return_overflowing_tokens=False, return_attention_mask=False)
|
||||
hunyuan.DEFAULT_PROMPT_TEMPLATE = {
|
||||
"template": (
|
||||
f"<|start_header_id|>system<|end_header_id|>{system_prompt}\n<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>{}<|eot_id|>"
|
||||
),
|
||||
"crop_start": tokens_system,
|
||||
}
|
||||
tokens_user = inputs['length'].item() - int(shared.sd_model.tokenizer.bos_token_id is not None) - int(shared.sd_model.tokenizer.eos_token_id is not None)
|
||||
shared.log.trace(f'FramePack prompt: system={tokens_system} user={tokens_user} optimized={optimized_prompt} unmodified={unmodified_prompt} mode={mode}')
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
import shutil
|
||||
import git as gitpython
|
||||
from installer import install, git
|
||||
from modules.shared import log
|
||||
|
||||
|
||||
def rename(src:str, dst:str):
|
||||
import errno
|
||||
try:
|
||||
os.rename(src, dst)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EXDEV: # cross-device
|
||||
shutil.move(src, dst)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def install_requirements(attention:str='SDPA'):
|
||||
install('av')
|
||||
import av
|
||||
import torchvision
|
||||
torchvision.io.video.av = av
|
||||
if attention == 'Xformers':
|
||||
log.debug('FramePack install: xformers')
|
||||
install('xformers')
|
||||
elif attention == 'FlashAttention':
|
||||
log.debug('FramePack install: flash-attn')
|
||||
install('flash-attn')
|
||||
elif attention == 'SageAttention':
|
||||
log.debug('FramePack install: sageattention')
|
||||
install('sageattention')
|
||||
|
||||
|
||||
def git_clone(git_repo:str, git_dir:str, tmp_dir:str):
|
||||
if os.path.exists(git_dir):
|
||||
return
|
||||
try:
|
||||
shutil.rmtree(tmp_dir, True)
|
||||
args = {
|
||||
'url': git_repo,
|
||||
'to_path': tmp_dir,
|
||||
'allow_unsafe_protocols': True,
|
||||
'allow_unsafe_options': True,
|
||||
'filter': ['blob:none'],
|
||||
}
|
||||
ssh = os.environ.get('GIT_SSH_COMMAND', None)
|
||||
if ssh:
|
||||
args['env'] = {'GIT_SSH_COMMAND':ssh}
|
||||
log.info(f'FramePack install: url={args} path={git_repo}')
|
||||
with gitpython.Repo.clone_from(**args) as repo:
|
||||
repo.remote().fetch(verbose=True)
|
||||
for submodule in repo.submodules:
|
||||
submodule.update()
|
||||
rename(tmp_dir, git_dir)
|
||||
except Exception as e:
|
||||
log.error(f'FramePack install: {e}')
|
||||
shutil.rmtree(tmp_dir, True)
|
||||
|
||||
|
||||
def git_update(git_dir:str, git_commit:str):
|
||||
if not os.path.exists(git_dir):
|
||||
return
|
||||
try:
|
||||
with gitpython.Repo(git_dir) as repo:
|
||||
commit = repo.commit()
|
||||
if f'{commit}' != git_commit:
|
||||
log.info(f'FramePack update: path={repo.git_dir} current={commit} target={git_commit}')
|
||||
repo.git.fetch(all=True)
|
||||
repo.git.reset('origin', hard=True)
|
||||
git(f'checkout {git_commit}', folder=git_dir, ignore=True, optional=True)
|
||||
else:
|
||||
log.debug(f'FramePack version: sha={commit}')
|
||||
except Exception as e:
|
||||
log.error(f'FramePack update: {e}')
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
import time
|
||||
from modules import shared, devices, errors, sd_models, sd_checkpoint, model_quant
|
||||
|
||||
|
||||
models = {
|
||||
'bi-directional': 'lllyasviel/FramePackI2V_HY',
|
||||
'forward-only': 'lllyasviel/FramePack_F1_I2V_HY_20250503',
|
||||
}
|
||||
default_model = {
|
||||
'pipeline': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': '' },
|
||||
'vae': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'vae' },
|
||||
'text_encoder': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'text_encoder' },
|
||||
'tokenizer': {'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'tokenizer' },
|
||||
# 'text_encoder': { 'repo': 'Kijai/llava-llama-3-8b-text-encoder-tokenizer', 'subfolder': '' },
|
||||
# 'tokenizer': { 'repo': 'Kijai/llava-llama-3-8b-text-encoder-tokenizer', 'subfolder': '' },
|
||||
# 'text_encoder': { 'repo': 'xtuner/llava-llama-3-8b-v1_1-transformers', 'subfolder': '' },
|
||||
# 'tokenizer': {'repo': 'xtuner/llava-llama-3-8b-v1_1-transformers', 'subfolder': '' },
|
||||
'text_encoder_2': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'text_encoder_2' },
|
||||
'tokenizer_2': { 'repo': 'hunyuanvideo-community/HunyuanVideo', 'subfolder': 'tokenizer_2' },
|
||||
'feature_extractor': { 'repo': 'lllyasviel/flux_redux_bfl', 'subfolder': 'feature_extractor' },
|
||||
'image_encoder': { 'repo': 'lllyasviel/flux_redux_bfl', 'subfolder': 'image_encoder' },
|
||||
'transformer': { 'repo': models.get('bi-directional'), 'subfolder': '' },
|
||||
}
|
||||
model = default_model.copy()
|
||||
|
||||
|
||||
def split_url(url):
|
||||
if url.count('/') == 1:
|
||||
url += '/'
|
||||
if url.count('/') != 2:
|
||||
raise ValueError(f'Invalid URL: {url}')
|
||||
url = [section.strip() for section in url.split('/')]
|
||||
return { 'repo': f'{url[0]}/{url[1]}', 'subfolder': url[2] }
|
||||
|
||||
|
||||
def set_model(receipe: str=None):
|
||||
if receipe is None or receipe == '':
|
||||
return
|
||||
lines = [line.strip() for line in receipe.split('\n') if line.strip() != '' and ':' in line]
|
||||
for line in lines:
|
||||
k, v = line.split(':', 1)
|
||||
k = k.strip()
|
||||
if k not in default_model.keys():
|
||||
shared.log.warning(f'FramePack receipe: key={k} invalid')
|
||||
model[k] = split_url(v)
|
||||
shared.log.debug(f'FramePack receipe: set {k}={model[k]}')
|
||||
|
||||
|
||||
def get_model():
|
||||
receipe = ''
|
||||
for k, v in model.items():
|
||||
receipe += f'{k}: {v["repo"]}/{v["subfolder"]}\n'
|
||||
return receipe.strip()
|
||||
|
||||
|
||||
def reset_model():
|
||||
global model # pylint: disable=global-statement
|
||||
model = default_model.copy()
|
||||
shared.log.debug('FramePack receipe: reset')
|
||||
return ''
|
||||
|
||||
|
||||
def load_model(variant:str=None, pipeline:str=None, text_encoder:str=None, text_encoder_2:str=None, feature_extractor:str=None, image_encoder:str=None, transformer:str=None):
|
||||
shared.state.begin('Load')
|
||||
if variant is not None:
|
||||
if variant not in models.keys():
|
||||
raise ValueError(f'FramePack: variant="{variant}" invalid')
|
||||
model['transformer']['repo'] = models[variant]
|
||||
if pipeline is not None:
|
||||
model['pipeline'] = split_url(pipeline)
|
||||
if text_encoder is not None:
|
||||
model['text_encoder'] = split_url(text_encoder)
|
||||
if text_encoder_2 is not None:
|
||||
model['text_encoder_2'] = split_url(text_encoder_2)
|
||||
if feature_extractor is not None:
|
||||
model['feature_extractor'] = split_url(feature_extractor)
|
||||
if image_encoder is not None:
|
||||
model['image_encoder'] = split_url(image_encoder)
|
||||
if transformer is not None:
|
||||
model['transformer'] = split_url(transformer)
|
||||
# shared.log.trace(f'FramePack load: {model}')
|
||||
|
||||
try:
|
||||
import diffusers
|
||||
from diffusers import HunyuanVideoImageToVideoPipeline, AutoencoderKLHunyuanVideo
|
||||
from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer, SiglipImageProcessor, SiglipVisionModel
|
||||
from modules.framepack.pipeline.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
|
||||
|
||||
class FramepackHunyuanVideoPipeline(HunyuanVideoImageToVideoPipeline): # inherit and override
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: LlamaModel,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
text_encoder_2: CLIPTextModel,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
vae: AutoencoderKLHunyuanVideo,
|
||||
feature_extractor: SiglipImageProcessor,
|
||||
image_processor: SiglipVisionModel,
|
||||
transformer: HunyuanVideoTransformer3DModelPacked,
|
||||
scheduler,
|
||||
):
|
||||
super().__init__(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
image_processor=image_processor,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_processor,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
sd_models.unload_model_weights()
|
||||
t0 = time.time()
|
||||
|
||||
shared.log.debug(f'FramePack load: module=llm {model["text_encoder"]}')
|
||||
load_args, quant_args = model_quant.get_dit_args({}, module='TE', device_map=True)
|
||||
text_encoder = LlamaModel.from_pretrained(model["text_encoder"]["repo"], subfolder=model["text_encoder"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **load_args, **quant_args)
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(model["tokenizer"]["repo"], subfolder=model["tokenizer"]["subfolder"], cache_dir=shared.opts.hfcache_dir)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
sd_models.move_model(text_encoder, devices.cpu)
|
||||
|
||||
shared.log.debug(f'FramePack load: module=te {model["text_encoder_2"]}')
|
||||
text_encoder_2 = CLIPTextModel.from_pretrained(model["text_encoder_2"]["repo"], subfolder=model["text_encoder_2"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(model["pipeline"]["repo"], subfolder='tokenizer_2', cache_dir=shared.opts.hfcache_dir)
|
||||
text_encoder_2.requires_grad_(False)
|
||||
text_encoder_2.eval()
|
||||
sd_models.move_model(text_encoder_2, devices.cpu)
|
||||
|
||||
shared.log.debug(f'FramePack load: module=vae {model["vae"]}')
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained(model["vae"]["repo"], subfolder=model["vae"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.enable_slicing()
|
||||
vae.enable_tiling()
|
||||
sd_models.move_model(vae, devices.cpu)
|
||||
|
||||
shared.log.debug(f'FramePack load: module=encoder {model["feature_extractor"]} model={model["image_encoder"]}')
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(model["feature_extractor"]["repo"], subfolder=model["feature_extractor"]["subfolder"], cache_dir=shared.opts.hfcache_dir)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(model["image_encoder"]["repo"], subfolder=model["image_encoder"]["subfolder"], torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
|
||||
image_encoder.requires_grad_(False)
|
||||
image_encoder.eval()
|
||||
sd_models.move_model(image_encoder, devices.cpu)
|
||||
|
||||
shared.log.debug(f'FramePack load: module=transformer {model["transformer"]}')
|
||||
dit_repo = model["transformer"]["repo"]
|
||||
load_args, quant_args = model_quant.get_dit_args({}, module='Model', device_map=True)
|
||||
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(dit_repo, subfolder=model["transformer"]["subfolder"], cache_dir=shared.opts.hfcache_dir, **load_args, **quant_args)
|
||||
transformer.high_quality_fp32_output_for_inference = False
|
||||
transformer.requires_grad_(False)
|
||||
transformer.eval()
|
||||
sd_models.move_model(transformer, devices.cpu)
|
||||
|
||||
shared.sd_model = FramepackHunyuanVideoPipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
feature_extractor=feature_extractor,
|
||||
image_processor=image_encoder,
|
||||
transformer=transformer,
|
||||
scheduler=None,
|
||||
)
|
||||
shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(dit_repo) # pylint: disable=attribute-defined-outside-init
|
||||
shared.sd_model.sd_model_checkpoint = dit_repo # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
shared.sd_model = model_quant.do_post_load_quant(shared.sd_model, allow=False)
|
||||
t1 = time.time()
|
||||
|
||||
diffusers.loaders.peft._SET_ADAPTER_SCALE_FN_MAPPING['HunyuanVideoTransformer3DModelPacked'] = lambda model_cls, weights: weights # pylint: disable=protected-access
|
||||
shared.log.info(f'FramePack load: model={shared.sd_model.__class__.__name__} variant="{variant}" type={shared.sd_model_type} time={t1-t0:.2f}')
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
devices.torch_gc(force=True)
|
||||
|
||||
except Exception as e:
|
||||
shared.log.error(f'FramePack load: {e}')
|
||||
errors.display(e, 'FramePack')
|
||||
shared.state.end()
|
||||
return None
|
||||
|
||||
shared.state.end()
|
||||
return variant
|
||||
|
||||
|
||||
def unload_model():
|
||||
sd_models.unload_model_weights()
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
import gradio as gr
|
||||
from modules import ui_sections, ui_common, ui_video_vlm
|
||||
from modules.framepack import framepack_load
|
||||
from modules.framepack.framepack_worker import get_latent_paddings
|
||||
from modules.framepack.framepack_wrappers import get_codecs, load_model, unload_model
|
||||
from modules.framepack.framepack_wrappers import run_framepack # pylint: disable=wrong-import-order
|
||||
|
||||
|
||||
def change_sections(duration, mp4_fps, mp4_interpolate, latent_ws, variant):
|
||||
num_sections = len(get_latent_paddings(mp4_fps, mp4_interpolate, latent_ws, duration, variant))
|
||||
num_frames = (latent_ws * 4 - 3) * num_sections + 1
|
||||
return gr.update(value=f'Target video: {num_frames} frames in {num_sections} sections'), gr.update(lines=max(2, 2*num_sections//3))
|
||||
|
||||
|
||||
def create_ui(prompt, negative, styles, _overrides):
|
||||
with gr.Row():
|
||||
with gr.Column(variant='compact', elem_id="framepack_settings", elem_classes=['settings-column'], scale=1):
|
||||
with gr.Row():
|
||||
generate = gr.Button('Generate', elem_id="framepack_generate_btn", variant='primary', visible=False)
|
||||
with gr.Row():
|
||||
variant = gr.Dropdown(label="Model variant", choices=list(framepack_load.models), value='bi-directional', type='value')
|
||||
with gr.Row():
|
||||
resolution = gr.Slider(label="Resolution", minimum=240, maximum=1088, value=640, step=16)
|
||||
duration = gr.Slider(label="Duration", minimum=1, maximum=120, value=4, step=0.1)
|
||||
mp4_fps = gr.Slider(label="FPS", minimum=1, maximum=60, value=24, step=1)
|
||||
mp4_interpolate = gr.Slider(label="Interpolation", minimum=0, maximum=10, value=0, step=1)
|
||||
with gr.Row():
|
||||
section_html = gr.HTML(show_label=False, elem_id="framepack_section_html")
|
||||
with gr.Accordion(label="Inputs", open=True):
|
||||
with gr.Row():
|
||||
input_image = gr.Image(sources='upload', type="numpy", label="Init image", width=256, height=256, interactive=True, tool="editor", image_mode='RGB', elem_id="framepack_input_image")
|
||||
end_image = gr.Image(sources='upload', type="numpy", label="End image", width=256, height=256, interactive=True, tool="editor", image_mode='RGB', elem_id="framepack_end_image")
|
||||
with gr.Row():
|
||||
start_weight = gr.Slider(label="Init strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_start_weight")
|
||||
end_weight = gr.Slider(label="End strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_end_weight")
|
||||
vision_weight = gr.Slider(label="Vision strength", value=1.0, minimum=0.0, maximum=2.0, step=0.05, elem_id="framepack_vision_weight")
|
||||
with gr.Accordion(label="Sections", open=False):
|
||||
section_prompt = gr.Textbox(label="Section prompts", elem_id="framepack_section_prompt", lines=2, placeholder="Optional one-line prompt suffix per each video section", interactive=True)
|
||||
with gr.Accordion(label="Video", open=False):
|
||||
with gr.Row():
|
||||
mp4_codec = gr.Dropdown(label="Codec", choices=['none', 'libx264'], value='libx264', type='value')
|
||||
ui_common.create_refresh_button(mp4_codec, get_codecs)
|
||||
mp4_ext = gr.Textbox(label="Format", value='mp4', elem_id="framepack_mp4_ext")
|
||||
mp4_opt = gr.Textbox(label="Options", value='crf:16', elem_id="framepack_mp4_ext")
|
||||
with gr.Row():
|
||||
mp4_video = gr.Checkbox(label='Save Video', value=True, elem_id="framepack_mp4_video")
|
||||
mp4_frames = gr.Checkbox(label='Save Frames', value=False, elem_id="framepack_mp4_frames")
|
||||
mp4_sf = gr.Checkbox(label='Save SafeTensors', value=False, elem_id="framepack_mp4_sf")
|
||||
with gr.Accordion(label="Advanced", open=False):
|
||||
seed = ui_sections.create_seed_inputs('control', reuse_visible=False, subseed_visible=False, accordion=False)[0]
|
||||
latent_ws = gr.Slider(label="Latent window size", minimum=1, maximum=33, value=9, step=1)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1)
|
||||
shift = gr.Slider(label="Sampler shift", minimum=0.0, maximum=10.0, value=3.0, step=0.01)
|
||||
with gr.Row():
|
||||
cfg_scale = gr.Slider(label="CFG scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01)
|
||||
cfg_distilled = gr.Slider(label="Distilled CFG scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01)
|
||||
cfg_rescale = gr.Slider(label="CFG re-scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01)
|
||||
|
||||
vlm_enhance, vlm_model, vlm_system_prompt = ui_video_vlm.create_ui(prompt_element=prompt, image_element=input_image)
|
||||
|
||||
with gr.Accordion(label="Model", open=False):
|
||||
with gr.Row():
|
||||
btn_load = gr.Button(value="Load model", elem_id="framepack_btn_load", interactive=True)
|
||||
btn_unload = gr.Button(value="Unload model", elem_id="framepack_btn_unload", interactive=True)
|
||||
with gr.Row():
|
||||
system_prompt = gr.Textbox(label="System prompt", elem_id="framepack_system_prompt", lines=6, placeholder="Optional system prompt for the model", interactive=True)
|
||||
with gr.Row():
|
||||
receipe = gr.Textbox(label="Model receipe", elem_id="framepack_model_receipe", lines=6, placeholder="Model receipe", interactive=True)
|
||||
with gr.Row():
|
||||
receipe_get = gr.Button(value="Get receipe", elem_id="framepack_btn_get_model", interactive=True)
|
||||
receipe_set = gr.Button(value="Set receipe", elem_id="framepack_btn_set_model", interactive=True)
|
||||
receipe_reset = gr.Button(value="Reset receipe", elem_id="framepack_btn_reset_model", interactive=True)
|
||||
use_teacache = gr.Checkbox(label='Enable TeaCache', value=True)
|
||||
optimized_prompt = gr.Checkbox(label='Use optimized system prompt', value=True)
|
||||
use_cfgzero = gr.Checkbox(label='Enable CFGZero', value=False)
|
||||
use_preview = gr.Checkbox(label='Enable Preview', value=True)
|
||||
attention = gr.Dropdown(label="Attention", choices=['Default', 'Xformers', 'FlashAttention', 'SageAttention'], value='Default', type='value')
|
||||
vae_type = gr.Dropdown(label="VAE", choices=['Full', 'Tiny', 'Remote'], value='Local', type='value')
|
||||
|
||||
with gr.Column(elem_id='framepack-output-column', scale=2) as _column_output:
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Video"):
|
||||
result_video = gr.Video(label="Video", autoplay=True, show_share_button=False, height=512, loop=True, show_label=False, elem_id="framepack_result_video")
|
||||
with gr.Tab("Preview"):
|
||||
preview_image = gr.Image(label="Current", height=512, show_label=False, elem_id="framepack_preview_image")
|
||||
progress_desc = gr.HTML('', show_label=False, elem_id="framepack_progress_desc")
|
||||
|
||||
# hidden fields
|
||||
task_id = gr.Textbox(visible=False, value='')
|
||||
ui_state = gr.Textbox(visible=False, value='')
|
||||
state_inputs = [task_id, ui_state]
|
||||
|
||||
framepack_outputs = [
|
||||
result_video,
|
||||
preview_image,
|
||||
progress_desc,
|
||||
]
|
||||
|
||||
duration.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt])
|
||||
mp4_fps.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt])
|
||||
mp4_interpolate.change(fn=change_sections, inputs=[duration, mp4_fps, mp4_interpolate, latent_ws, variant], outputs=[section_html, section_prompt])
|
||||
btn_load.click(fn=load_model, inputs=[variant, attention], outputs=framepack_outputs)
|
||||
btn_unload.click(fn=unload_model, outputs=framepack_outputs)
|
||||
receipe_get.click(fn=framepack_load.get_model, inputs=[], outputs=receipe)
|
||||
receipe_set.click(fn=framepack_load.set_model, inputs=[receipe], outputs=[])
|
||||
receipe_reset.click(fn=framepack_load.reset_model, inputs=[], outputs=[receipe])
|
||||
|
||||
framepack_inputs=[
|
||||
input_image, end_image,
|
||||
start_weight, end_weight, vision_weight,
|
||||
prompt, system_prompt, optimized_prompt, section_prompt, negative, styles,
|
||||
seed,
|
||||
resolution,
|
||||
duration,
|
||||
latent_ws,
|
||||
steps,
|
||||
cfg_scale, cfg_distilled, cfg_rescale,
|
||||
shift,
|
||||
use_teacache, use_cfgzero, use_preview,
|
||||
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
|
||||
attention, vae_type, variant,
|
||||
vlm_enhance, vlm_model, vlm_system_prompt,
|
||||
]
|
||||
|
||||
framepack_dict = dict(
|
||||
fn=run_framepack,
|
||||
_js="submit_framepack",
|
||||
inputs=state_inputs + framepack_inputs,
|
||||
outputs=framepack_outputs,
|
||||
show_progress=False,
|
||||
)
|
||||
generate.click(**framepack_dict)
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import torch
|
||||
import einops
|
||||
from modules import shared, devices
|
||||
|
||||
|
||||
latent_rgb_factors = [ # from comfyui
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
[0.0696, 0.0795, 0.0518],
|
||||
[0.0135, -0.0945, -0.0282],
|
||||
[0.0108, -0.0250, -0.0765],
|
||||
[-0.0209, 0.0032, 0.0224],
|
||||
[-0.0804, -0.0254, -0.0639],
|
||||
[-0.0991, 0.0271, -0.0669],
|
||||
[-0.0646, -0.0422, -0.0400],
|
||||
[-0.0696, -0.0595, -0.0894],
|
||||
[-0.0799, -0.0208, -0.0375],
|
||||
[0.1166, 0.1627, 0.0962],
|
||||
[0.1165, 0.0432, 0.0407],
|
||||
[-0.2315, -0.1920, -0.1355],
|
||||
[-0.0270, 0.0401, -0.0821],
|
||||
[-0.0616, -0.0997, -0.0727],
|
||||
[0.0249, -0.0469, -0.1703]
|
||||
]
|
||||
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
|
||||
vae_weight = None
|
||||
vae_bias = None
|
||||
taesd = None
|
||||
|
||||
|
||||
def vae_decode_simple(latents):
|
||||
global vae_weight, vae_bias # pylint: disable=global-statement
|
||||
with devices.inference_context():
|
||||
if vae_weight is None or vae_bias is None:
|
||||
vae_weight = torch.tensor(latent_rgb_factors, device=devices.device, dtype=devices.dtype).transpose(0, 1)[:, :, None, None, None]
|
||||
vae_bias = torch.tensor(latent_rgb_factors_bias, device=devices.device, dtype=devices.dtype)
|
||||
images = torch.nn.functional.conv3d(latents, weight=vae_weight, bias=vae_bias, stride=1, padding=0, dilation=1, groups=1)
|
||||
images = (images + 1.2) * 100 # sort-of normalized
|
||||
images = einops.rearrange(images, 'b c t h w -> (b h) (t w) c')
|
||||
images = images.to(torch.uint8).detach().cpu().numpy().clip(0, 255)
|
||||
return images
|
||||
|
||||
|
||||
def vae_decode_tiny(latents):
|
||||
global taesd # pylint: disable=global-statement
|
||||
if taesd is None:
|
||||
from modules import sd_vae_taesd
|
||||
taesd = sd_vae_taesd.get_model(variant='TAE HunyuanVideo')
|
||||
shared.log.debug(f'Video VAE: type=Tiny cls={taesd.__class__.__name__} latents={latents.shape}')
|
||||
with devices.inference_context():
|
||||
taesd = taesd.to(device=devices.device, dtype=devices.dtype)
|
||||
latents = latents.transpose(1, 2) # pipe produces NCTHW and tae wants NTCHW
|
||||
images = taesd.decode_video(latents, parallel=False, show_progress_bar=False)
|
||||
images = images.transpose(1, 2).mul_(2).sub_(1) # normalize
|
||||
taesd = taesd.to(device=devices.cpu, dtype=devices.dtype)
|
||||
return images
|
||||
|
||||
|
||||
def vae_decode_remote(latents):
|
||||
# from modules.sd_vae_remote import remote_decode
|
||||
# images = remote_decode(latents, model_type='hunyuanvideo')
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
images = remote_decode(
|
||||
tensor=latents.contiguous(),
|
||||
endpoint='https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
output_type='pt',
|
||||
return_type='pt',
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
def vae_decode_full(latents):
|
||||
with devices.inference_context():
|
||||
vae = shared.sd_model.vae
|
||||
latents = (latents / vae.config.scaling_factor).to(device=vae.device, dtype=vae.dtype)
|
||||
images = vae.decode(latents).sample
|
||||
return images
|
||||
|
||||
|
||||
def vae_decode(latents, vae_type):
|
||||
latents = latents.to(device=devices.device, dtype=devices.dtype)
|
||||
if vae_type == 'Tiny':
|
||||
return vae_decode_tiny(latents)
|
||||
elif vae_type == 'Preview':
|
||||
return vae_decode_simple(latents)
|
||||
elif vae_type == 'Remote':
|
||||
return vae_decode_remote(latents)
|
||||
else: # vae_type == 'Full'
|
||||
return vae_decode_full(latents)
|
||||
|
||||
|
||||
def vae_encode(image):
|
||||
with devices.inference_context():
|
||||
vae = shared.sd_model.vae
|
||||
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
return latents
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
import os
|
||||
import time
|
||||
import datetime
|
||||
import cv2
|
||||
import torch
|
||||
import einops
|
||||
from modules import shared, errors ,timer, rife
|
||||
|
||||
|
||||
def atomic_save_video(filename, tensor:torch.Tensor, fps:float=24, codec:str='libx264', pix_fmt:str='yuv420p', options:str='', metadata:dict={}, pbar=None):
|
||||
try:
|
||||
import av
|
||||
av.logging.set_level(av.logging.ERROR) # pylint: disable=c-extension-no-member
|
||||
except Exception as e:
|
||||
shared.log.error(f'FramePack video: {e}')
|
||||
return
|
||||
|
||||
frames, height, width, _channels = tensor.shape
|
||||
rate = round(fps)
|
||||
options_str = options
|
||||
options = {}
|
||||
for option in [option.strip() for option in options_str.split(',')]:
|
||||
if '=' in option:
|
||||
key, value = option.split('=', 1)
|
||||
elif ':' in option:
|
||||
key, value = option.split(':', 1)
|
||||
else:
|
||||
continue
|
||||
options[key.strip()] = value.strip()
|
||||
shared.log.info(f'FramePack video: file="{filename}" codec={codec} frames={frames} width={width} height={height} fps={rate} options={options}')
|
||||
video_array = torch.as_tensor(tensor, dtype=torch.uint8).numpy(force=True)
|
||||
task = pbar.add_task('encoding', total=frames) if pbar is not None else None
|
||||
if task is not None:
|
||||
pbar.update(task, description='video encoding')
|
||||
with av.open(filename, mode="w") as container:
|
||||
for k, v in metadata.items():
|
||||
container.metadata[k] = v
|
||||
stream: av.VideoStream = container.add_stream(codec, rate=rate, options=options)
|
||||
stream.width = video_array.shape[2]
|
||||
stream.height = video_array.shape[1]
|
||||
stream.pix_fmt = pix_fmt
|
||||
for img in video_array:
|
||||
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
||||
for packet in stream.encode_lazy(frame):
|
||||
container.mux(packet)
|
||||
if task is not None:
|
||||
pbar.update(task, advance=1)
|
||||
for packet in stream.encode(): # flush
|
||||
container.mux(packet)
|
||||
shared.state.outputs(filename)
|
||||
|
||||
|
||||
def save_video(
|
||||
pixels:torch.Tensor,
|
||||
mp4_fps:int=24,
|
||||
mp4_codec:str='libx264',
|
||||
mp4_opt:str='',
|
||||
mp4_ext:str='mp4',
|
||||
mp4_sf:bool=False, # save safetensors
|
||||
mp4_video:bool=True, # save video
|
||||
mp4_frames:bool=False, # save frames
|
||||
mp4_interpolate:int=0, # rife interpolation
|
||||
stream=None, # async progress reporting stream
|
||||
metadata:dict={}, # metadata for video
|
||||
pbar=None, # progress bar for video
|
||||
):
|
||||
if pixels is None:
|
||||
return 0
|
||||
t_save = time.time()
|
||||
n, _c, t, h, w = pixels.shape
|
||||
size = pixels.element_size() * pixels.numel()
|
||||
shared.log.debug(f'FramePack video: video={mp4_video} export={mp4_frames} safetensors={mp4_sf} interpolate={mp4_interpolate}')
|
||||
shared.log.debug(f'FramePack video: encode={t} raw={size} latent={pixels.shape} fps={mp4_fps} codec={mp4_codec} ext={mp4_ext} options="{mp4_opt}"')
|
||||
try:
|
||||
if stream is not None:
|
||||
stream.output_queue.push(('progress', (None, 'Saving video...')))
|
||||
if mp4_interpolate > 0:
|
||||
x = pixels.squeeze(0).permute(1, 0, 2, 3)
|
||||
interpolated = rife.interpolate_nchw(x, count=mp4_interpolate+1)
|
||||
pixels = torch.stack(interpolated, dim=0)
|
||||
pixels = pixels.permute(1, 2, 0, 3, 4)
|
||||
|
||||
n, _c, t, h, w = pixels.shape
|
||||
x = torch.clamp(pixels.float(), -1., 1.) * 127.5 + 127.5
|
||||
x = x.detach().cpu().to(torch.uint8)
|
||||
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=n)
|
||||
x = x.contiguous()
|
||||
|
||||
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
|
||||
output_filename = os.path.join(shared.opts.outdir_video, f'{timestamp}-{mp4_codec}-f{t}')
|
||||
|
||||
if mp4_sf:
|
||||
fn = f'{output_filename}.safetensors'
|
||||
shared.log.info(f'FramePack export: file="{fn}" type=savetensors shape={x.shape}')
|
||||
from safetensors.torch import save_file
|
||||
shared.state.outputs(fn)
|
||||
save_file({ 'frames': x }, fn, metadata={'format': 'video', 'frames': str(t), 'width': str(w), 'height': str(h), 'fps': str(mp4_fps), 'codec': mp4_codec, 'options': mp4_opt, 'ext': mp4_ext, 'interpolate': str(mp4_interpolate)})
|
||||
|
||||
if mp4_frames:
|
||||
shared.log.info(f'FramePack frames: files="{output_filename}-00000.jpg" frames={t} width={w} height={h}')
|
||||
for i in range(t):
|
||||
image = cv2.cvtColor(x[i].numpy(), cv2.COLOR_RGB2BGR)
|
||||
fn = f'{output_filename}-{i:05d}.jpg'
|
||||
shared.state.outputs(fn)
|
||||
cv2.imwrite(fn, image)
|
||||
|
||||
if mp4_video and (mp4_codec != 'none'):
|
||||
fn = f'{output_filename}.{mp4_ext}'
|
||||
atomic_save_video(fn, tensor=x, fps=mp4_fps, codec=mp4_codec, options=mp4_opt, metadata=metadata, pbar=pbar)
|
||||
if stream is not None:
|
||||
stream.output_queue.push(('progress', (None, f'Video {os.path.basename(fn)} | Codec {mp4_codec} | Size {w}x{h}x{t} | FPS {mp4_fps}')))
|
||||
stream.output_queue.push(('file', fn))
|
||||
else:
|
||||
if stream is not None:
|
||||
stream.output_queue.push(('progress', (None, '')))
|
||||
|
||||
except Exception as e:
|
||||
shared.log.error(f'FramePack video: raw={size} {e}')
|
||||
errors.display(e, 'FramePack video')
|
||||
timer.process.add('save', time.time()-t_save)
|
||||
return t
|
||||
|
|
@ -0,0 +1,319 @@
|
|||
import time
|
||||
import torch
|
||||
import rich.progress as rp
|
||||
from modules import shared, errors ,devices, sd_models, timer, memstats
|
||||
from modules.framepack import framepack_vae # pylint: disable=wrong-import-order
|
||||
from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order
|
||||
from modules.framepack import framepack_video # pylint: disable=wrong-import-order
|
||||
|
||||
|
||||
stream = None # AsyncStream
|
||||
|
||||
|
||||
def get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant):
|
||||
try:
|
||||
real_fps = mp4_fps / (mp4_interpolate + 1)
|
||||
is_f1 = variant == 'forward-only'
|
||||
if is_f1:
|
||||
total_latent_sections = (total_second_length * real_fps) / (latent_window_size * 4)
|
||||
total_latent_sections = int(max(round(total_latent_sections), 1))
|
||||
latent_paddings = list(range(total_latent_sections))
|
||||
else:
|
||||
total_latent_sections = int(max((total_second_length * real_fps) / (latent_window_size * 4), 1))
|
||||
latent_paddings = list(reversed(range(total_latent_sections)))
|
||||
if total_latent_sections > 4: # extra padding for better quality
|
||||
# latent_paddings = list(reversed(range(total_latent_sections)))
|
||||
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
||||
except Exception:
|
||||
latent_paddings = [0]
|
||||
return latent_paddings
|
||||
|
||||
|
||||
def worker(
|
||||
input_image, end_image,
|
||||
start_weight, end_weight, vision_weight,
|
||||
prompts, n_prompt, system_prompt, optimized_prompt, unmodified_prompt,
|
||||
seed,
|
||||
total_second_length,
|
||||
latent_window_size,
|
||||
steps,
|
||||
cfg_scale, cfg_distilled, cfg_rescale,
|
||||
shift,
|
||||
use_teacache, use_cfgzero, use_preview,
|
||||
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
|
||||
vae_type,
|
||||
variant,
|
||||
metadata:dict={},
|
||||
):
|
||||
timer.process.reset()
|
||||
memstats.reset_stats()
|
||||
if stream is None or shared.state.interrupted or shared.state.skipped:
|
||||
shared.log.error('FramePack: stream is None')
|
||||
stream.output_queue.push(('end', None))
|
||||
return
|
||||
|
||||
from modules.framepack.pipeline import hunyuan
|
||||
from modules.framepack.pipeline import utils
|
||||
from modules.framepack.pipeline.k_diffusion_hunyuan import sample_hunyuan
|
||||
|
||||
is_f1 = variant == 'forward-only'
|
||||
total_generated_frames = 0
|
||||
total_generated_latent_frames = 0
|
||||
latent_paddings = get_latent_paddings(mp4_fps, mp4_interpolate, latent_window_size, total_second_length, variant)
|
||||
num_frames = latent_window_size * 4 - 3 # number of frames to generate in each section
|
||||
|
||||
metadata['title'] = 'sdnext framepack'
|
||||
metadata['description'] = f'variant:{variant} seed:{seed} steps:{steps} scale:{cfg_scale} distilled:{cfg_distilled} rescale:{cfg_rescale} shift:{shift} start:{start_weight} end:{end_weight} vision:{vision_weight}'
|
||||
|
||||
shared.state.begin('Video')
|
||||
shared.state.job_count = 1
|
||||
|
||||
text_encoder = shared.sd_model.text_encoder
|
||||
text_encoder_2 = shared.sd_model.text_encoder_2
|
||||
tokenizer = shared.sd_model.tokenizer
|
||||
tokenizer_2 = shared.sd_model.tokenizer_2
|
||||
vae = shared.sd_model.vae
|
||||
feature_extractor = shared.sd_model.feature_extractor
|
||||
image_encoder = shared.sd_model.image_processor
|
||||
transformer = shared.sd_model.transformer
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
pbar = rp.Progress(rp.TextColumn('[cyan]Video'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
||||
task = pbar.add_task('starting', total=steps * len(latent_paddings))
|
||||
t_last = time.time()
|
||||
if not is_f1:
|
||||
prompts = list(reversed(prompts))
|
||||
|
||||
def text_encode(prompt, i:int=None):
|
||||
pbar.update(task, description=f'text encode section={i}')
|
||||
t0 = time.time()
|
||||
torch.manual_seed(seed)
|
||||
# shared.log.debug(f'FramePack: section={i} prompt="{prompt}"')
|
||||
shared.state.textinfo = 'Text encode'
|
||||
stream.output_queue.push(('progress', (None, 'Text encoding...')))
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(text_encoder, devices.device, force=True) # required as hunyuan.encode_prompt_conds checks device before calling model
|
||||
sd_models.move_model(text_encoder_2, devices.device, force=True)
|
||||
framepack_hijack.set_prompt_template(prompt, system_prompt, optimized_prompt, unmodified_prompt)
|
||||
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
|
||||
metadata['comment'] = prompt
|
||||
if cfg_scale > 1 and n_prompt is not None and len(n_prompt) > 0:
|
||||
llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
|
||||
else:
|
||||
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
|
||||
llama_vec, llama_attention_mask = utils.crop_or_pad_yield_mask(llama_vec, length=512)
|
||||
llama_vec_n, llama_attention_mask_n = utils.crop_or_pad_yield_mask(llama_vec_n, length=512)
|
||||
timer.process.add('prompt', time.time()-t0)
|
||||
return llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n
|
||||
|
||||
def latents_encode(input_image, end_image):
|
||||
pbar.update(task, description='image encode')
|
||||
# shared.log.debug(f'FramePack: image encode init={input_image.shape} end={end_image.shape if end_image is not None else None}')
|
||||
t0 = time.time()
|
||||
torch.manual_seed(seed)
|
||||
stream.output_queue.push(('progress', (None, 'VAE encoding...')))
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(vae, devices.device, force=True)
|
||||
if input_image is not None:
|
||||
input_image_pt = torch.from_numpy(input_image).float() / 127.5 - 1
|
||||
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
|
||||
start_latent = framepack_vae.vae_encode(input_image_pt)
|
||||
if start_weight < 1:
|
||||
noise = torch.randn_like(start_latent)
|
||||
start_latent = start_latent * start_weight + noise * (1 - start_weight)
|
||||
if end_image is not None:
|
||||
end_image_pt = torch.from_numpy(end_image).float() / 127.5 - 1
|
||||
end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None]
|
||||
end_latent = framepack_vae.vae_encode(end_image_pt)
|
||||
else:
|
||||
end_latent = None
|
||||
timer.process.add('encode', time.time()-t0)
|
||||
return start_latent, end_latent
|
||||
|
||||
def vision_encode(input_image, end_image):
|
||||
pbar.update(task, description='vision encode')
|
||||
# shared.log.debug(f'FramePack: vision encode init={input_image.shape} end={end_image.shape if end_image is not None else None}')
|
||||
t0 = time.time()
|
||||
shared.state.textinfo = 'Vision encode'
|
||||
stream.output_queue.push(('progress', (None, 'Vision encoding...')))
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(feature_extractor, devices.device, force=True)
|
||||
sd_models.move_model(image_encoder, devices.device, force=True)
|
||||
preprocessed = feature_extractor.preprocess(images=input_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
|
||||
image_encoder_output = image_encoder(**preprocessed)
|
||||
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
|
||||
if end_image is not None:
|
||||
preprocessed = feature_extractor.preprocess(images=end_image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
|
||||
end_image_encoder_output = image_encoder(**preprocessed)
|
||||
end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
|
||||
image_encoder_last_hidden_state = (image_encoder_last_hidden_state * start_weight) + (end_image_encoder_last_hidden_state * end_weight) / (start_weight + end_weight) # use weighted approach
|
||||
timer.process.add('vision', time.time()-t0)
|
||||
image_encoder_last_hidden_state = image_encoder_last_hidden_state * vision_weight
|
||||
return image_encoder_last_hidden_state
|
||||
|
||||
def step_callback(d):
|
||||
if use_cfgzero and is_first_section and d['i'] == 0:
|
||||
d['denoised'] = d['denoised'] * 0
|
||||
t_current = time.time()
|
||||
if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped:
|
||||
stream.output_queue.push(('progress', (None, 'Interrupted...')))
|
||||
stream.output_queue.push(('end', None))
|
||||
raise AssertionError('Interrupted...')
|
||||
if shared.state.paused:
|
||||
shared.log.debug('Sampling paused')
|
||||
while shared.state.paused:
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
raise AssertionError('Interrupted...')
|
||||
time.sleep(0.1)
|
||||
nonlocal total_generated_frames, t_last
|
||||
t_preview = time.time()
|
||||
current_step = d['i'] + 1
|
||||
shared.state.textinfo = ''
|
||||
shared.state.sampling_step = ((lattent_padding_loop-1) * steps) + current_step
|
||||
shared.state.sampling_steps = steps * len(latent_paddings)
|
||||
progress = shared.state.sampling_step / shared.state.sampling_steps
|
||||
total_generated_frames = int(max(0, total_generated_latent_frames * 4 - 3))
|
||||
pbar.update(task, advance=1, description=f'its={1/(t_current-t_last):.2f} sample={d["i"]+1}/{steps} section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)}')
|
||||
desc = f'Step {shared.state.sampling_step}/{shared.state.sampling_steps} | Current {current_step}/{steps} | Section {lattent_padding_loop}/{len(latent_paddings)} | Progress {progress:.2%}'
|
||||
if use_preview:
|
||||
preview = framepack_vae.vae_decode(d['denoised'], 'Preview')
|
||||
stream.output_queue.push(('progress', (preview, desc)))
|
||||
else:
|
||||
stream.output_queue.push(('progress', (None, desc)))
|
||||
timer.process.add('preview', time.time() - t_preview)
|
||||
t_last = t_current
|
||||
|
||||
try:
|
||||
with devices.inference_context(), pbar:
|
||||
t0 = time.time()
|
||||
|
||||
height, width, _C = input_image.shape
|
||||
start_latent, end_latent = latents_encode(input_image, end_image)
|
||||
image_encoder_last_hidden_state = vision_encode(input_image, end_image)
|
||||
|
||||
# Sample loop
|
||||
shared.state.textinfo = 'Sample'
|
||||
stream.output_queue.push(('progress', (None, 'Start sampling...')))
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
if is_f1:
|
||||
history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu()
|
||||
else:
|
||||
history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=devices.dtype).cpu()
|
||||
history_pixels = None
|
||||
lattent_padding_loop = 0
|
||||
last_prompt = None
|
||||
|
||||
for latent_padding in latent_paddings:
|
||||
current_prompt = prompts[lattent_padding_loop]
|
||||
if current_prompt != last_prompt:
|
||||
llama_vec, llama_vec_n, llama_attention_mask, llama_attention_mask_n, clip_l_pooler, clip_l_pooler_n = text_encode(current_prompt, i=lattent_padding_loop+1)
|
||||
last_prompt = current_prompt
|
||||
|
||||
lattent_padding_loop += 1
|
||||
# shared.log.trace(f'FramePack: op=sample section={lattent_padding_loop}/{len(latent_paddings)} frames={total_generated_frames}/{num_frames*len(latent_paddings)} window={latent_window_size} size={num_frames}')
|
||||
if is_f1:
|
||||
is_first_section, is_last_section = False, False
|
||||
else:
|
||||
is_first_section, is_last_section = latent_padding == latent_paddings[0], latent_padding == 0
|
||||
if stream.input_queue.top() == 'end' or shared.state.interrupted or shared.state.skipped:
|
||||
stream.output_queue.push(('end', None))
|
||||
return
|
||||
if is_f1:
|
||||
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
||||
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
||||
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
||||
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
|
||||
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
||||
else:
|
||||
latent_padding_size = latent_padding * latent_window_size
|
||||
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
||||
clean_latent_indices_pre, _blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
||||
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
||||
clean_latents_pre = start_latent.to(history_latents)
|
||||
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
|
||||
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
||||
if end_image is not None and is_first_section:
|
||||
clean_latents_post = (clean_latents_post * start_weight / len(latent_paddings)) + (end_weight * end_latent.to(history_latents)) / (start_weight/len(latent_paddings) + end_weight) # pylint: disable=possibly-used-before-assignment
|
||||
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
||||
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
transformer.initialize_teacache(enable_teacache=use_teacache, num_steps=steps, rel_l1_thresh=shared.opts.teacache_thresh)
|
||||
|
||||
t_sample = time.time()
|
||||
generated_latents = sample_hunyuan(
|
||||
transformer=transformer,
|
||||
sampler='unipc',
|
||||
width=width,
|
||||
height=height,
|
||||
frames=num_frames,
|
||||
num_inference_steps=steps,
|
||||
real_guidance_scale=cfg_scale,
|
||||
distilled_guidance_scale=cfg_distilled,
|
||||
guidance_rescale=cfg_rescale,
|
||||
shift=shift if shift > 0 else None,
|
||||
generator=generator,
|
||||
prompt_embeds=llama_vec, # pylint: disable=possibly-used-before-assignment
|
||||
prompt_embeds_mask=llama_attention_mask, # pylint: disable=possibly-used-before-assignment
|
||||
prompt_poolers=clip_l_pooler, # pylint: disable=possibly-used-before-assignment
|
||||
negative_prompt_embeds=llama_vec_n, # pylint: disable=possibly-used-before-assignment
|
||||
negative_prompt_embeds_mask=llama_attention_mask_n, # pylint: disable=possibly-used-before-assignment
|
||||
negative_prompt_poolers=clip_l_pooler_n, # pylint: disable=possibly-used-before-assignment
|
||||
image_embeddings=image_encoder_last_hidden_state,
|
||||
latent_indices=latent_indices,
|
||||
clean_latents=clean_latents,
|
||||
clean_latent_indices=clean_latent_indices,
|
||||
clean_latents_2x=clean_latents_2x,
|
||||
clean_latent_2x_indices=clean_latent_2x_indices,
|
||||
clean_latents_4x=clean_latents_4x,
|
||||
clean_latent_4x_indices=clean_latent_4x_indices,
|
||||
device=devices.device,
|
||||
dtype=devices.dtype,
|
||||
callback=step_callback,
|
||||
)
|
||||
timer.process.add('sample', time.time()-t_sample)
|
||||
|
||||
if is_last_section:
|
||||
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
|
||||
total_generated_latent_frames += int(generated_latents.shape[2])
|
||||
|
||||
if is_f1:
|
||||
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
|
||||
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
|
||||
else:
|
||||
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
|
||||
real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
|
||||
|
||||
t_vae = time.time()
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
sd_models.move_model(vae, devices.device, force=True)
|
||||
if history_pixels is None:
|
||||
history_pixels = framepack_vae.vae_decode(real_history_latents, vae_type=vae_type).cpu()
|
||||
else:
|
||||
overlapped_frames = latent_window_size * 4 - 3
|
||||
if is_f1:
|
||||
section_latent_frames = latent_window_size * 2
|
||||
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
|
||||
history_pixels = utils.soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
|
||||
else:
|
||||
section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
|
||||
current_pixels = framepack_vae.vae_decode(real_history_latents[:, :, :section_latent_frames], vae_type=vae_type).cpu()
|
||||
history_pixels = utils.soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
||||
timer.process.add('vae', time.time()-t_vae)
|
||||
|
||||
if is_last_section:
|
||||
break
|
||||
|
||||
total_generated_frames = framepack_video.save_video(history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate, pbar=pbar, stream=stream, metadata=metadata)
|
||||
|
||||
except AssertionError:
|
||||
shared.log.info('FramePack: interrupted')
|
||||
if shared.opts.keep_incomplete:
|
||||
framepack_video.save_video(history_pixels, mp4_fps, mp4_codec, mp4_opt, mp4_ext, mp4_sf, mp4_video, mp4_frames, mp4_interpolate=0, stream=stream, metadata=metadata)
|
||||
except Exception as e:
|
||||
shared.log.error(f'FramePack: {e}')
|
||||
errors.display(e, 'FramePack')
|
||||
|
||||
sd_models.apply_balanced_offload(shared.sd_model)
|
||||
stream.output_queue.push(('end', None))
|
||||
t1 = time.time()
|
||||
shared.log.info(f'Processed: frames={total_generated_frames} fps={total_generated_frames/(t1-t0):.2f} its={(shared.state.sampling_step)/(t1-t0):.2f} time={t1-t0:.2f} timers={timer.process.dct()} memory={memstats.memory_stats()}')
|
||||
shared.state.end()
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
import os
|
||||
import re
|
||||
import random
|
||||
import threading
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
from modules import shared, processing, timer, paths, extra_networks, progress, ui_video_vlm
|
||||
from modules.framepack import framepack_install # pylint: disable=wrong-import-order
|
||||
from modules.framepack import framepack_load # pylint: disable=wrong-import-order
|
||||
from modules.framepack import framepack_worker # pylint: disable=wrong-import-order
|
||||
from modules.framepack import framepack_hijack # pylint: disable=wrong-import-order
|
||||
|
||||
|
||||
tmp_dir = os.path.join(paths.data_path, 'tmp', 'framepack')
|
||||
git_dir = os.path.join(os.path.dirname(__file__), 'framepack')
|
||||
git_repo = 'https://github.com/lllyasviel/framepack'
|
||||
git_commit = 'c5d375661a2557383f0b8da9d11d14c23b0c4eaf'
|
||||
queue_lock = threading.Lock()
|
||||
loaded_variant = None
|
||||
|
||||
|
||||
def check_av():
|
||||
try:
|
||||
import av
|
||||
except Exception as e:
|
||||
shared.log.error(f'av package: {e}')
|
||||
return False
|
||||
return av
|
||||
|
||||
|
||||
def get_codecs():
|
||||
av = check_av()
|
||||
if av is None:
|
||||
return []
|
||||
codecs = []
|
||||
for codec in av.codecs_available:
|
||||
try:
|
||||
c = av.Codec(codec, mode='w')
|
||||
if c.type == 'video' and c.is_encoder and len(c.video_formats) > 0:
|
||||
if not any(c.name == ca.name for ca in codecs):
|
||||
codecs.append(c)
|
||||
except Exception:
|
||||
pass
|
||||
hw_codecs = [c for c in codecs if (c.capabilities & 0x40000 > 0) or (c.capabilities & 0x80000 > 0)]
|
||||
sw_codecs = [c for c in codecs if c not in hw_codecs]
|
||||
shared.log.debug(f'Video codecs: hardware={len(hw_codecs)} software={len(sw_codecs)}')
|
||||
# for c in hw_codecs:
|
||||
# shared.log.trace(f'codec={c.name} cname="{c.canonical_name}" decs="{c.long_name}" intra={c.intra_only} lossy={c.lossy} lossless={c.lossless} capabilities={c.capabilities} hw=True')
|
||||
# for c in sw_codecs:
|
||||
# shared.log.trace(f'codec={c.name} cname="{c.canonical_name}" decs="{c.long_name}" intra={c.intra_only} lossy={c.lossy} lossless={c.lossless} capabilities={c.capabilities} hw=False')
|
||||
return ['none'] + [c.name for c in hw_codecs + sw_codecs]
|
||||
|
||||
|
||||
def prepare_image(image, resolution):
|
||||
from modules.framepack.pipeline.utils import resize_and_center_crop
|
||||
buckets = [
|
||||
(416, 960), (448, 864), (480, 832), (512, 768), (544, 704), (576, 672), (608, 640),
|
||||
(640, 608), (672, 576), (704, 544), (768, 512), (832, 480), (864, 448), (960, 416),
|
||||
]
|
||||
h, w, _c = image.shape
|
||||
min_metric = float('inf')
|
||||
scale_factor = resolution / 640.0
|
||||
scaled_h, scaled_w = h, w
|
||||
for (bucket_h, bucket_w) in buckets:
|
||||
metric = abs(h * bucket_w - w * bucket_h)
|
||||
if metric <= min_metric:
|
||||
min_metric = metric
|
||||
scaled_h = round(bucket_h * scale_factor / 16) * 16
|
||||
scaled_w = round(bucket_w * scale_factor / 16) * 16
|
||||
|
||||
image = resize_and_center_crop(image, target_height=scaled_h, target_width=scaled_w)
|
||||
h0, w0, _c = image.shape
|
||||
shared.log.debug(f'FramePack prepare: input="{w}x{h}" resized="{w0}x{h0}" resolution={resolution} scale={scale_factor}')
|
||||
return image
|
||||
|
||||
|
||||
def interpolate_prompts(prompts, steps):
|
||||
interpolated_prompts = [''] * steps
|
||||
if prompts is None:
|
||||
return interpolated_prompts
|
||||
if isinstance(prompts, str):
|
||||
prompts = re.split(r'[,\n]', prompts)
|
||||
prompts = [p.strip() for p in prompts]
|
||||
if len(prompts) == 0:
|
||||
return interpolated_prompts
|
||||
if len(prompts) == steps:
|
||||
return prompts
|
||||
factor = steps / len(prompts)
|
||||
for i in range(steps):
|
||||
prompt_index = int(i / factor)
|
||||
interpolated_prompts[i] = prompts[prompt_index]
|
||||
# shared.log.trace(f'FramePack interpolate: section={i} prompt="{interpolated_prompts[i]}"')
|
||||
return interpolated_prompts
|
||||
|
||||
|
||||
def prepare_prompts(p, init_image, prompt:str, section_prompt:str, num_sections:int, vlm_enhance:bool, vlm_model:str, vlm_system_prompt:str):
|
||||
section_prompts = interpolate_prompts(section_prompt, num_sections)
|
||||
p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)
|
||||
p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
||||
shared.prompt_styles.apply_styles_to_extra(p)
|
||||
p.prompts, p.network_data = extra_networks.parse_prompts([p.prompt])
|
||||
extra_networks.activate(p)
|
||||
prompt = p.prompts[0]
|
||||
generated_prompts = [''] * num_sections
|
||||
previous_prompt = None
|
||||
for i in range(num_sections):
|
||||
current_prompt = (prompt + ' ' + section_prompts[i]).strip()
|
||||
if current_prompt == previous_prompt:
|
||||
generated_prompts[i] = generated_prompts[i - 1]
|
||||
else:
|
||||
generated_prompts[i] = ui_video_vlm.enhance_prompt(
|
||||
enable=vlm_enhance,
|
||||
model=vlm_model,
|
||||
image=init_image,
|
||||
prompt=current_prompt,
|
||||
system_prompt=vlm_system_prompt,
|
||||
)
|
||||
previous_prompt = current_prompt
|
||||
return generated_prompts
|
||||
|
||||
|
||||
def load_model(variant, attention):
|
||||
global loaded_variant # pylint: disable=global-statement
|
||||
if (shared.sd_model_type != 'hunyuanvideo') or (loaded_variant != variant):
|
||||
yield gr.update(), gr.update(), 'Verifying FramePack'
|
||||
framepack_install.install_requirements(attention)
|
||||
# framepack_install.git_clone(git_repo=git_repo, git_dir=git_dir, tmp_dir=tmp_dir)
|
||||
# framepack_install.git_update(git_dir=git_dir, git_commit=git_commit)
|
||||
# sys.path.append(git_dir)
|
||||
framepack_hijack.set_progress_bar_config()
|
||||
yield gr.update(), gr.update(), 'Model loading...', ''
|
||||
loaded_variant = framepack_load.load_model(variant)
|
||||
if loaded_variant is not None:
|
||||
yield gr.update(), gr.update(), 'Model loaded'
|
||||
else:
|
||||
yield gr.update(), gr.update(), 'Model load failed'
|
||||
|
||||
|
||||
def unload_model():
|
||||
shared.log.debug('FramePack unload')
|
||||
framepack_load.unload_model()
|
||||
yield gr.update(), gr.update(), 'Model unloaded'
|
||||
|
||||
|
||||
def run_framepack(task_id, _ui_state, init_image, end_image, start_weight, end_weight, vision_weight, prompt, system_prompt, optimized_prompt, section_prompt, negative_prompt, styles, seed, resolution, duration, latent_ws, steps, cfg_scale, cfg_distilled, cfg_rescale, shift, use_teacache, use_cfgzero, use_preview, mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate, attention, vae_type, variant, vlm_enhance, vlm_model, vlm_system_prompt):
|
||||
variant = variant or 'bi-directional'
|
||||
if init_image is None:
|
||||
init_image = np.zeros((resolution, resolution, 3), dtype=np.uint8)
|
||||
mode = 't2v'
|
||||
elif end_image is not None:
|
||||
mode = 'flf2v'
|
||||
else:
|
||||
mode = 'i2v'
|
||||
|
||||
av = check_av()
|
||||
if av is None:
|
||||
yield gr.update(), gr.update(), 'AV package not installed'
|
||||
return
|
||||
|
||||
progress.add_task_to_queue(task_id)
|
||||
with queue_lock:
|
||||
progress.start_task(task_id)
|
||||
|
||||
yield from load_model(variant, attention)
|
||||
if shared.sd_model_type != 'hunyuanvideo':
|
||||
progress.finish_task(task_id)
|
||||
yield gr.update(), gr.update(), 'Model load failed'
|
||||
return
|
||||
|
||||
yield gr.update(), gr.update(), 'Generate starting...'
|
||||
from modules.framepack.pipeline.thread_utils import AsyncStream, async_run
|
||||
framepack_worker.stream = AsyncStream()
|
||||
|
||||
if seed is None or seed == '' or seed == -1:
|
||||
random.seed()
|
||||
seed = random.randrange(4294967294)
|
||||
seed = int(seed)
|
||||
torch.manual_seed(seed)
|
||||
num_sections = len(framepack_worker.get_latent_paddings(mp4_fps, mp4_interpolate, latent_ws, duration, variant))
|
||||
num_frames = (latent_ws * 4 - 3) * num_sections + 1
|
||||
shared.log.info(f'FramePack start: mode={mode} variant="{variant}" frames={num_frames} sections={num_sections} resolution={resolution} seed={seed} duration={duration} teacache={use_teacache} thres={shared.opts.teacache_thresh} cfgzero={use_cfgzero}')
|
||||
shared.log.info(f'FramePack params: steps={steps} start={start_weight} end={end_weight} vision={vision_weight} scale={cfg_scale} distilled={cfg_distilled} rescale={cfg_rescale} shift={shift}')
|
||||
init_image = prepare_image(init_image, resolution)
|
||||
if end_image is not None:
|
||||
end_image = prepare_image(end_image, resolution)
|
||||
w, h, _c = init_image.shape
|
||||
p = processing.StableDiffusionProcessingVideo(
|
||||
sd_model=shared.sd_model,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
styles=styles,
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
width=w,
|
||||
height=h,
|
||||
)
|
||||
prompts = prepare_prompts(p, init_image, prompt, section_prompt, num_sections, vlm_enhance, vlm_model, vlm_system_prompt)
|
||||
|
||||
async_run(
|
||||
framepack_worker.worker,
|
||||
init_image, end_image,
|
||||
start_weight, end_weight, vision_weight,
|
||||
prompts, p.negative_prompt, system_prompt, optimized_prompt, vlm_enhance,
|
||||
seed,
|
||||
duration,
|
||||
latent_ws,
|
||||
p.steps,
|
||||
cfg_scale, cfg_distilled, cfg_rescale,
|
||||
shift,
|
||||
use_teacache, use_cfgzero, use_preview,
|
||||
mp4_fps, mp4_codec, mp4_sf, mp4_video, mp4_frames, mp4_opt, mp4_ext, mp4_interpolate,
|
||||
vae_type, variant,
|
||||
)
|
||||
|
||||
output_filename = None
|
||||
while True:
|
||||
flag, data = framepack_worker.stream.output_queue.next()
|
||||
if flag == 'file':
|
||||
output_filename = data
|
||||
yield output_filename, gr.update(), gr.update()
|
||||
if flag == 'progress':
|
||||
preview, text = data
|
||||
summary = timer.process.summary(min_time=0.25, total=False).replace('=', ' ')
|
||||
memory = shared.mem_mon.summary()
|
||||
stats = f"<div class='performance'><p>{summary} {memory}</p></div>"
|
||||
yield gr.update(), gr.update(value=preview), f'{text} {stats}'
|
||||
if flag == 'end':
|
||||
yield output_filename, gr.update(value=None), gr.update()
|
||||
break
|
||||
|
||||
progress.finish_task(task_id)
|
||||
yield gr.update(), gr.update(), 'Generate finished'
|
||||
return
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
bucket_options = {
|
||||
640: [
|
||||
(416, 960),
|
||||
(448, 864),
|
||||
(480, 832),
|
||||
(512, 768),
|
||||
(544, 704),
|
||||
(576, 672),
|
||||
(608, 640),
|
||||
(640, 608),
|
||||
(672, 576),
|
||||
(704, 544),
|
||||
(768, 512),
|
||||
(832, 480),
|
||||
(864, 448),
|
||||
(960, 416),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def find_nearest_bucket(h, w, resolution=640):
|
||||
min_metric = float('inf')
|
||||
best_bucket = None
|
||||
for (bucket_h, bucket_w) in bucket_options[resolution]:
|
||||
metric = abs(h * bucket_w - w * bucket_h)
|
||||
if metric <= min_metric:
|
||||
min_metric = metric
|
||||
best_bucket = (bucket_h, bucket_w)
|
||||
return best_bucket
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def hf_clip_vision_encode(image, feature_extractor, image_encoder):
|
||||
assert isinstance(image, np.ndarray)
|
||||
assert image.ndim == 3 and image.shape[2] == 3
|
||||
assert image.dtype == np.uint8
|
||||
|
||||
preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
|
||||
image_encoder_output = image_encoder(**preprocessed)
|
||||
|
||||
return image_encoder_output
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
import torch
|
||||
import accelerate.accelerator
|
||||
|
||||
from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
|
||||
|
||||
|
||||
accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
|
||||
|
||||
|
||||
def LayerNorm_forward(self, x):
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
|
||||
|
||||
|
||||
LayerNorm.forward = LayerNorm_forward
|
||||
torch.nn.LayerNorm.forward = LayerNorm_forward
|
||||
|
||||
|
||||
def FP32LayerNorm_forward(self, x):
|
||||
origin_dtype = x.dtype
|
||||
return torch.nn.functional.layer_norm(
|
||||
x.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
).to(origin_dtype)
|
||||
|
||||
|
||||
FP32LayerNorm.forward = FP32LayerNorm_forward
|
||||
|
||||
|
||||
def RMSNorm_forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is None:
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
|
||||
|
||||
|
||||
RMSNorm.forward = RMSNorm_forward
|
||||
|
||||
|
||||
def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
|
||||
emb = self.linear(self.silu(conditioning_embedding))
|
||||
scale, shift = emb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
import torch
|
||||
from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
prompt = [prompt]
|
||||
|
||||
# LLAMA
|
||||
|
||||
prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
|
||||
crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
|
||||
|
||||
llama_inputs = tokenizer(
|
||||
prompt_llama,
|
||||
padding="max_length",
|
||||
max_length=max_length + crop_start,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True,
|
||||
)
|
||||
|
||||
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
|
||||
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
|
||||
llama_attention_length = int(llama_attention_mask.sum())
|
||||
|
||||
llama_outputs = text_encoder(
|
||||
input_ids=llama_input_ids,
|
||||
attention_mask=llama_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
|
||||
# llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
|
||||
llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
|
||||
|
||||
assert torch.all(llama_attention_mask.bool())
|
||||
|
||||
# CLIP
|
||||
|
||||
clip_l_input_ids = tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
|
||||
|
||||
return llama_vec, clip_l_pooler
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vae_decode_fake(latents):
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
[0.0696, 0.0795, 0.0518],
|
||||
[0.0135, -0.0945, -0.0282],
|
||||
[0.0108, -0.0250, -0.0765],
|
||||
[-0.0209, 0.0032, 0.0224],
|
||||
[-0.0804, -0.0254, -0.0639],
|
||||
[-0.0991, 0.0271, -0.0669],
|
||||
[-0.0646, -0.0422, -0.0400],
|
||||
[-0.0696, -0.0595, -0.0894],
|
||||
[-0.0799, -0.0208, -0.0375],
|
||||
[0.1166, 0.1627, 0.0962],
|
||||
[0.1165, 0.0432, 0.0407],
|
||||
[-0.2315, -0.1920, -0.1355],
|
||||
[-0.0270, 0.0401, -0.0821],
|
||||
[-0.0616, -0.0997, -0.0727],
|
||||
[0.0249, -0.0469, -0.1703]
|
||||
] # From comfyui
|
||||
|
||||
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
|
||||
|
||||
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
|
||||
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
||||
|
||||
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
|
||||
images = images.clamp(0.0, 1.0)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vae_decode(latents, vae, image_mode=False):
|
||||
latents = latents / vae.config.scaling_factor
|
||||
|
||||
if not image_mode:
|
||||
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
|
||||
else:
|
||||
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
|
||||
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
|
||||
image = torch.cat(image, dim=2)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def vae_encode(image, vae):
|
||||
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
return latents
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,119 @@
|
|||
import math
|
||||
import torch
|
||||
from modules.framepack.pipeline.uni_pc_fm import sample_unipc
|
||||
from modules.framepack.pipeline.wrapper import fm_wrapper
|
||||
from modules.framepack.pipeline.utils import repeat_to_batch_size
|
||||
|
||||
|
||||
def flux_time_shift(t, mu=1.15, sigma=1.0):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
|
||||
k = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - k * x1
|
||||
mu = k * context_length + b
|
||||
mu = min(mu, math.log(exp_max))
|
||||
return mu
|
||||
|
||||
|
||||
def get_flux_sigmas_from_mu(n, mu):
|
||||
sigmas = torch.linspace(1, 0, steps=n + 1)
|
||||
sigmas = flux_time_shift(sigmas, mu=mu)
|
||||
return sigmas
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_hunyuan(
|
||||
transformer,
|
||||
sampler='unipc',
|
||||
initial_latent=None,
|
||||
concat_latent=None,
|
||||
strength=1.0,
|
||||
width=512,
|
||||
height=512,
|
||||
frames=16,
|
||||
real_guidance_scale=1.0,
|
||||
distilled_guidance_scale=6.0,
|
||||
guidance_rescale=0.0,
|
||||
shift=None,
|
||||
num_inference_steps=25,
|
||||
batch_size=None,
|
||||
generator=None,
|
||||
prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
prompt_poolers=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
negative_prompt_poolers=None,
|
||||
dtype=torch.bfloat16,
|
||||
device=None,
|
||||
negative_kwargs=None,
|
||||
callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
device = device or transformer.device
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = int(prompt_embeds.shape[0])
|
||||
|
||||
latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
|
||||
|
||||
_B, _C, T, H, W = latents.shape
|
||||
seq_length = T * H * W // 4
|
||||
|
||||
if shift is None:
|
||||
mu = calculate_flux_mu(seq_length, exp_max=7.0)
|
||||
else:
|
||||
mu = math.log(shift)
|
||||
|
||||
sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
|
||||
|
||||
k_model = fm_wrapper(transformer)
|
||||
|
||||
if initial_latent is not None:
|
||||
sigmas = sigmas * strength
|
||||
first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
|
||||
initial_latent = initial_latent.to(device=device, dtype=torch.float32)
|
||||
latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
|
||||
|
||||
if concat_latent is not None:
|
||||
concat_latent = concat_latent.to(latents)
|
||||
|
||||
distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
|
||||
|
||||
prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
|
||||
prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
|
||||
prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
|
||||
negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
|
||||
negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
|
||||
negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
|
||||
concat_latent = repeat_to_batch_size(concat_latent, batch_size)
|
||||
|
||||
sampler_kwargs = dict(
|
||||
dtype=dtype,
|
||||
cfg_scale=real_guidance_scale,
|
||||
cfg_rescale=guidance_rescale,
|
||||
concat_latent=concat_latent,
|
||||
positive=dict(
|
||||
pooled_projections=prompt_poolers,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_embeds_mask,
|
||||
guidance=distilled_guidance,
|
||||
**kwargs,
|
||||
),
|
||||
negative=dict(
|
||||
pooled_projections=negative_prompt_poolers,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_attention_mask=negative_prompt_embeds_mask,
|
||||
guidance=distilled_guidance,
|
||||
**(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
|
||||
)
|
||||
)
|
||||
|
||||
if sampler == 'unipc':
|
||||
results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
|
||||
else:
|
||||
raise NotImplementedError(f'Sampler {sampler} is not supported.')
|
||||
|
||||
return results
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
import time
|
||||
|
||||
from threading import Thread, Lock
|
||||
|
||||
|
||||
class Listener:
|
||||
task_queue = []
|
||||
lock = Lock()
|
||||
thread = None
|
||||
|
||||
@classmethod
|
||||
def _process_tasks(cls):
|
||||
while True:
|
||||
task = None
|
||||
with cls.lock:
|
||||
if cls.task_queue:
|
||||
task = cls.task_queue.pop(0)
|
||||
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
|
||||
func, args, kwargs = task
|
||||
try:
|
||||
func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"Error in listener thread: {e}")
|
||||
|
||||
@classmethod
|
||||
def add_task(cls, func, *args, **kwargs):
|
||||
with cls.lock:
|
||||
cls.task_queue.append((func, args, kwargs))
|
||||
|
||||
if cls.thread is None:
|
||||
cls.thread = Thread(target=cls._process_tasks, daemon=True)
|
||||
cls.thread.start()
|
||||
|
||||
|
||||
def async_run(func, *args, **kwargs):
|
||||
Listener.add_task(func, *args, **kwargs)
|
||||
|
||||
|
||||
class FIFOQueue:
|
||||
def __init__(self):
|
||||
self.queue = []
|
||||
self.lock = Lock()
|
||||
|
||||
def push(self, item):
|
||||
with self.lock:
|
||||
self.queue.append(item)
|
||||
|
||||
def pop(self):
|
||||
with self.lock:
|
||||
if self.queue:
|
||||
return self.queue.pop(0)
|
||||
return None
|
||||
|
||||
def top(self):
|
||||
with self.lock:
|
||||
if self.queue:
|
||||
return self.queue[0]
|
||||
return None
|
||||
|
||||
def next(self):
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.queue:
|
||||
return self.queue.pop(0)
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
def __init__(self):
|
||||
self.input_queue = FIFOQueue()
|
||||
self.output_queue = FIFOQueue()
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
# Better Flow Matching UniPC by Lvmin Zhang
|
||||
# (c) 2025
|
||||
# CC BY-SA 4.0
|
||||
# Attribution-ShareAlike 4.0 International Licence
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import trange
|
||||
|
||||
|
||||
def expand_dims(v, dims):
|
||||
return v[(...,) + (None,) * (dims - 1)]
|
||||
|
||||
|
||||
class FlowMatchUniPC:
|
||||
def __init__(self, model, extra_args, variant='bh1'):
|
||||
self.model = model
|
||||
self.variant = variant
|
||||
self.extra_args = extra_args
|
||||
|
||||
def model_fn(self, x, t):
|
||||
return self.model(x, t, **self.extra_args)
|
||||
|
||||
def update_fn(self, x, model_prev_list, t_prev_list, t, order):
|
||||
assert order <= len(model_prev_list)
|
||||
dims = x.dim()
|
||||
|
||||
t_prev_0 = t_prev_list[-1]
|
||||
lambda_prev_0 = - torch.log(t_prev_0)
|
||||
lambda_t = - torch.log(t)
|
||||
model_prev_0 = model_prev_list[-1]
|
||||
|
||||
h = lambda_t - lambda_prev_0
|
||||
|
||||
rks = []
|
||||
D1s = []
|
||||
for i in range(1, order):
|
||||
t_prev_i = t_prev_list[-(i + 1)]
|
||||
model_prev_i = model_prev_list[-(i + 1)]
|
||||
lambda_prev_i = - torch.log(t_prev_i)
|
||||
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
||||
rks.append(rk)
|
||||
D1s.append((model_prev_i - model_prev_0) / rk)
|
||||
|
||||
rks.append(1.)
|
||||
rks = torch.tensor(rks, device=x.device)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
|
||||
hh = -h[0]
|
||||
h_phi_1 = torch.expm1(hh)
|
||||
h_phi_k = h_phi_1 / hh - 1
|
||||
|
||||
factorial_i = 1
|
||||
|
||||
if self.variant == 'bh1':
|
||||
B_h = hh
|
||||
elif self.variant == 'bh2':
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError('Bad variant!')
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
factorial_i *= (i + 1)
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=x.device)
|
||||
|
||||
use_predictor = len(D1s) > 0
|
||||
|
||||
if use_predictor:
|
||||
D1s = torch.stack(D1s, dim=1)
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
rhos_p = None
|
||||
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
|
||||
x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
|
||||
|
||||
if use_predictor:
|
||||
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
|
||||
else:
|
||||
pred_res = 0
|
||||
|
||||
x_t = x_t_ - expand_dims(B_h, dims) * pred_res
|
||||
model_t = self.model_fn(x_t, t)
|
||||
|
||||
if D1s is not None:
|
||||
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
|
||||
else:
|
||||
corr_res = 0
|
||||
|
||||
D1_t = model_t - model_prev_0
|
||||
x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||
|
||||
return x_t, model_t
|
||||
|
||||
def sample(self, x, sigmas, callback=None, disable_pbar=False):
|
||||
order = min(3, len(sigmas) - 2)
|
||||
model_prev_list, t_prev_list = [], []
|
||||
for i in trange(len(sigmas) - 1, disable=disable_pbar):
|
||||
vec_t = sigmas[i].expand(x.shape[0])
|
||||
|
||||
if i == 0:
|
||||
model_prev_list = [self.model_fn(x, vec_t)]
|
||||
t_prev_list = [vec_t]
|
||||
elif i < order:
|
||||
init_order = i
|
||||
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
else:
|
||||
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
|
||||
model_prev_list = model_prev_list[-order:]
|
||||
t_prev_list = t_prev_list[-order:]
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
|
||||
|
||||
return model_prev_list[-1]
|
||||
|
||||
|
||||
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
||||
assert variant in ['bh1', 'bh2']
|
||||
return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
|
||||
|
|
@ -0,0 +1,567 @@
|
|||
import os
|
||||
import cv2
|
||||
import json
|
||||
import random
|
||||
import glob
|
||||
import torch
|
||||
import einops
|
||||
import numpy as np
|
||||
import datetime
|
||||
import torchvision
|
||||
import safetensors.torch as sf
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def min_resize(x, m):
|
||||
if x.shape[0] < x.shape[1]:
|
||||
s0 = m
|
||||
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
||||
else:
|
||||
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
||||
s1 = m
|
||||
new_max = max(s1, s0)
|
||||
raw_max = max(x.shape[0], x.shape[1])
|
||||
if new_max < raw_max:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_LANCZOS4
|
||||
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
|
||||
return y
|
||||
|
||||
|
||||
def d_resize(x, y):
|
||||
H, W, C = y.shape
|
||||
new_min = min(H, W)
|
||||
raw_min = min(x.shape[0], x.shape[1])
|
||||
if new_min < raw_min:
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
interpolation = cv2.INTER_LANCZOS4
|
||||
y = cv2.resize(x, (W, H), interpolation=interpolation)
|
||||
return y
|
||||
|
||||
|
||||
def resize_and_center_crop(image, target_width, target_height):
|
||||
if target_height == image.shape[0] and target_width == image.shape[1]:
|
||||
return image
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
original_width, original_height = pil_image.size
|
||||
scale_factor = max(target_width / original_width, target_height / original_height)
|
||||
resized_width = int(round(original_width * scale_factor))
|
||||
resized_height = int(round(original_height * scale_factor))
|
||||
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
|
||||
left = (resized_width - target_width) / 2
|
||||
top = (resized_height - target_height) / 2
|
||||
right = (resized_width + target_width) / 2
|
||||
bottom = (resized_height + target_height) / 2
|
||||
cropped_image = resized_image.crop((left, top, right, bottom))
|
||||
return np.array(cropped_image)
|
||||
|
||||
|
||||
def resize_and_center_crop_pytorch(image, target_width, target_height):
|
||||
B, C, H, W = image.shape
|
||||
|
||||
if H == target_height and W == target_width:
|
||||
return image
|
||||
|
||||
scale_factor = max(target_width / W, target_height / H)
|
||||
resized_width = int(round(W * scale_factor))
|
||||
resized_height = int(round(H * scale_factor))
|
||||
|
||||
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
|
||||
|
||||
top = (resized_height - target_height) // 2
|
||||
left = (resized_width - target_width) // 2
|
||||
cropped = resized[:, :, top:top + target_height, left:left + target_width]
|
||||
|
||||
return cropped
|
||||
|
||||
|
||||
def resize_without_crop(image, target_width, target_height):
|
||||
if target_height == image.shape[0] and target_width == image.shape[1]:
|
||||
return image
|
||||
|
||||
pil_image = Image.fromarray(image)
|
||||
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
||||
return np.array(resized_image)
|
||||
|
||||
|
||||
def just_crop(image, w, h):
|
||||
if h == image.shape[0] and w == image.shape[1]:
|
||||
return image
|
||||
|
||||
original_height, original_width = image.shape[:2]
|
||||
k = min(original_height / h, original_width / w)
|
||||
new_width = int(round(w * k))
|
||||
new_height = int(round(h * k))
|
||||
x_start = (original_width - new_width) // 2
|
||||
y_start = (original_height - new_height) // 2
|
||||
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
|
||||
return cropped_image
|
||||
|
||||
|
||||
def write_to_json(data, file_path):
|
||||
temp_file_path = file_path + ".tmp"
|
||||
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
||||
json.dump(data, temp_file, indent=4)
|
||||
os.replace(temp_file_path, file_path)
|
||||
return
|
||||
|
||||
|
||||
def read_from_json(file_path):
|
||||
with open(file_path, 'rt', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
|
||||
|
||||
def get_active_parameters(m):
|
||||
return {k: v for k, v in m.named_parameters() if v.requires_grad}
|
||||
|
||||
|
||||
def cast_training_params(m, dtype=torch.float32):
|
||||
result = {}
|
||||
for n, param in m.named_parameters():
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
result[n] = param
|
||||
return result
|
||||
|
||||
|
||||
def separate_lora_AB(parameters, B_patterns=None):
|
||||
parameters_normal = {}
|
||||
parameters_B = {}
|
||||
|
||||
if B_patterns is None:
|
||||
B_patterns = ['.lora_B.', '__zero__']
|
||||
|
||||
for k, v in parameters.items():
|
||||
if any(B_pattern in k for B_pattern in B_patterns):
|
||||
parameters_B[k] = v
|
||||
else:
|
||||
parameters_normal[k] = v
|
||||
|
||||
return parameters_normal, parameters_B
|
||||
|
||||
|
||||
def set_attr_recursive(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
setattr(obj, attrs[-1], value)
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
|
||||
batch_size = a.size(0)
|
||||
|
||||
if b is None:
|
||||
b = torch.zeros_like(a)
|
||||
|
||||
if mask_a is None:
|
||||
mask_a = torch.rand(batch_size) < probability_a
|
||||
|
||||
mask_a = mask_a.to(a.device)
|
||||
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
||||
result = torch.where(mask_a, a, b)
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def supress_lower_channels(m, k, alpha=0.01):
|
||||
data = m.weight.data.clone()
|
||||
|
||||
assert int(data.shape[1]) >= k
|
||||
|
||||
data[:, :k] = data[:, :k] * alpha
|
||||
m.weight.data = data.contiguous().clone()
|
||||
return m
|
||||
|
||||
|
||||
def freeze_module(m):
|
||||
if not hasattr(m, '_forward_inside_frozen_module'):
|
||||
m._forward_inside_frozen_module = m.forward
|
||||
m.requires_grad_(False)
|
||||
m.forward = torch.no_grad()(m.forward)
|
||||
return m
|
||||
|
||||
|
||||
def get_latest_safetensors(folder_path):
|
||||
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
|
||||
|
||||
if not safetensors_files:
|
||||
raise ValueError('No file to resume!')
|
||||
|
||||
latest_file = max(safetensors_files, key=os.path.getmtime)
|
||||
latest_file = os.path.abspath(os.path.realpath(latest_file))
|
||||
return latest_file
|
||||
|
||||
|
||||
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
||||
tags = tags_str.split(', ')
|
||||
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
||||
prompt = ', '.join(tags)
|
||||
return prompt
|
||||
|
||||
|
||||
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
|
||||
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
|
||||
if round_to_int:
|
||||
numbers = np.round(numbers).astype(int)
|
||||
return numbers.tolist()
|
||||
|
||||
|
||||
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
|
||||
edges = np.linspace(0, 1, n + 1)
|
||||
points = np.random.uniform(edges[:-1], edges[1:])
|
||||
numbers = inclusive + (exclusive - inclusive) * points
|
||||
if round_to_int:
|
||||
numbers = np.round(numbers).astype(int)
|
||||
return numbers.tolist()
|
||||
|
||||
|
||||
def soft_append_bcthw(history, current, overlap=0):
|
||||
if overlap <= 0:
|
||||
return torch.cat([history, current], dim=2)
|
||||
|
||||
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
|
||||
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
|
||||
|
||||
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
|
||||
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
|
||||
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
|
||||
|
||||
return output.to(history)
|
||||
|
||||
|
||||
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
per_row = b
|
||||
for p in [6, 5, 4, 3, 2]:
|
||||
if b % p == 0:
|
||||
per_row = p
|
||||
break
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
||||
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
||||
x = x.detach().cpu().to(torch.uint8)
|
||||
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
||||
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
|
||||
return x
|
||||
|
||||
|
||||
def save_bcthw_as_png(x, output_filename):
|
||||
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
||||
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
||||
x = x.detach().cpu().to(torch.uint8)
|
||||
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
||||
torchvision.io.write_png(x, output_filename)
|
||||
return output_filename
|
||||
|
||||
|
||||
def save_bchw_as_png(x, output_filename):
|
||||
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
||||
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
||||
x = x.detach().cpu().to(torch.uint8)
|
||||
x = einops.rearrange(x, 'b c h w -> c h (b w)')
|
||||
torchvision.io.write_png(x, output_filename)
|
||||
return output_filename
|
||||
|
||||
|
||||
def add_tensors_with_padding(tensor1, tensor2):
|
||||
if tensor1.shape == tensor2.shape:
|
||||
return tensor1 + tensor2
|
||||
|
||||
shape1 = tensor1.shape
|
||||
shape2 = tensor2.shape
|
||||
|
||||
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
||||
|
||||
padded_tensor1 = torch.zeros(new_shape)
|
||||
padded_tensor2 = torch.zeros(new_shape)
|
||||
|
||||
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
||||
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
||||
|
||||
result = padded_tensor1 + padded_tensor2
|
||||
return result
|
||||
|
||||
|
||||
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
txt = Image.new("RGB", (width, height), color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype(font_path, size=size)
|
||||
|
||||
if text == '':
|
||||
return np.array(txt)
|
||||
|
||||
# Split text into lines that fit within the image width
|
||||
lines = []
|
||||
words = text.split()
|
||||
current_line = words[0]
|
||||
|
||||
for word in words[1:]:
|
||||
line_with_word = f"{current_line} {word}"
|
||||
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
|
||||
current_line = line_with_word
|
||||
else:
|
||||
lines.append(current_line)
|
||||
current_line = word
|
||||
|
||||
lines.append(current_line)
|
||||
|
||||
# Draw the text line by line
|
||||
y = 0
|
||||
line_height = draw.textbbox((0, 0), "A", font=font)[3]
|
||||
|
||||
for line in lines:
|
||||
if y + line_height > height:
|
||||
break # stop drawing if the next line will be outside the image
|
||||
draw.text((0, y), line, fill="black", font=font)
|
||||
y += line_height
|
||||
|
||||
return np.array(txt)
|
||||
|
||||
|
||||
def blue_mark(x):
|
||||
x = x.copy()
|
||||
c = x[:, :, 2]
|
||||
b = cv2.blur(c, (9, 9))
|
||||
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
|
||||
return x
|
||||
|
||||
|
||||
def green_mark(x):
|
||||
x = x.copy()
|
||||
x[:, :, 2] = -1
|
||||
x[:, :, 0] = -1
|
||||
return x
|
||||
|
||||
|
||||
def frame_mark(x):
|
||||
x = x.copy()
|
||||
x[:64] = -1
|
||||
x[-64:] = -1
|
||||
x[:, :8] = 1
|
||||
x[:, -8:] = 1
|
||||
return x
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def pytorch2numpy(imgs):
|
||||
results = []
|
||||
for x in imgs:
|
||||
y = x.movedim(0, -1)
|
||||
y = y * 127.5 + 127.5
|
||||
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
||||
results.append(y)
|
||||
return results
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def numpy2pytorch(imgs):
|
||||
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
||||
h = h.movedim(-1, 1)
|
||||
return h
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def duplicate_prefix_to_suffix(x, count, zero_out=False):
|
||||
if zero_out:
|
||||
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
|
||||
else:
|
||||
return torch.cat([x, x[:count]], dim=0)
|
||||
|
||||
|
||||
def weighted_mse(a, b, weight):
|
||||
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
|
||||
|
||||
|
||||
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
|
||||
x = (x - x_min) / (x_max - x_min)
|
||||
x = max(0.0, min(x, 1.0))
|
||||
x = x ** sigma
|
||||
return y_min + x * (y_max - y_min)
|
||||
|
||||
|
||||
def expand_to_dims(x, target_dims):
|
||||
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
|
||||
|
||||
|
||||
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
|
||||
if tensor is None:
|
||||
return None
|
||||
|
||||
first_dim = tensor.shape[0]
|
||||
|
||||
if first_dim == batch_size:
|
||||
return tensor
|
||||
|
||||
if batch_size % first_dim != 0:
|
||||
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
|
||||
|
||||
repeat_times = batch_size // first_dim
|
||||
|
||||
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
|
||||
|
||||
|
||||
def dim5(x):
|
||||
return expand_to_dims(x, 5)
|
||||
|
||||
|
||||
def dim4(x):
|
||||
return expand_to_dims(x, 4)
|
||||
|
||||
|
||||
def dim3(x):
|
||||
return expand_to_dims(x, 3)
|
||||
|
||||
|
||||
def crop_or_pad_yield_mask(x, length):
|
||||
B, F, C = x.shape
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
|
||||
if F < length:
|
||||
y = torch.zeros((B, length, C), dtype=dtype, device=device)
|
||||
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
|
||||
y[:, :F, :] = x
|
||||
mask[:, :F] = True
|
||||
return y, mask
|
||||
|
||||
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
|
||||
|
||||
|
||||
def extend_dim(x, dim, minimal_length, zero_pad=False):
|
||||
original_length = int(x.shape[dim])
|
||||
|
||||
if original_length >= minimal_length:
|
||||
return x
|
||||
|
||||
if zero_pad:
|
||||
padding_shape = list(x.shape)
|
||||
padding_shape[dim] = minimal_length - original_length
|
||||
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
|
||||
last_element = x[idx]
|
||||
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
|
||||
|
||||
return torch.cat([x, padding], dim=dim)
|
||||
|
||||
|
||||
def lazy_positional_encoding(t, repeats=None):
|
||||
if not isinstance(t, list):
|
||||
t = [t]
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
|
||||
te = torch.tensor(t)
|
||||
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
|
||||
|
||||
if repeats is None:
|
||||
return te
|
||||
|
||||
te = te[:, None, :].expand(-1, repeats, -1)
|
||||
|
||||
return te
|
||||
|
||||
|
||||
def state_dict_offset_merge(A, B, C=None):
|
||||
result = {}
|
||||
keys = A.keys()
|
||||
|
||||
for key in keys:
|
||||
A_value = A[key]
|
||||
B_value = B[key].to(A_value)
|
||||
|
||||
if C is None:
|
||||
result[key] = A_value + B_value
|
||||
else:
|
||||
C_value = C[key].to(A_value)
|
||||
result[key] = A_value + B_value - C_value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def state_dict_weighted_merge(state_dicts, weights):
|
||||
if len(state_dicts) != len(weights):
|
||||
raise ValueError("Number of state dictionaries must match number of weights")
|
||||
|
||||
if not state_dicts:
|
||||
return {}
|
||||
|
||||
total_weight = sum(weights)
|
||||
|
||||
if total_weight == 0:
|
||||
raise ValueError("Sum of weights cannot be zero")
|
||||
|
||||
normalized_weights = [w / total_weight for w in weights]
|
||||
|
||||
keys = state_dicts[0].keys()
|
||||
result = {}
|
||||
|
||||
for key in keys:
|
||||
result[key] = state_dicts[0][key] * normalized_weights[0]
|
||||
|
||||
for i in range(1, len(state_dicts)):
|
||||
state_dict_value = state_dicts[i][key].to(result[key])
|
||||
result[key] += state_dict_value * normalized_weights[i]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def group_files_by_folder(all_files):
|
||||
grouped_files = {}
|
||||
|
||||
for file in all_files:
|
||||
folder_name = os.path.basename(os.path.dirname(file))
|
||||
if folder_name not in grouped_files:
|
||||
grouped_files[folder_name] = []
|
||||
grouped_files[folder_name].append(file)
|
||||
|
||||
list_of_lists = list(grouped_files.values())
|
||||
return list_of_lists
|
||||
|
||||
|
||||
def generate_timestamp():
|
||||
now = datetime.datetime.now()
|
||||
timestamp = now.strftime('%y%m%d_%H%M%S')
|
||||
milliseconds = f"{int(now.microsecond / 1000):03d}"
|
||||
random_number = random.randint(0, 9999)
|
||||
return f"{timestamp}_{milliseconds}_{random_number}"
|
||||
|
||||
|
||||
def write_PIL_image_with_png_info(image, metadata, path):
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
png_info = PngInfo()
|
||||
for key, value in metadata.items():
|
||||
png_info.add_text(key, value)
|
||||
|
||||
image.save(path, "PNG", pnginfo=png_info)
|
||||
return image
|
||||
|
||||
|
||||
def torch_safe_save(content, path):
|
||||
torch.save(content, path + '_tmp')
|
||||
os.replace(path + '_tmp', path)
|
||||
return path
|
||||
|
||||
|
||||
def move_optimizer_to_device(optimizer, device):
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
return x[(...,) + (None,) * (target_dims - x.ndim)]
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
|
||||
if guidance_rescale == 0:
|
||||
return noise_cfg
|
||||
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
def fm_wrapper(transformer, t_scale=1000.0):
|
||||
def k_model(x, sigma, **extra_args):
|
||||
dtype = extra_args['dtype']
|
||||
cfg_scale = extra_args['cfg_scale']
|
||||
cfg_rescale = extra_args['cfg_rescale']
|
||||
concat_latent = extra_args['concat_latent']
|
||||
|
||||
original_dtype = x.dtype
|
||||
sigma = sigma.float()
|
||||
|
||||
x = x.to(dtype)
|
||||
timestep = (sigma * t_scale).to(dtype)
|
||||
|
||||
if concat_latent is None:
|
||||
hidden_states = x
|
||||
else:
|
||||
hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
|
||||
|
||||
pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
|
||||
|
||||
if cfg_scale == 1.0:
|
||||
pred_negative = torch.zeros_like(pred_positive)
|
||||
else:
|
||||
pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
|
||||
|
||||
pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
|
||||
pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
|
||||
|
||||
x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
|
||||
|
||||
return x0.to(dtype=original_dtype)
|
||||
|
||||
return k_model
|
||||
|
|
@ -8,7 +8,7 @@ from diffusers.pipelines import auto_pipeline
|
|||
|
||||
current_steps = 50
|
||||
def sd15_hidiffusion_key():
|
||||
modified_key = dict()
|
||||
modified_key = {}
|
||||
modified_key['down_module_key'] = ['down_blocks.0.downsamplers.0.conv']
|
||||
modified_key['down_module_key_extra'] = ['down_blocks.1']
|
||||
modified_key['up_module_key'] = ['up_blocks.2.upsamplers.0.conv']
|
||||
|
|
@ -22,7 +22,7 @@ def sd15_hidiffusion_key():
|
|||
return modified_key
|
||||
|
||||
def sdxl_hidiffusion_key():
|
||||
modified_key = dict()
|
||||
modified_key = {}
|
||||
modified_key['down_module_key'] = ['down_blocks.1']
|
||||
modified_key['down_module_key_extra'] = ['down_blocks.1.downsamplers.0.conv']
|
||||
modified_key['up_module_key'] = ['up_blocks.1']
|
||||
|
|
@ -42,7 +42,7 @@ def sdxl_hidiffusion_key():
|
|||
|
||||
|
||||
def sdxl_turbo_hidiffusion_key():
|
||||
modified_key = dict()
|
||||
modified_key = {}
|
||||
modified_key['down_module_key'] = ['down_blocks.1']
|
||||
modified_key['up_module_key'] = ['up_blocks.1']
|
||||
modified_key['windown_attn_module_key'] = [
|
||||
|
|
|
|||
|
|
@ -164,14 +164,11 @@ def qwen(question: str, image: Image.Image, repo: str = None, system_prompt: str
|
|||
|
||||
def gemma(question: str, image: Image.Image, repo: str = None, system_prompt: str = None):
|
||||
global processor, model, loaded # pylint: disable=global-statement
|
||||
if not hasattr(transformers, 'Gemma3ForConditionalGeneration'):
|
||||
shared.log.error(f'Interrogate: vlm="{repo}" gemma is not available')
|
||||
return ''
|
||||
if model is None or loaded != repo:
|
||||
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
||||
model = None
|
||||
if '3n' in repo:
|
||||
cls = transformers.Gemma3nForConditionalGeneration
|
||||
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
|
||||
else:
|
||||
cls = transformers.Gemma3ForConditionalGeneration
|
||||
model = cls.from_pretrained(
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@ try:
|
|||
import numpy.random # pylint: disable=W0611,C0411 # this causes failure if numpy version changed
|
||||
def obj2sctype(obj):
|
||||
return np.dtype(obj).type
|
||||
np.obj2sctype = obj2sctype # noqa: NPY201
|
||||
np.bool8 = np.bool
|
||||
np.float_ = np.float64 # noqa: NPY201
|
||||
if np.__version__.startswith('2.'): # monkeypatch for np==1.2 compatibility
|
||||
np.obj2sctype = obj2sctype # noqa: NPY201
|
||||
np.bool8 = np.bool
|
||||
np.float_ = np.float64 # noqa: NPY201
|
||||
except Exception as e:
|
||||
errors.log.error(f'Loader: numpy=={np.__version__ if np is not None else None} {e}')
|
||||
errors.log.error('Please restart the app to fix this issue')
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def edge_detect_for_pixelart(image: PipelineImageInput, image_weight: float = 1.
|
|||
block_size_sq = block_size * block_size
|
||||
new_image = process_image_input(image).to(device, dtype=torch.float32) / 255
|
||||
new_image = new_image.permute(0,3,1,2)
|
||||
batch_size, channels, height, width = new_image.shape
|
||||
batch_size, _channels, height, width = new_image.shape
|
||||
|
||||
min_pool = -torch.nn.functional.max_pool2d(-new_image, block_size, 1, block_size//2, 1, False, False)
|
||||
min_pool = min_pool[:, :, :height, :width]
|
||||
|
|
@ -203,6 +203,7 @@ class JPEGEncoder(ImageProcessingMixin, ConfigMixin):
|
|||
self.norm = norm
|
||||
self.latents_std = latents_std
|
||||
self.latents_mean = latents_mean
|
||||
super().__init__()
|
||||
|
||||
|
||||
def encode(self, images: PipelineImageInput, device: str="cpu") -> torch.FloatTensor:
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ def apply_styles_to_extra(p, style: Style):
|
|||
k = name_map[k]
|
||||
if k in name_exclude: # exclude some fields
|
||||
continue
|
||||
elif hasattr(p, k):
|
||||
if hasattr(p, k):
|
||||
orig = getattr(p, k)
|
||||
if (type(orig) != type(v)) and (orig is not None):
|
||||
if not (type(orig) == int and type(v) == float): # dont convert float to int
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ def create_ui(_blocks: gr.Blocks=None):
|
|||
with gr.Row(elem_id='control_status'):
|
||||
result_txt = gr.HTML(elem_classes=['control-result'], elem_id='control-result')
|
||||
|
||||
with gr.Row(elem_id='control_settings'):
|
||||
with gr.Row(elem_id='control_settings', elem_classes=['settings-column']):
|
||||
|
||||
state = gr.Textbox(value='', visible=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -1001,5 +1001,5 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
|
|||
return ui
|
||||
|
||||
|
||||
def setup_ui(ui, gallery):
|
||||
def setup_ui(ui, gallery: gr.Gallery = None):
|
||||
ui.gallery = gallery
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def create_ui():
|
|||
timer.startup.record('ui-networks')
|
||||
|
||||
with gr.Row(elem_id="img2img_interface", equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
||||
with gr.Column(variant='compact', elem_id="img2img_settings", elem_classes=['settings-column']):
|
||||
copy_image_buttons = []
|
||||
copy_image_destinations = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from modules.ui_components import ToolButton
|
|||
from modules.interrogate import interrogate
|
||||
|
||||
|
||||
def create_toprow(is_img2img: bool = False, id_part: str = None, negative_visible: bool = True, reprocess_visible: bool = True):
|
||||
def create_toprow(is_img2img: bool = False, id_part: str = None, generate_visible: bool = True, negative_visible: bool = True, reprocess_visible: bool = True):
|
||||
def apply_styles(prompt, prompt_neg, styles):
|
||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles, wildcards=not shared.opts.extra_networks_apply_unparsed)
|
||||
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles, wildcards=not shared.opts.extra_networks_apply_unparsed)
|
||||
|
|
@ -29,7 +29,7 @@ def create_toprow(is_img2img: bool = False, id_part: str = None, negative_visibl
|
|||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box"):
|
||||
reprocess = []
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary', visible=generate_visible)
|
||||
if reprocess_visible:
|
||||
reprocess.append(gr.Button('Reprocess', elem_id=f"{id_part}_reprocess", variant='primary', visible=True))
|
||||
reprocess.append(gr.Button('Reprocess decode', elem_id=f"{id_part}_reprocess_decode", variant='primary', visible=False))
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def create_ui():
|
|||
timer.startup.record('ui-networks')
|
||||
|
||||
with gr.Row(elem_id="txt2img_interface", equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||
with gr.Column(variant='compact', elem_id="txt2img_settings", elem_classes=['settings-column']):
|
||||
|
||||
with gr.Row():
|
||||
width, height = ui_sections.create_resolution_inputs('txt2img')
|
||||
|
|
|
|||
|
|
@ -1,190 +1,49 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
from modules import shared, sd_models, timer, images, ui_common, ui_sections, ui_symbols, call_queue, generation_parameters_copypaste
|
||||
from modules.ui_components import ToolButton
|
||||
from modules.video_models import models_def, video_utils
|
||||
from modules import shared, timer, images, ui_common, ui_sections, generation_parameters_copypaste
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def engine_change(engine):
|
||||
debug(f'Video change: engine="{engine}"')
|
||||
found = [model.name for model in models_def.models.get(engine, [])]
|
||||
return gr.update(choices=found, value=found[0] if len(found) > 0 else None)
|
||||
|
||||
|
||||
def get_selected(engine, model):
|
||||
found = [model.name for model in models_def.models.get(engine, [])]
|
||||
if len(models_def.models[engine]) > 0 and len(found) > 0:
|
||||
selected = [m for m in models_def.models[engine] if m.name == model][0]
|
||||
return selected
|
||||
return None
|
||||
|
||||
|
||||
def model_change(engine, model):
|
||||
debug(f'Video change: engine="{engine}" model="{model}"')
|
||||
found = [model.name for model in models_def.models.get(engine, [])]
|
||||
selected = [m for m in models_def.models[engine] if m.name == model][0] if len(found) > 0 else None
|
||||
url = video_utils.get_url(selected.url if selected else None)
|
||||
i2v = 'i2v' in selected.name.lower() if selected else False
|
||||
return url, gr.update(visible=i2v)
|
||||
|
||||
|
||||
def model_load(engine, model):
|
||||
debug(f'Video load: engine="{engine}" model="{model}"')
|
||||
selected = get_selected(engine, model)
|
||||
yield f'Video model loading: {selected.name}'
|
||||
if selected:
|
||||
if 'None' in selected.name:
|
||||
sd_models.unload_model_weights()
|
||||
msg = 'Video model unloaded'
|
||||
else:
|
||||
from modules.video_models import video_load
|
||||
msg = video_load.load_model(selected)
|
||||
else:
|
||||
sd_models.unload_model_weights()
|
||||
msg = 'Video model unloaded'
|
||||
yield msg
|
||||
return msg
|
||||
|
||||
|
||||
def run_video(*args):
|
||||
engine, model = args[2], args[3]
|
||||
debug(f'Video run: engine="{engine}" model="{model}"')
|
||||
selected = get_selected(engine, model)
|
||||
if not selected or engine is None or model is None or engine == 'None' or model == 'None':
|
||||
return video_utils.queue_err('model not selected')
|
||||
debug(f'Video run: {str(selected)}')
|
||||
from modules.video_models import video_run
|
||||
if selected and 'Hunyuan' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'LTX' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Mochi' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Cog' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Allegro' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'WAN' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Latte' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'anisora' in selected.name.lower():
|
||||
return video_run.generate(*args)
|
||||
return video_utils.queue_err(f'model not found: engine="{engine}" model="{model}"')
|
||||
|
||||
|
||||
def create_ui():
|
||||
shared.log.debug('UI initialize: video')
|
||||
with gr.Blocks(analytics_enabled=False) as _video_interface:
|
||||
prompt, styles, negative, generate, _reprocess, paste, networks_button, _token_counter, _token_button, _token_counter_negative, _token_button_negative = ui_sections.create_toprow(is_img2img=False, id_part="video", negative_visible=True, reprocess_visible=False)
|
||||
prompt, styles, negative, generate_btn, _reprocess, paste, networks_button, _token_counter, _token_button, _token_counter_negative, _token_button_negative = ui_sections.create_toprow(
|
||||
is_img2img=False,
|
||||
id_part="video",
|
||||
negative_visible=True,
|
||||
reprocess_visible=False,
|
||||
)
|
||||
prompt_image = gr.File(label="", elem_id="video_prompt_image", file_count="single", type="binary", visible=False)
|
||||
prompt_image.change(fn=images.image_data, inputs=[prompt_image], outputs=[prompt, prompt_image])
|
||||
|
||||
with gr.Row(variant='compact', elem_id="video_extra_networks", elem_classes=["extra_networks_root"], visible=False) as extra_networks_ui:
|
||||
from modules import ui_extra_networks
|
||||
extra_networks_ui = ui_extra_networks.create_ui(extra_networks_ui, networks_button, 'video', skip_indexing=shared.opts.extra_network_skip_indexing)
|
||||
ui_extra_networks.setup_ui(extra_networks_ui)
|
||||
timer.startup.record('ui-networks')
|
||||
|
||||
with gr.Row(elem_id="video_interface", equal_height=False):
|
||||
with gr.Column(variant='compact', elem_id="video_settings", scale=1):
|
||||
with gr.Tabs(elem_classes=['video-tabs'], elem_id='video-tabs'):
|
||||
overrides = ui_common.create_override_inputs('video')
|
||||
with gr.Tab('Video', id='video-tab') as video_tab:
|
||||
from modules.video_models import video_ui
|
||||
video_ui.create_ui(prompt, negative, styles, overrides)
|
||||
with gr.Tab('FramePack', id='framepack-tab') as framepack_tab:
|
||||
from modules.framepack import framepack_ui
|
||||
framepack_ui.create_ui(prompt, negative, styles, overrides)
|
||||
|
||||
with gr.Row():
|
||||
engine = gr.Dropdown(label='Engine', choices=list(models_def.models), value='None', elem_id="video_engine")
|
||||
model = gr.Dropdown(label='Model', choices=[''], value=None, elem_id="video_model")
|
||||
btn_load = ToolButton(ui_symbols.loading, elem_id="video_model_load")
|
||||
with gr.Row():
|
||||
url = gr.HTML(label='Model URL', elem_id='video_model_url', value='<br><br>')
|
||||
with gr.Accordion(open=True, label="Size", elem_id='video_size_accordion'):
|
||||
with gr.Row():
|
||||
width, height = ui_sections.create_resolution_inputs('video', default_width=832, default_height=480)
|
||||
with gr.Row():
|
||||
frames = gr.Slider(label='Frames', minimum=1, maximum=1024, step=1, value=15, elem_id="video_frames")
|
||||
seed = gr.Number(label='Initial seed', value=-1, elem_id="video_seed", container=True)
|
||||
random_seed = ToolButton(ui_symbols.random, elem_id="video_random_seed")
|
||||
reuse_seed = ToolButton(ui_symbols.reuse, elem_id="video_reuse_seed")
|
||||
with gr.Accordion(open=True, label="Parameters", elem_id='video_parameters_accordion'):
|
||||
steps, sampler_index = ui_sections.create_sampler_and_steps_selection(None, "video")
|
||||
with gr.Row():
|
||||
sampler_shift = gr.Slider(label='Sampler shift', minimum=-1.0, maximum=20.0, step=0.1, value=-1.0, elem_id="video_scheduler_shift")
|
||||
dynamic_shift = gr.Checkbox(label='Dynamic shift', value=False, elem_id="video_dynamic_shift")
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(label='Guidance scale', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_scale")
|
||||
guidance_true = gr.Slider(label='True guidance', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_true")
|
||||
with gr.Accordion(open=True, label="Decode", elem_id='video_decode_accordion'):
|
||||
with gr.Row():
|
||||
vae_type = gr.Dropdown(label='VAE decode', choices=['Default', 'Tiny', 'Remote'], value='Default', elem_id="video_vae_type")
|
||||
vae_tile_frames = gr.Slider(label='Tile frames', minimum=1, maximum=64, step=1, value=16, elem_id="video_vae_tile_frames")
|
||||
with gr.Accordion(open=False, label="Init image", elem_id='video_init_accordion', visible=False) as init_accordion:
|
||||
init_strength = gr.Slider(label='Init strength', minimum=0.0, maximum=1.0, step=0.01, value=0.5, elem_id="video_denoising_strength")
|
||||
gr.HTML("<br>  Init image")
|
||||
init_image = gr.Image(elem_id="video_image", show_label=False, type="pil", image_mode="RGB", height=512)
|
||||
gr.HTML("<br>  Last image")
|
||||
last_image = gr.Image(elem_id="video_last", show_label=False, type="pil", image_mode="RGB", height=512)
|
||||
with gr.Accordion(open=True, label="Output", elem_id='video_output_accordion'):
|
||||
with gr.Row():
|
||||
save_frames = gr.Checkbox(label='Save image frames', value=False, elem_id="video_save_frames")
|
||||
with gr.Row():
|
||||
video_type, video_duration, video_loop, video_pad, video_interpolate = ui_sections.create_video_inputs(tab='video', show_always=True)
|
||||
override_settings = ui_common.create_override_inputs('video')
|
||||
|
||||
# output panel with gallery and video tabs
|
||||
with gr.Column(elem_id='video-output-column', scale=2) as _column_output:
|
||||
with gr.Tabs(elem_classes=['video-output-tabs'], elem_id='video-output-tabs'):
|
||||
with gr.Tab('Frames', id='out-gallery'):
|
||||
gallery, gen_info, html_info, _html_info_formatted, html_log = ui_common.create_output_panel("video", prompt=prompt, preview=False, transfer=False, scale=2)
|
||||
with gr.Tab('Video', id='out-video'):
|
||||
video = gr.Video(label="Output", show_label=False, elem_id='control_output_video', elem_classes=['control-image'], height=512, autoplay=False)
|
||||
|
||||
# connect reuse seed button
|
||||
ui_common.connect_reuse_seed(seed, reuse_seed, gen_info, is_subseed=False)
|
||||
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
||||
# handle engine and model change
|
||||
engine.change(fn=engine_change, inputs=[engine], outputs=[model])
|
||||
model.change(fn=model_change, inputs=[engine, model], outputs=[url, init_accordion])
|
||||
btn_load.click(fn=model_load, inputs=[engine, model], outputs=[html_log])
|
||||
# setup extra networks
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, gallery)
|
||||
|
||||
# handle restore fields
|
||||
paste_fields = [
|
||||
(prompt, "Prompt"),
|
||||
(width, "Size-1"),
|
||||
(height, "Size-2"),
|
||||
(frames, "Frames"),
|
||||
(prompt, "Prompt"), # cannot add more fields as they are not defined yet
|
||||
]
|
||||
generation_parameters_copypaste.add_paste_fields("video", None, paste_fields, override_settings)
|
||||
generation_parameters_copypaste.add_paste_fields("video", None, paste_fields, overrides)
|
||||
bindings = generation_parameters_copypaste.ParamBinding(paste_button=paste, tabname="video", source_text_component=prompt, source_image_component=None)
|
||||
generation_parameters_copypaste.register_paste_params_button(bindings)
|
||||
# hidden fields
|
||||
task_id = gr.Textbox(visible=False, value='')
|
||||
ui_state = gr.Textbox(visible=False, value='')
|
||||
# generate args
|
||||
video_args = [
|
||||
task_id, ui_state,
|
||||
engine, model,
|
||||
prompt, negative, styles,
|
||||
width, height,
|
||||
frames,
|
||||
steps, sampler_index,
|
||||
sampler_shift, dynamic_shift,
|
||||
seed,
|
||||
guidance_scale, guidance_true,
|
||||
init_image, init_strength, last_image,
|
||||
vae_type, vae_tile_frames,
|
||||
save_frames,
|
||||
video_type, video_duration, video_loop, video_pad, video_interpolate,
|
||||
override_settings,
|
||||
]
|
||||
# generate function
|
||||
video_dict = dict(
|
||||
fn=call_queue.wrap_gradio_gpu_call(run_video, extra_outputs=[None, '', ''], name='Video'),
|
||||
_js="submit_video",
|
||||
inputs=video_args,
|
||||
outputs=[gallery, video, gen_info, html_info, html_log],
|
||||
show_progress=False,
|
||||
)
|
||||
prompt.submit(**video_dict)
|
||||
generate.click(**video_dict)
|
||||
|
||||
current_tab = gr.Textbox(visible=False, value='video')
|
||||
video_tab.select(fn=lambda: 'video', inputs=[], outputs=[current_tab])
|
||||
framepack_tab.select(fn=lambda: 'framepack', inputs=[], outputs=[current_tab])
|
||||
generate_btn.click(fn=None, _js='submit_video_wrapper', inputs=[current_tab], outputs=[])
|
||||
|
||||
# from framepack_api import create_api # pylint: disable=wrong-import-order
|
||||
|
|
|
|||
|
|
@ -92,7 +92,6 @@ class UpscalerLatent(Upscaler):
|
|||
mode, antialias = 'bicubic', True
|
||||
else:
|
||||
raise log.error(f"Upscale: type=latent model={selected_model} unknown")
|
||||
return img
|
||||
return F.interpolate(img, size=(h, w), mode=mode, antialias=antialias)
|
||||
|
||||
|
||||
|
|
@ -170,7 +169,6 @@ class UpscalerVIPS(Upscaler):
|
|||
if selected_model is None:
|
||||
return img
|
||||
from installer import install
|
||||
from modules.shared import log
|
||||
install('pyvips')
|
||||
try:
|
||||
import pyvips
|
||||
|
|
|
|||
|
|
@ -0,0 +1,173 @@
|
|||
import os
|
||||
import gradio as gr
|
||||
from modules import shared, sd_models, ui_common, ui_sections, ui_symbols, call_queue
|
||||
from modules.ui_components import ToolButton
|
||||
from modules.video_models import models_def, video_utils
|
||||
from modules.video_models import video_run
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_VIDEO_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def engine_change(engine):
|
||||
debug(f'Video change: engine="{engine}"')
|
||||
found = [model.name for model in models_def.models.get(engine, [])]
|
||||
return gr.update(choices=found, value=found[0] if len(found) > 0 else None)
|
||||
|
||||
|
||||
def get_selected(engine, model):
|
||||
found = [model.name for model in models_def.models.get(engine, [])]
|
||||
if len(models_def.models[engine]) > 0 and len(found) > 0:
|
||||
selected = [m for m in models_def.models[engine] if m.name == model][0]
|
||||
return selected
|
||||
return None
|
||||
|
||||
|
||||
def model_change(engine, model):
|
||||
debug(f'Video change: engine="{engine}" model="{model}"')
|
||||
found = [model.name for model in models_def.models.get(engine, [])]
|
||||
selected = [m for m in models_def.models[engine] if m.name == model][0] if len(found) > 0 else None
|
||||
url = video_utils.get_url(selected.url if selected else None)
|
||||
i2v = 'i2v' in selected.name.lower() if selected else False
|
||||
return url, gr.update(visible=i2v)
|
||||
|
||||
|
||||
def model_load(engine, model):
|
||||
debug(f'Video load: engine="{engine}" model="{model}"')
|
||||
selected = get_selected(engine, model)
|
||||
yield f'Video model loading: {selected.name}'
|
||||
if selected:
|
||||
if 'None' in selected.name:
|
||||
sd_models.unload_model_weights()
|
||||
msg = 'Video model unloaded'
|
||||
else:
|
||||
from modules.video_models import video_load
|
||||
msg = video_load.load_model(selected)
|
||||
else:
|
||||
sd_models.unload_model_weights()
|
||||
msg = 'Video model unloaded'
|
||||
yield msg
|
||||
return msg
|
||||
|
||||
|
||||
def run_video(*args):
|
||||
engine, model = args[2], args[3]
|
||||
debug(f'Video run: engine="{engine}" model="{model}"')
|
||||
selected = get_selected(engine, model)
|
||||
if not selected or engine is None or model is None or engine == 'None' or model == 'None':
|
||||
return video_utils.queue_err('model not selected')
|
||||
debug(f'Video run: {str(selected)}')
|
||||
if selected and 'Hunyuan' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'LTX' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Mochi' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Cog' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Allegro' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'WAN' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'Latte' in selected.name:
|
||||
return video_run.generate(*args)
|
||||
elif selected and 'anisora' in selected.name.lower():
|
||||
return video_run.generate(*args)
|
||||
return video_utils.queue_err(f'model not found: engine="{engine}" model="{model}"')
|
||||
|
||||
|
||||
def create_ui(prompt, negative, styles, overrides):
|
||||
with gr.Row():
|
||||
with gr.Column(variant='compact', elem_id="video_settings", elem_classes=['settings-column'], scale=1):
|
||||
with gr.Row():
|
||||
generate = gr.Button('Generate', elem_id="video_generate_btn", variant='primary', visible=False)
|
||||
with gr.Row():
|
||||
engine = gr.Dropdown(label='Engine', choices=list(models_def.models), value='None', elem_id="video_engine")
|
||||
model = gr.Dropdown(label='Model', choices=[''], value=None, elem_id="video_model")
|
||||
btn_load = ToolButton(ui_symbols.loading, elem_id="video_model_load")
|
||||
with gr.Row():
|
||||
url = gr.HTML(label='Model URL', elem_id='video_model_url', value='<br><br>')
|
||||
with gr.Accordion(open=True, label="Size", elem_id='video_size_accordion'):
|
||||
with gr.Row():
|
||||
width, height = ui_sections.create_resolution_inputs('video', default_width=832, default_height=480)
|
||||
with gr.Row():
|
||||
frames = gr.Slider(label='Frames', minimum=1, maximum=1024, step=1, value=15, elem_id="video_frames")
|
||||
seed = gr.Number(label='Initial seed', value=-1, elem_id="video_seed", container=True)
|
||||
random_seed = ToolButton(ui_symbols.random, elem_id="video_random_seed")
|
||||
reuse_seed = ToolButton(ui_symbols.reuse, elem_id="video_reuse_seed")
|
||||
with gr.Accordion(open=True, label="Parameters", elem_id='video_parameters_accordion'):
|
||||
steps, sampler_index = ui_sections.create_sampler_and_steps_selection(None, "video")
|
||||
with gr.Row():
|
||||
sampler_shift = gr.Slider(label='Sampler shift', minimum=-1.0, maximum=20.0, step=0.1, value=-1.0, elem_id="video_scheduler_shift")
|
||||
dynamic_shift = gr.Checkbox(label='Dynamic shift', value=False, elem_id="video_dynamic_shift")
|
||||
with gr.Row():
|
||||
guidance_scale = gr.Slider(label='Guidance scale', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_scale")
|
||||
guidance_true = gr.Slider(label='True guidance', minimum=-1.0, maximum=14.0, step=0.1, value=-1.0, elem_id="video_guidance_true")
|
||||
with gr.Accordion(open=True, label="Decode", elem_id='video_decode_accordion'):
|
||||
with gr.Row():
|
||||
vae_type = gr.Dropdown(label='VAE decode', choices=['Default', 'Tiny', 'Remote'], value='Default', elem_id="video_vae_type")
|
||||
vae_tile_frames = gr.Slider(label='Tile frames', minimum=1, maximum=64, step=1, value=16, elem_id="video_vae_tile_frames")
|
||||
with gr.Accordion(open=False, label="Init image", elem_id='video_init_accordion', visible=False) as init_accordion:
|
||||
init_strength = gr.Slider(label='Init strength', minimum=0.0, maximum=1.0, step=0.01, value=0.5, elem_id="video_denoising_strength")
|
||||
gr.HTML("<br>  Init image")
|
||||
init_image = gr.Image(elem_id="video_image", show_label=False, type="pil", image_mode="RGB", width=256, height=256)
|
||||
gr.HTML("<br>  Last image")
|
||||
last_image = gr.Image(elem_id="video_last", show_label=False, type="pil", image_mode="RGB", width=256, height=256)
|
||||
with gr.Accordion(open=True, label="Output", elem_id='video_output_accordion'):
|
||||
with gr.Row():
|
||||
save_frames = gr.Checkbox(label='Save image frames', value=False, elem_id="video_save_frames")
|
||||
with gr.Row():
|
||||
video_type, video_duration, video_loop, video_pad, video_interpolate = ui_sections.create_video_inputs(tab='video', show_always=True)
|
||||
|
||||
# output panel with gallery and video tabs
|
||||
with gr.Column(elem_id='video-output-column', scale=2) as _column_output:
|
||||
with gr.Tabs(elem_classes=['video-output-tabs'], elem_id='video-output-tabs'):
|
||||
with gr.Tab('Frames', id='out-gallery'):
|
||||
gallery, gen_info, html_info, _html_info_formatted, html_log = ui_common.create_output_panel("video", prompt=prompt, preview=False, transfer=False, scale=2)
|
||||
with gr.Tab('Video', id='out-video'):
|
||||
video = gr.Video(label="Output", show_label=False, elem_id='control_output_video', elem_classes=['control-image'], height=512, autoplay=False)
|
||||
|
||||
# connect reuse seed button
|
||||
ui_common.connect_reuse_seed(seed, reuse_seed, gen_info, is_subseed=False)
|
||||
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
||||
# handle engine and model change
|
||||
engine.change(fn=engine_change, inputs=[engine], outputs=[model])
|
||||
model.change(fn=model_change, inputs=[engine, model], outputs=[url, init_accordion])
|
||||
btn_load.click(fn=model_load, inputs=[engine, model], outputs=[html_log])
|
||||
# hidden fields
|
||||
task_id = gr.Textbox(visible=False, value='')
|
||||
ui_state = gr.Textbox(visible=False, value='')
|
||||
state_inputs = [task_id, ui_state]
|
||||
|
||||
# generate args
|
||||
video_inputs = [
|
||||
engine, model,
|
||||
prompt, negative, styles,
|
||||
width, height,
|
||||
frames,
|
||||
steps, sampler_index,
|
||||
sampler_shift, dynamic_shift,
|
||||
seed,
|
||||
guidance_scale, guidance_true,
|
||||
init_image, init_strength, last_image,
|
||||
vae_type, vae_tile_frames,
|
||||
save_frames,
|
||||
video_type, video_duration, video_loop, video_pad, video_interpolate,
|
||||
overrides,
|
||||
]
|
||||
video_outputs = [
|
||||
gallery,
|
||||
video,
|
||||
gen_info,
|
||||
html_info,
|
||||
html_log,
|
||||
]
|
||||
|
||||
video_dict = dict(
|
||||
fn=call_queue.wrap_gradio_gpu_call(video_run.generate, extra_outputs=[None, '', ''], name='Video'),
|
||||
_js="submit_video",
|
||||
inputs=state_inputs + video_inputs,
|
||||
outputs=video_outputs,
|
||||
show_progress=False,
|
||||
)
|
||||
generate.click(**video_dict)
|
||||
|
|
@ -362,7 +362,7 @@ class InstantIRPipeline(
|
|||
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
|
||||
}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
lora_state_dict = dict()
|
||||
lora_state_dict = {}
|
||||
for k, v in unet_state_dict.items():
|
||||
if "ip" in k:
|
||||
k = k.replace("attn2", "attn2.processor")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class ScriptPixelArt(scripts_postprocessing.ScriptPostprocessing):
|
|||
"pixelart_sharpen_amount": pixelart_sharpen_amount,
|
||||
}
|
||||
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, pixelart_enabled: bool, pixelart_use_edge_detection: bool, pixelart_block_size: int, pixelart_edge_block_size: int, pixelart_image_weight: float, pixelart_sharpen_amount: float):
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, pixelart_enabled: bool, pixelart_use_edge_detection: bool, pixelart_block_size: int, pixelart_edge_block_size: int, pixelart_image_weight: float, pixelart_sharpen_amount: float): # pylint: disable=arguments-differ
|
||||
if not pixelart_enabled:
|
||||
return
|
||||
from modules.postprocess.pixelart import img_to_pixelart, edge_detect_for_pixelart
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
try:
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
except:
|
||||
except Exception:
|
||||
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
||||
|
||||
from .transformer import PatchDropout
|
||||
|
|
@ -18,7 +18,7 @@ from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
|||
if os.getenv('ENV_TYPE') == 'deepspeed':
|
||||
try:
|
||||
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
||||
except:
|
||||
except Exception:
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
else:
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
|
@ -27,7 +27,7 @@ try:
|
|||
import xformers
|
||||
import xformers.ops as xops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
except Exception:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
class DropPath(nn.Module):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from torch import nn
|
|||
|
||||
try:
|
||||
from .hf_model import HFTextEncoder
|
||||
except:
|
||||
except Exception:
|
||||
HFTextEncoder = None
|
||||
from .modified_resnet import ModifiedResNet
|
||||
from .timm_model import TimmModel
|
||||
|
|
@ -23,7 +23,7 @@ from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, Tex
|
|||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
except:
|
||||
except Exception:
|
||||
FusedLayerNorm = LayerNorm
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class SimpleTokenizer(object):
|
|||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from torch.nn import functional as F
|
|||
|
||||
try:
|
||||
from timm.models.layers import trunc_normal_
|
||||
except:
|
||||
except Exception:
|
||||
from timm.layers import trunc_normal_
|
||||
|
||||
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
||||
|
|
|
|||
Loading…
Reference in New Issue