mirror of https://github.com/vladmandic/automatic
parent
c92e329234
commit
87bd347116
|
|
@ -2,13 +2,11 @@
|
|||
"""
|
||||
Warnings:
|
||||
- fal/AuraFlow-v0.3: layer_class_name=Linear layer_weight_shape=torch.Size([3072, 2, 1024]) weights_dtype=int8 unsupported
|
||||
- Kwai-Kolors/Kolors-diffusers: `set_input_embeddings` not auto‑handled for ChatGLMModel
|
||||
- kandinsky-community/kandinsky-2-1: `get_input_embeddings` not auto‑handled for MultilingualCLIP
|
||||
- Kwai-Kolors/Kolors-diffusers: set_input_embeddings not autohandled for ChatGLMModel
|
||||
- kandinsky-community/kandinsky-2-1: get_input_embeddings not autohandled for MultilingualCLIP
|
||||
Errors:
|
||||
- kandinsky-community/kandinsky-3: corrupt output
|
||||
- nvidia/Cosmos-Predict2-2B-Text2Image: mat1 and mat2 shapes cannot be multiplied (512x4096 and 1024x2048)
|
||||
- nvidia/Cosmos-Predict2-14B-Text2Image: mat1 and mat2 shapes cannot be multiplied (512x4096 and 1024x5120)
|
||||
- Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers: CUDA error: device-side assert triggered
|
||||
- Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers: CUDA error device-side assert triggered
|
||||
Other:
|
||||
- HiDream-ai/HiDream-I1-Full: very slow at 30+s/it
|
||||
"""
|
||||
|
|
@ -64,7 +62,11 @@ models = {
|
|||
"stabilityai/stable-cascade": {},
|
||||
"nvidia/Cosmos-Predict2-2B-Text2Image": {},
|
||||
"nvidia/Cosmos-Predict2-14B-Text2Image": {},
|
||||
# "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers": {},
|
||||
"black-forest-labs/FLUX.1-dev": {},
|
||||
"black-forest-labs/FLUX.1-Kontext-dev": {},
|
||||
"black-forest-labs/FLUX.1-Krea-dev": {},
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers": {},
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers": {},
|
||||
# "kandinsky-community/kandinsky-3": {},
|
||||
# "HiDream-ai/HiDream-I1-Full": {},
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": {},
|
||||
|
|
@ -76,11 +78,6 @@ models = {
|
|||
"vladmandic/chroma-unlocked-v48": {},
|
||||
"vladmandic/chroma-unlocked-v48-detail-calibrated": {},
|
||||
}
|
||||
models_tbd = [
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
"black-forest-labs/FLUX.1-Kontext-dev",
|
||||
"black-forest-labs/FLUX.1-Krea-dev",
|
||||
]
|
||||
styles = [
|
||||
'Fixed Astronaut',
|
||||
]
|
||||
|
|
@ -115,7 +112,7 @@ def read_history():
|
|||
log.info(f'history: file="{fn}" records={len(history)}')
|
||||
|
||||
|
||||
def write_history(model:str, style:str, image:str='', size:tuple=(0,0), duration:float=0, info:str=''):
|
||||
def write_history(model:str, style:str, image:str='', size:tuple=(0,0), generate:float=0, load:float=0, info:str=''):
|
||||
fn = os.path.join(output_folder, 'history.json')
|
||||
history.append({
|
||||
'model': model,
|
||||
|
|
@ -123,7 +120,8 @@ def write_history(model:str, style:str, image:str='', size:tuple=(0,0), duration
|
|||
'style': style,
|
||||
'image': image,
|
||||
'size': size,
|
||||
'time': duration,
|
||||
'time': generate,
|
||||
'load': load,
|
||||
'info': info,
|
||||
})
|
||||
with open(fn, "w", encoding='utf8') as file:
|
||||
|
|
@ -147,12 +145,17 @@ def request(endpoint: str, dct: dict = None, method: str = 'POST'):
|
|||
return req.json()
|
||||
|
||||
|
||||
def generate(): # pylint: disable=redefined-outer-name
|
||||
idx = 0
|
||||
def main(): # pylint: disable=redefined-outer-name
|
||||
idx_model = 0
|
||||
idx_images = 0
|
||||
t_generate0 = time.time()
|
||||
log.info(f'generate: models={len(models)} styles={len(styles)}')
|
||||
for model, args in models.items():
|
||||
idx += 1
|
||||
t_model0 = time.time()
|
||||
idx_model += 1
|
||||
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
|
||||
log.info(f'model: n={idx+1}/{len(models)} name="{model}"')
|
||||
log.info(f'model: n={idx_model+1}/{len(models)} name="{model}"')
|
||||
idx_style = 0
|
||||
for s, style in enumerate(styles):
|
||||
try:
|
||||
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
|
||||
|
|
@ -160,39 +163,46 @@ def generate(): # pylint: disable=redefined-outer-name
|
|||
fn = os.path.join(output_folder, f'{model_name}__{style_name}.jpg')
|
||||
if os.path.exists(fn):
|
||||
continue
|
||||
t_load0 = time.time()
|
||||
request(f'/sdapi/v1/checkpoint?sd_model_checkpoint={model}', method='POST')
|
||||
loaded = request('/sdapi/v1/checkpoint', method='GET')
|
||||
t_load1 = time.time()
|
||||
if not loaded or not (model in loaded.get('checkpoint') or model in loaded.get('title') or model in loaded.get('name')):
|
||||
log.error(f' model: error="{model}"')
|
||||
continue
|
||||
t0 = time.time()
|
||||
t_style0 = time.time()
|
||||
params = { 'styles': [style] }
|
||||
for k, v in args.items():
|
||||
params[k] = v
|
||||
log.info(f' style: n={s+1}/{len(styles)} name="{style}" args={params} fn="{fn}"')
|
||||
data = request('/sdapi/v1/txt2img', params)
|
||||
t1 = time.time()
|
||||
t_style1 = time.time()
|
||||
if 'images' in data and len(data['images']) > 0:
|
||||
idx_style += 1
|
||||
idx_images += 1
|
||||
b64 = data['images'][0].split(',',1)[0]
|
||||
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
||||
info = data['info']
|
||||
log.info(f' image: size={image.width}x{image.height} time={t1-t0:.2f} info={len(info)}')
|
||||
log.info(f' image: size={image.width}x{image.height} time={t_style1-t_style0:.2f} info={len(info)}')
|
||||
image.save(fn)
|
||||
write_history(model=model, style=style, image=fn, size=image.size, duration=round(t1-t0, 3), info=info)
|
||||
write_history(model=model, style=style, image=fn, size=image.size, generate=round(t_style1-t_style0, 3), load=round(t_load1-t_load0, 3), info=info)
|
||||
else:
|
||||
# write_history(model=model, style=style, duration=round(t1-t0, 3), info='no image')
|
||||
log.error(f' model: error="{model}" style="{style}" no image')
|
||||
except Exception as e:
|
||||
if 'Connection refused' in str(e) or 'RemoteDisconnected' in str(e):
|
||||
log.error('server offline')
|
||||
os._exit(1)
|
||||
# write_history(model=model, style=style, duration=round(t1-t0, 3), info=str(e))
|
||||
log.error(f' model: error="{model}" style="{style}" exception="{e}"')
|
||||
t_model1 = time.time()
|
||||
if idx_style > 0:
|
||||
log.info(f'model: name="{model}" images={idx_style} time={t_model1-t_model0:.2f}')
|
||||
t_generate1 = time.time()
|
||||
if idx_images > 0:
|
||||
log.info(f'generate: models={idx_model} images={idx_images} time={t_generate1-t_generate0:.2f}')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.info('test-all-models')
|
||||
log.info(f'output="{output_folder}" models={len(models)} styles={len(styles)}')
|
||||
log.info(f'output="{output_folder}"')
|
||||
read_history()
|
||||
generate()
|
||||
log.info('done...')
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -429,6 +429,12 @@
|
|||
"preview": "Tencent-Hunyuan--HunyuanDiT-v1.2-Diffusers.jpg",
|
||||
"extras": "sampler: Default, cfg_scale: 2.0"
|
||||
},
|
||||
"Tencent HunyuanDiT 1.1": {
|
||||
"path": "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers",
|
||||
"desc": "Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding.",
|
||||
"preview": "Tencent-Hunyuan--HunyuanDiT-v1.2-Diffusers.jpg",
|
||||
"extras": "sampler: Default, cfg_scale: 2.0"
|
||||
},
|
||||
|
||||
"AlphaVLLM Lumina Next SFT": {
|
||||
"path": "Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
|
|
|
|||
|
|
@ -642,11 +642,15 @@ def get_dit_args(load_config:dict={}, module:str=None, device_map:bool=False, al
|
|||
def do_post_load_quant(sd_model, allow=True):
|
||||
from modules import shared
|
||||
if shared.opts.sdnq_quantize_weights and (shared.opts.sdnq_quantize_mode == 'post' or (allow and shared.opts.sdnq_quantize_mode == 'auto')):
|
||||
shared.log.debug('Load model: post_quant=sdnq')
|
||||
sd_model = sdnq_quantize_weights(sd_model)
|
||||
if len(shared.opts.optimum_quanto_weights) > 0:
|
||||
shared.log.debug('Load model: post_quant=quanto')
|
||||
sd_model = optimum_quanto_weights(sd_model)
|
||||
if shared.opts.torchao_quantization and (shared.opts.torchao_quantization_mode == 'post' or (allow and shared.opts.torchao_quantization_mode == 'auto')):
|
||||
shared.log.debug('Load model: post_quant=torchao')
|
||||
sd_model = torchao_quantization(sd_model)
|
||||
if shared.opts.layerwise_quantization:
|
||||
shared.log.debug('Load model: post_quant=layerwise')
|
||||
apply_layerwise(sd_model)
|
||||
return sd_model
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ debug_load = os.environ.get('SD_LOAD_DEBUG', None)
|
|||
debug_process = log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
diffusers_version = int(diffusers.__version__.split('.')[1])
|
||||
checkpoint_tiles = checkpoint_titles # legacy compatibility
|
||||
allow_post_quant = None
|
||||
pipe_switch_task_exclude = [
|
||||
'AnimateDiffPipeline', 'AnimateDiffSDXLPipeline',
|
||||
'FluxControlPipeline',
|
||||
|
|
@ -275,7 +276,7 @@ def load_diffuser_initial(diffusers_load_config, op='model'):
|
|||
|
||||
def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='model'):
|
||||
sd_model = None
|
||||
allow_post_quant = True
|
||||
global allow_post_quant # pylint: disable=global-statement
|
||||
unload_model_weights(op=op)
|
||||
shared.sd_model = None
|
||||
try:
|
||||
|
|
@ -316,7 +317,8 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
allow_post_quant = True
|
||||
elif model_type in ['FLUX']:
|
||||
from pipelines.model_flux import load_flux
|
||||
sd_model, allow_post_quant = load_flux(checkpoint_info, diffusers_load_config)
|
||||
sd_model = load_flux(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['FLEX']:
|
||||
from pipelines.model_flex import load_flex
|
||||
sd_model = load_flex(checkpoint_info, diffusers_load_config)
|
||||
|
|
@ -398,7 +400,7 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
if debug_load:
|
||||
errors.display(e, 'Load')
|
||||
return None, True
|
||||
return sd_model, allow_post_quant
|
||||
return sd_model
|
||||
|
||||
|
||||
def load_diffuser_folder(model_type, pipeline, checkpoint_info, diffusers_load_config, op='model'):
|
||||
|
|
@ -561,6 +563,8 @@ def set_defaults(sd_model, checkpoint_info):
|
|||
|
||||
|
||||
def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: disable=unused-argument
|
||||
global allow_post_quant # pylint: disable=global-statement
|
||||
allow_post_quant = True # assume default
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
timer.load.record("diffusers")
|
||||
diffusers_load_config = {
|
||||
|
|
@ -589,7 +593,6 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di
|
|||
return
|
||||
|
||||
sd_model = None
|
||||
allow_post_quant = True
|
||||
try:
|
||||
# initial load only
|
||||
if sd_model is None:
|
||||
|
|
@ -621,7 +624,7 @@ def load_diffuser(checkpoint_info=None, op='model', revision=None): # pylint: di
|
|||
|
||||
# load with custom loader
|
||||
if sd_model is None:
|
||||
sd_model, allow_post_quant = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op)
|
||||
sd_model = load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op)
|
||||
if sd_model is not None and not sd_model:
|
||||
shared.log.error(f'Load {op}: type="{model_type}" pipeline="{pipeline}" not loaded')
|
||||
return
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
import diffusers
|
||||
import transformers
|
||||
from modules import devices, model_quant
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer = None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
return transformer
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer, text_encoder_2 = None, None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
try:
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
|
||||
transformer, text_encoder_2 = None, None
|
||||
if debug:
|
||||
errors.display(e, 'FLUX:')
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
|
||||
try:
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": cache_dir,
|
||||
}
|
||||
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = None
|
||||
if 'flux.1-kontext' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
elif 'flux.1-dev' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
elif 'flux.1-schnell' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-fill' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-depth' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'shuttle' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
|
||||
else:
|
||||
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
|
||||
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype)
|
||||
kwargs['transformer'].quantization_method = 'SVDQuant'
|
||||
if shared.opts.nunchaku_attention:
|
||||
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
|
||||
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
|
||||
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype)
|
||||
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
|
||||
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Quantization: {e}')
|
||||
errors.display(e, 'Quantization:')
|
||||
return kwargs
|
||||
|
||||
|
||||
def load_transformer(file_path): # triggered by opts.sd_unet change
|
||||
if file_path is None or not os.path.exists(file_path):
|
||||
return None
|
||||
transformer = None
|
||||
quant = model_quant.get_quant(file_path)
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": shared.opts.hfcache_dir,
|
||||
}
|
||||
if quant is not None and quant != 'none':
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
|
||||
if 'gguf' in file_path.lower():
|
||||
from modules import ggml
|
||||
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant == "fp8":
|
||||
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'qint8', 'qint4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'fp8', 'fp4', 'nf4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif 'nf4' in quant:
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
else:
|
||||
quant_args = model_quant.create_bnb_config({})
|
||||
if quant_args:
|
||||
shared.log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
shared.log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
|
||||
if transformer is None:
|
||||
shared.log.error('Failed to load UNet model')
|
||||
shared.opts.sd_unet = 'Default'
|
||||
return transformer
|
||||
|
||||
|
||||
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
allow_post_quant = False
|
||||
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
shared.log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
|
||||
debug(f'Load model: type=FLUX config={diffusers_load_config}')
|
||||
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
|
||||
# unload current model
|
||||
sd_models.unload_model_weights()
|
||||
shared.sd_model = None
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
|
||||
if shared.opts.teacache_enabled:
|
||||
from modules import teacache
|
||||
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
|
||||
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
|
||||
|
||||
# load overrides if any
|
||||
if shared.opts.sd_unet != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
|
||||
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
|
||||
if transformer is None:
|
||||
shared.opts.sd_unet = 'Default'
|
||||
sd_unet.failed_unet.append(shared.opts.sd_unet)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}")
|
||||
shared.opts.sd_unet = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX UNet:')
|
||||
if shared.opts.sd_text_encoder != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
|
||||
from modules.model_te import load_t5, load_vit_l
|
||||
if 'vit-l' in shared.opts.sd_text_encoder.lower():
|
||||
text_encoder_1 = load_vit_l()
|
||||
else:
|
||||
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load T5: {e}")
|
||||
shared.opts.sd_text_encoder = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX T5:')
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
if os.path.exists(vae_file):
|
||||
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX VAE:')
|
||||
|
||||
# load quantized components if any
|
||||
if prequantized == 'nf4':
|
||||
try:
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX NF4:')
|
||||
if prequantized == 'qint8' or prequantized == 'qint4':
|
||||
try:
|
||||
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
# initialize pipeline with pre-loaded components
|
||||
kwargs = {}
|
||||
if transformer is not None:
|
||||
kwargs['transformer'] = transformer
|
||||
sd_unet.loaded_unet = shared.opts.sd_unet
|
||||
if text_encoder_1 is not None:
|
||||
kwargs['text_encoder'] = text_encoder_1
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if text_encoder_2 is not None:
|
||||
kwargs['text_encoder_2'] = text_encoder_2
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if vae is not None:
|
||||
kwargs['vae'] = vae
|
||||
if repo_id == 'sayakpaul/flux.1-dev-nf4':
|
||||
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
|
||||
if 'Fill' in repo_id:
|
||||
cls = diffusers.FluxFillPipeline
|
||||
elif 'Canny' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Depth' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Kontext' in repo_id:
|
||||
cls = diffusers.FluxKontextPipeline
|
||||
from diffusers import pipelines
|
||||
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
|
||||
|
||||
else:
|
||||
cls = diffusers.FluxPipeline
|
||||
shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
|
||||
for c in kwargs:
|
||||
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
|
||||
shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
|
||||
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
|
||||
try:
|
||||
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
|
||||
shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
|
||||
fn = checkpoint_info.path
|
||||
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
|
||||
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
|
||||
if fn.endswith('.safetensors') and os.path.isfile(fn):
|
||||
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
else:
|
||||
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
|
||||
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
|
||||
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
|
||||
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
||||
|
||||
# release memory
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
for k in kwargs.keys():
|
||||
kwargs[k] = None
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe, allow_post_quant
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from modules import shared, devices
|
||||
|
||||
|
||||
def load_flux_nunchaku(repo_id):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = None
|
||||
transformer = None
|
||||
if 'flux.1-kontext' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
elif 'flux.1-dev' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
elif 'flux.1-schnell' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-fill' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-depth' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'shuttle' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
|
||||
else:
|
||||
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
|
||||
transformer = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype)
|
||||
transformer.quantization_method = 'SVDQuant'
|
||||
if shared.opts.nunchaku_attention:
|
||||
transformer.set_attention_impl("nunchaku-fp16")
|
||||
return transformer
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, model_quant
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
|
@ -2,120 +2,145 @@ import os
|
|||
import json
|
||||
import diffusers
|
||||
import transformers
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
from modules import shared, devices, errors, sd_models, model_quant
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
|
||||
|
||||
|
||||
def load_transformer(repo_id, cls_name, load_config={}, subfolder="transformer", allow_quant=True, variant=None, dtype=None):
|
||||
load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True, allow_quant=allow_quant)
|
||||
quant_type = model_quant.get_quant_type(quant_args)
|
||||
dtype = dtype or devices.dtype
|
||||
transformer = None
|
||||
try:
|
||||
load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True, allow_quant=allow_quant)
|
||||
quant_type = model_quant.get_quant_type(quant_args)
|
||||
dtype = dtype or devices.dtype
|
||||
|
||||
local_file = None
|
||||
if shared.opts.sd_unet is not None and shared.opts.sd_unet != 'Default':
|
||||
from modules import sd_unet
|
||||
if shared.opts.sd_unet not in list(sd_unet.unet_dict):
|
||||
shared.log.error(f'Load module: type=transformer file="{shared.opts.sd_unet}" not found')
|
||||
elif os.path.exists(sd_unet.unet_dict[shared.opts.sd_unet]):
|
||||
local_file = sd_unet.unet_dict[shared.opts.sd_unet]
|
||||
local_file = None
|
||||
if shared.opts.sd_unet is not None and shared.opts.sd_unet != 'Default':
|
||||
from modules import sd_unet
|
||||
if shared.opts.sd_unet not in list(sd_unet.unet_dict):
|
||||
shared.log.error(f'Load module: type=transformer file="{shared.opts.sd_unet}" not found')
|
||||
elif os.path.exists(sd_unet.unet_dict[shared.opts.sd_unet]):
|
||||
local_file = sd_unet.unet_dict[shared.opts.sd_unet]
|
||||
|
||||
if local_file is not None and local_file.lower().endswith('.gguf'):
|
||||
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
|
||||
from modules import ggml
|
||||
ggml.install_gguf()
|
||||
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
|
||||
transformer = loader(
|
||||
local_file,
|
||||
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
|
||||
elif local_file is not None and local_file.lower().endswith('.safetensors'):
|
||||
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
|
||||
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
|
||||
transformer = loader(
|
||||
local_file,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
|
||||
else:
|
||||
shared.log.debug(f'Load model: transformer="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
|
||||
if dtype is not None:
|
||||
load_args['torch_dtype'] = dtype
|
||||
if subfolder is not None:
|
||||
load_args['subfolder'] = subfolder
|
||||
if variant is not None:
|
||||
load_args['variant'] = variant
|
||||
transformer = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
**quant_args,
|
||||
)
|
||||
if shared.opts.diffusers_offload_mode != 'none' and transformer is not None:
|
||||
sd_models.move_model(transformer, devices.cpu)
|
||||
if local_file is not None and local_file.lower().endswith('.gguf'):
|
||||
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
|
||||
from modules import ggml
|
||||
ggml.install_gguf()
|
||||
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
|
||||
transformer = loader(
|
||||
local_file,
|
||||
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
|
||||
elif local_file is not None and local_file.lower().endswith('.safetensors'):
|
||||
shared.log.debug(f'Load model: transformer="{local_file}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
|
||||
loader = cls_name.from_single_file if hasattr(cls_name, 'from_single_file') else cls_name.from_pretrained
|
||||
transformer = loader(
|
||||
local_file,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
transformer = model_quant.do_post_load_quant(transformer, allow=quant_type is not None)
|
||||
else:
|
||||
shared.log.debug(f'Load model: transformer="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" args={load_args}')
|
||||
if dtype is not None:
|
||||
load_args['torch_dtype'] = dtype
|
||||
if subfolder is not None:
|
||||
load_args['subfolder'] = subfolder
|
||||
if variant is not None:
|
||||
load_args['variant'] = variant
|
||||
transformer = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
**quant_args,
|
||||
)
|
||||
sd_models.allow_post_quant = False # we already handled it
|
||||
if shared.opts.diffusers_offload_mode != 'none' and transformer is not None:
|
||||
sd_models.move_model(transformer, devices.cpu)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load model: type=transformer {e}')
|
||||
if debug:
|
||||
errors.display(e, 'Load:')
|
||||
raise
|
||||
return transformer
|
||||
|
||||
|
||||
def load_text_encoder(repo_id, cls_name, load_config={}, subfolder="text_encoder", allow_quant=True, allow_shared=True, variant=None, dtype=None):
|
||||
load_args, quant_args = model_quant.get_dit_args(load_config, module='TE', device_map=True, allow_quant=allow_quant)
|
||||
quant_type = model_quant.get_quant_type(quant_args)
|
||||
text_encoder = None
|
||||
dtype = dtype or devices.dtype
|
||||
try:
|
||||
load_args, quant_args = model_quant.get_dit_args(load_config, module='TE', device_map=True, allow_quant=allow_quant)
|
||||
quant_type = model_quant.get_quant_type(quant_args)
|
||||
dtype = dtype or devices.dtype
|
||||
|
||||
# load from local file if specified
|
||||
local_file = None
|
||||
if shared.opts.sd_text_encoder is not None and shared.opts.sd_text_encoder != 'Default':
|
||||
from modules import model_te
|
||||
if shared.opts.sd_text_encoder not in list(model_te.te_dict):
|
||||
shared.log.error(f'Load module: type=te file="{shared.opts.sd_text_encoder}" not found')
|
||||
elif os.path.exists(model_te.te_dict[shared.opts.sd_text_encoder]):
|
||||
local_file = model_te.te_dict[shared.opts.sd_text_encoder]
|
||||
# load from local file if specified
|
||||
local_file = None
|
||||
if shared.opts.sd_text_encoder is not None and shared.opts.sd_text_encoder != 'Default':
|
||||
from modules import model_te
|
||||
if shared.opts.sd_text_encoder not in list(model_te.te_dict):
|
||||
shared.log.error(f'Load module: type=te file="{shared.opts.sd_text_encoder}" not found')
|
||||
elif os.path.exists(model_te.te_dict[shared.opts.sd_text_encoder]):
|
||||
local_file = model_te.te_dict[shared.opts.sd_text_encoder]
|
||||
|
||||
# load from local file gguf
|
||||
if local_file is not None and local_file.lower().endswith('.gguf'):
|
||||
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
|
||||
from modules import ggml
|
||||
ggml.install_gguf()
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
gguf_file=local_file,
|
||||
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
|
||||
# load from local file safetensors
|
||||
elif local_file is not None and local_file.lower().endswith('.safetensors'):
|
||||
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
local_file,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
|
||||
# use shared t5 if possible
|
||||
elif cls_name == transformers.T5EncoderModel and allow_shared:
|
||||
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
|
||||
load_args['config'] = transformers.T5Config(**json.load(f))
|
||||
if model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
cls_name = nunchaku.NunchakuT5EncoderModel
|
||||
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant"')
|
||||
text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained(
|
||||
repo_id,
|
||||
torch_dtype=dtype,
|
||||
# load from local file gguf
|
||||
if local_file is not None and local_file.lower().endswith('.gguf'):
|
||||
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
|
||||
from modules import ggml
|
||||
ggml.install_gguf()
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
gguf_file=local_file,
|
||||
quantization_config=diffusers.GGUFQuantizationConfig(compute_dtype=dtype),
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
text_encoder.quantization_method = 'SVDQuant'
|
||||
elif shared.opts.te_shared_t5:
|
||||
repo_id = 'Disty0/t5-xxl'
|
||||
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
|
||||
# load from local file safetensors
|
||||
elif local_file is not None and local_file.lower().endswith('.safetensors'):
|
||||
shared.log.debug(f'Load model: text_encoder="{local_file}" cls={cls_name.__name__} quant="{quant_type}"')
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
local_file,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
)
|
||||
text_encoder = model_quant.do_post_load_quant(text_encoder, allow=quant_type is not None)
|
||||
# use shared t5 if possible
|
||||
elif cls_name == transformers.T5EncoderModel and allow_shared:
|
||||
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
|
||||
load_args['config'] = transformers.T5Config(**json.load(f))
|
||||
if model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
repo_id = 'nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
cls_name = nunchaku.NunchakuT5EncoderModel
|
||||
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="SVDQuant"')
|
||||
text_encoder = nunchaku.NunchakuT5EncoderModel.from_pretrained(
|
||||
repo_id,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
text_encoder.quantization_method = 'SVDQuant'
|
||||
elif shared.opts.te_shared_t5:
|
||||
repo_id = 'Disty0/t5-xxl'
|
||||
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}')
|
||||
if dtype is not None:
|
||||
load_args['torch_dtype'] = dtype
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
**quant_args,
|
||||
)
|
||||
|
||||
# load from repo
|
||||
if text_encoder is None:
|
||||
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}')
|
||||
if dtype is not None:
|
||||
load_args['torch_dtype'] = dtype
|
||||
if subfolder is not None:
|
||||
load_args['subfolder'] = subfolder
|
||||
if variant is not None:
|
||||
load_args['variant'] = variant
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
|
|
@ -123,22 +148,12 @@ def load_text_encoder(repo_id, cls_name, load_config={}, subfolder="text_encoder
|
|||
**quant_args,
|
||||
)
|
||||
|
||||
# load from repo
|
||||
if text_encoder is None:
|
||||
shared.log.debug(f'Load model: text_encoder="{repo_id}" cls={cls_name.__name__} quant="{quant_type}" shared={shared.opts.te_shared_t5}')
|
||||
if dtype is not None:
|
||||
load_args['torch_dtype'] = dtype
|
||||
if subfolder is not None:
|
||||
load_args['subfolder'] = subfolder
|
||||
if variant is not None:
|
||||
load_args['variant'] = variant
|
||||
text_encoder = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.hfcache_dir,
|
||||
**load_args,
|
||||
**quant_args,
|
||||
)
|
||||
|
||||
if shared.opts.diffusers_offload_mode != 'none' and text_encoder is not None:
|
||||
sd_models.move_model(text_encoder, devices.cpu)
|
||||
sd_models.allow_post_quant = False # we already handled it
|
||||
if shared.opts.diffusers_offload_mode != 'none' and text_encoder is not None:
|
||||
sd_models.move_model(text_encoder, devices.cpu)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load model: type=te {e}')
|
||||
if debug:
|
||||
errors.display(e, 'Load:')
|
||||
raise
|
||||
return text_encoder
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ sys.path.append("./")
|
|||
# import torch
|
||||
# from torchvision import transforms
|
||||
from meissonic.transformer import Transformer2DModel as TransformerMeissonic
|
||||
from meissonic.pipeline import Pipeline as PipelineMeissonic
|
||||
from meissonic.pipeline import MeissonicPipeline
|
||||
from meissonic.scheduler import Scheduler as MeissonicScheduler
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from diffusers import VQModel
|
||||
|
|
@ -21,7 +21,7 @@ vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", cache_dir=cach
|
|||
text_encoder = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", cache_dir=cache_dir)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
||||
scheduler = MeissonicScheduler.from_pretrained(model_path, subfolder="scheduler")
|
||||
pipe = PipelineMeissonic(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler)
|
||||
pipe = MeissonicPipeline(vq_model, tokenizer=tokenizer, text_encoder=text_encoder, transformer=model, scheduler=scheduler)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
steps = 64
|
||||
|
|
|
|||
|
|
@ -1,13 +1,9 @@
|
|||
import os
|
||||
import diffusers
|
||||
import transformers
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
from pipelines import generic
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_chroma(checkpoint_info, diffusers_load_config={}):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
|
|
|||
|
|
@ -1,360 +1,76 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared, errors, devices, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
from pipelines import generic
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_flux_quanto(checkpoint_info):
|
||||
transformer, text_encoder_2 = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=FLUX')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.FluxTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=FLUX Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder_2", "quantization_map.json")
|
||||
debug(f'Load model: type=FLUX quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder_2"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder_2', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder_2", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder_2", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder_2 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder_2, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
text_encoder_2_dtype = text_encoder_2.dtype
|
||||
if text_encoder_2_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder_2 = text_encoder_2.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=FLUX Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_2_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_flux_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer, text_encoder_2 = None, None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=FLUX')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
try:
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=devices.dtype)
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load BnB transformer: {e}")
|
||||
transformer, text_encoder_2 = None, None
|
||||
if debug:
|
||||
errors.display(e, 'FLUX:')
|
||||
return transformer, text_encoder_2
|
||||
|
||||
|
||||
def load_quants(kwargs, repo_id, cache_dir, allow_quant): # pylint: disable=unused-argument
|
||||
try:
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": cache_dir,
|
||||
}
|
||||
if 'transformer' not in kwargs and model_quant.check_nunchaku('Model'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = None
|
||||
if 'flux.1-kontext' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{nunchaku_precision}_r32-flux.1-kontext-dev.safetensors"
|
||||
elif 'flux.1-dev' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-dev/svdq-{nunchaku_precision}_r32-flux.1-dev.safetensors"
|
||||
elif 'flux.1-schnell' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-fill' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-fp4-flux.1-fill-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'flux.1-depth' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/svdq-int4-flux.1-depth-dev/svdq-{nunchaku_precision}_r32-flux.1-schnell.safetensors"
|
||||
elif 'shuttle' in repo_id.lower():
|
||||
nunchaku_repo = f"mit-han-lab/nunchaku-shuttle-jaguar/svdq-{nunchaku_precision}_r32-shuttle-jaguar.safetensors"
|
||||
else:
|
||||
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
|
||||
if nunchaku_repo is not None:
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
|
||||
kwargs['transformer'] = nunchaku.NunchakuFluxTransformer2dModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype)
|
||||
kwargs['transformer'].quantization_method = 'SVDQuant'
|
||||
if shared.opts.nunchaku_attention:
|
||||
kwargs['transformer'].set_attention_impl("nunchaku-fp16")
|
||||
if 'transformer' not in kwargs and model_quant.check_quant('Model'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
kwargs['transformer'] = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", **load_args, **quant_args)
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
|
||||
kwargs['text_encoder_2'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype)
|
||||
kwargs['text_encoder_2'].quantization_method = 'SVDQuant'
|
||||
if 'text_encoder_2' not in kwargs and model_quant.check_quant('TE'):
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='TE', device_map=True)
|
||||
kwargs['text_encoder_2'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", **load_args, **quant_args)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Quantization: {e}')
|
||||
errors.display(e, 'Quantization:')
|
||||
return kwargs
|
||||
|
||||
|
||||
def load_transformer(file_path): # triggered by opts.sd_unet change
|
||||
if file_path is None or not os.path.exists(file_path):
|
||||
return None
|
||||
transformer = None
|
||||
quant = model_quant.get_quant(file_path)
|
||||
diffusers_load_config = {
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": shared.opts.hfcache_dir,
|
||||
}
|
||||
if quant is not None and quant != 'none':
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
|
||||
if 'gguf' in file_path.lower():
|
||||
from modules import ggml
|
||||
_transformer = ggml.load_gguf(file_path, cls=diffusers.FluxTransformer2DModel, compute_dtype=devices.dtype)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant == "fp8":
|
||||
_transformer = model_quant.load_fp8_model_layerwise(file_path, diffusers.FluxTransformer2DModel.from_single_file, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'qint8', 'qint4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_quanto(file_path)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant in {'fp8', 'fp4', 'nf4'}:
|
||||
_transformer, _text_encoder_2 = load_flux_bnb(file_path, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif 'nf4' in quant:
|
||||
from pipelines.model_flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=True)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
else:
|
||||
quant_args = model_quant.create_bnb_config({})
|
||||
if quant_args:
|
||||
shared.log.info(f'Load module: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
|
||||
from pipelines.model_flux_nf4 import load_flux_nf4
|
||||
transformer, _text_encoder_2 = load_flux_nf4(file_path, prequantized=False)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
load_args, quant_args = model_quant.get_dit_args(diffusers_load_config, module='Model', device_map=True)
|
||||
shared.log.debug(f'Load model: type=Flux transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} args={load_args}')
|
||||
transformer = diffusers.FluxTransformer2DModel.from_single_file(file_path, **load_args, **quant_args)
|
||||
if transformer is None:
|
||||
shared.log.error('Failed to load UNet model')
|
||||
shared.opts.sd_unet = 'Default'
|
||||
return transformer
|
||||
|
||||
|
||||
def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
|
||||
def load_flux(checkpoint_info, diffusers_load_config={}):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
allow_post_quant = False
|
||||
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
shared.log.debug(f'Load model: type=FLUX model="{checkpoint_info.name}" repo="{repo_id}" unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
|
||||
debug(f'Load model: type=FLUX config={diffusers_load_config}')
|
||||
if 'Fill' in repo_id:
|
||||
cls_name = diffusers.FluxFillPipeline
|
||||
elif 'Canny' in repo_id:
|
||||
cls_name = diffusers.FluxControlPipeline
|
||||
elif 'Depth' in repo_id:
|
||||
cls_name = diffusers.FluxControlPipeline
|
||||
elif 'Kontext' in repo_id:
|
||||
cls_name = diffusers.FluxKontextPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
|
||||
else:
|
||||
cls_name = diffusers.FluxPipeline
|
||||
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
shared.log.debug(f'Load model: type=Flux repo="{repo_id}" cls={cls_name.__name__} config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
# unload current model
|
||||
sd_models.unload_model_weights()
|
||||
shared.sd_model = None
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
|
||||
if shared.opts.teacache_enabled:
|
||||
# optional teacache patch
|
||||
if shared.opts.teacache_enabled and not model_quant.check_nunchaku('Model'):
|
||||
from modules import teacache
|
||||
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.FluxTransformer2DModel.__name__}')
|
||||
diffusers.FluxTransformer2DModel.forward = teacache.teacache_flux_forward # patch must be done before transformer is loaded
|
||||
|
||||
# load overrides if any
|
||||
if shared.opts.sd_unet != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX unet="{shared.opts.sd_unet}"')
|
||||
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
|
||||
if transformer is None:
|
||||
shared.opts.sd_unet = 'Default'
|
||||
sd_unet.failed_unet.append(shared.opts.sd_unet)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load UNet: {e}")
|
||||
shared.opts.sd_unet = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX UNet:')
|
||||
if shared.opts.sd_text_encoder != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX te="{shared.opts.sd_text_encoder}"')
|
||||
from modules.model_te import load_t5, load_vit_l
|
||||
if 'vit-l' in shared.opts.sd_text_encoder.lower():
|
||||
text_encoder_1 = load_vit_l()
|
||||
else:
|
||||
text_encoder_2 = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load T5: {e}")
|
||||
shared.opts.sd_text_encoder = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX T5:')
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=FLUX vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
if os.path.exists(vae_file):
|
||||
vae_config = os.path.join('configs', 'flux', 'vae', 'config.json')
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'FLUX VAE:')
|
||||
transformer = None
|
||||
text_encoder_2 = None
|
||||
|
||||
# load quantized components if any
|
||||
if prequantized == 'nf4':
|
||||
try:
|
||||
from pipelines.model_flux_nf4 import load_flux_nf4
|
||||
_transformer, _text_encoder = load_flux_nf4(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load NF4 components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX NF4:')
|
||||
if prequantized == 'qint8' or prequantized == 'qint4':
|
||||
try:
|
||||
_transformer, _text_encoder = load_flux_quanto(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder_2 = _text_encoder
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=FLUX failed to load Quanto components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'FLUX Quanto:')
|
||||
# handle transformer svdquant if available, t5 is handled inside load_text_encoder
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
if model_quant.check_nunchaku('Model'):
|
||||
from pipelines.flux.flux_nunchaku import load_flux_nunchaku
|
||||
transformer = load_flux_nunchaku(repo_id)
|
||||
# handle prequantized models
|
||||
elif prequantized == 'nf4':
|
||||
from pipelines.flux.flux_nf4 import load_flux_nf4
|
||||
transformer, text_encoder_2 = load_flux_nf4(checkpoint_info)
|
||||
elif prequantized == 'qint8' or prequantized == 'qint4':
|
||||
from pipelines.flux.flux_quanto import load_flux_quanto
|
||||
transformer, text_encoder_2 = load_flux_quanto(checkpoint_info)
|
||||
elif prequantized == 'fp4' or prequantized == 'fp8':
|
||||
from pipelines.flux.flux_bnb import load_flux_bnb
|
||||
transformer = load_flux_bnb(checkpoint_info, diffusers_load_config)
|
||||
|
||||
# initialize pipeline with pre-loaded components
|
||||
kwargs = {}
|
||||
if transformer is not None:
|
||||
kwargs['transformer'] = transformer
|
||||
sd_unet.loaded_unet = shared.opts.sd_unet
|
||||
if text_encoder_1 is not None:
|
||||
kwargs['text_encoder'] = text_encoder_1
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if text_encoder_2 is not None:
|
||||
kwargs['text_encoder_2'] = text_encoder_2
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if vae is not None:
|
||||
kwargs['vae'] = vae
|
||||
if repo_id == 'sayakpaul/flux.1-dev-nf4':
|
||||
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
|
||||
if 'Fill' in repo_id:
|
||||
cls = diffusers.FluxFillPipeline
|
||||
elif 'Canny' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Depth' in repo_id:
|
||||
cls = diffusers.FluxControlPipeline
|
||||
elif 'Kontext' in repo_id:
|
||||
cls = diffusers.FluxKontextPipeline
|
||||
from diffusers import pipelines
|
||||
pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextPipeline
|
||||
pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["flux1kontext"] = diffusers.FluxKontextInpaintPipeline
|
||||
# finally load transformer and text encoder if not already loaded
|
||||
if transformer is None:
|
||||
transformer = generic.load_transformer(repo_id, cls_name=diffusers.FluxTransformer2DModel, load_config=diffusers_load_config)
|
||||
if text_encoder_2 is None:
|
||||
text_encoder_2 = generic.load_text_encoder(repo_id, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config)
|
||||
|
||||
else:
|
||||
cls = diffusers.FluxPipeline
|
||||
shared.log.debug(f'Load model: type=FLUX cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
|
||||
for c in kwargs:
|
||||
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
|
||||
shared.log.debug(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
|
||||
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
|
||||
try:
|
||||
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
|
||||
shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
|
||||
except Exception:
|
||||
pass
|
||||
pipe = cls_name.from_pretrained(
|
||||
repo_id,
|
||||
transformer=transformer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
**load_args,
|
||||
)
|
||||
|
||||
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
|
||||
fn = checkpoint_info.path
|
||||
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
|
||||
kwargs = load_quants(kwargs, repo_id, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
|
||||
if fn.endswith('.safetensors') and os.path.isfile(fn):
|
||||
pipe = cls.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
allow_post_quant = True
|
||||
else:
|
||||
pipe = cls.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
del text_encoder_2
|
||||
del transformer
|
||||
|
||||
# optional first-block patch
|
||||
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Model'):
|
||||
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
|
||||
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
||||
|
||||
# release memory
|
||||
transformer = None
|
||||
text_encoder_1 = None
|
||||
text_encoder_2 = None
|
||||
vae = None
|
||||
for k in kwargs.keys():
|
||||
kwargs[k] = None
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe, allow_post_quant
|
||||
return pipe
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ def load_hunyuandit(checkpoint_info, diffusers_load_config={}):
|
|||
shared.log.debug(f'Load model: type=HunyuanDiT repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
transformer = generic.load_transformer(repo_id, cls_name=diffusers.HunyuanDiT2DModel, load_config=diffusers_load_config)
|
||||
text_encoder_2 = generic.load_text_encoder(repo_id, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config, subfolder="text_encoder_2")
|
||||
repo_te = 'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers' if 'HunyuanDiT-v1' in repo_id else repo_id
|
||||
text_encoder_2 = generic.load_text_encoder(repo_te, cls_name=transformers.T5EncoderModel, load_config=diffusers_load_config, subfolder="text_encoder_2", allow_shared=False) # this is not normal t5
|
||||
|
||||
pipe = diffusers.HunyuanDiTPipeline.from_pretrained(
|
||||
repo_id,
|
||||
|
|
|
|||
|
|
@ -2,21 +2,14 @@ import torch
|
|||
import diffusers
|
||||
|
||||
|
||||
repo_id = 'Kwai-Kolors/Kolors-diffusers'
|
||||
|
||||
|
||||
def load_kolors(_checkpoint_info, diffusers_load_config={}):
|
||||
from modules import shared, devices
|
||||
diffusers_load_config['variant'] = "fp16"
|
||||
if 'torch_dtype' not in diffusers_load_config:
|
||||
diffusers_load_config['torch_dtype'] = torch.float16
|
||||
|
||||
# import torch
|
||||
# import transformers
|
||||
# encoder_id = 'THUDM/chatglm3-6b'
|
||||
# text_encoder = transformers.AutoModel.from_pretrained(encoder_id, torch_dtype=torch.float16, trust_remote_code=True, cache_dir=shared.opts.diffusers_dir)
|
||||
# text_encoder = transformers.AutoModel.from_pretrained("THUDM/chatglm3-6b", torch_dtype=torch.float16, trust_remote_code=True).quantize(4).cuda()
|
||||
# tokenizer = transformers.AutoTokenizer.from_pretrained(encoder_id, trust_remote_code=True, cache_dir=shared.opts.diffusers_dir)
|
||||
repo_id = 'Kwai-Kolors/Kolors-diffusers'
|
||||
shared.log.debug(f'Load model: type=Kolors repo="{repo_id}" config={diffusers_load_config} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={diffusers_load_config}')
|
||||
pipe = diffusers.KolorsPipeline.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir = shared.opts.diffusers_dir,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
import transformers
|
||||
import diffusers
|
||||
from modules import errors, shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant
|
||||
from modules import shared, sd_models, sd_unet, sd_hijack_te, devices, modelloader, model_quant
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_lumina(_checkpoint_info, diffusers_load_config={}):
|
||||
|
|
@ -30,7 +30,6 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}):
|
|||
load_config, quant_config = model_quant.get_dit_args(diffusers_load_config, module='Model')
|
||||
if shared.opts.sd_unet != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=Lumina2 unet="{shared.opts.sd_unet}"')
|
||||
transformer = diffusers.Lumina2Transformer2DModel.from_single_file(
|
||||
sd_unet.unet_dict[shared.opts.sd_unet],
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
|
|
@ -43,12 +42,9 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}):
|
|||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Lumina2 failed to load UNet: {e}")
|
||||
shared.opts.sd_unet = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'Lumina2 UNet:')
|
||||
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=Lumina2 vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
|
|
@ -58,8 +54,6 @@ def load_lumina2(checkpoint_info, diffusers_load_config={}):
|
|||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Lumina2 failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'Lumina2 VAE:')
|
||||
|
||||
if transformer is None:
|
||||
transformer = diffusers.Lumina2Transformer2DModel.from_pretrained(
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}):
|
|||
from modules import shared, devices, modelloader, sd_models, shared_items
|
||||
from pipelines.meissonic.transformer import Transformer2DModel as TransformerMeissonic
|
||||
from pipelines.meissonic.scheduler import Scheduler as MeissonicScheduler
|
||||
from pipelines.meissonic.pipeline import Pipeline as PipelineMeissonic
|
||||
from pipelines.meissonic.pipeline_img2img import Img2ImgPipeline as PipelineMeissonicImg2Img
|
||||
from pipelines.meissonic.pipeline_inpaint import InpaintPipeline as PipelineMeissonicInpaint
|
||||
shared_items.pipelines['Meissonic'] = PipelineMeissonic
|
||||
from pipelines.meissonic.pipeline import MeissonicPipeline
|
||||
from pipelines.meissonic.pipeline_img2img import MeissonicImg2ImgPipeline
|
||||
from pipelines.meissonic.pipeline_inpaint import MeissonicInpaintPipeline
|
||||
shared_items.pipelines['Meissonic'] = MeissonicPipeline
|
||||
|
||||
modelloader.hf_login()
|
||||
fn = sd_models.path_to_repo(checkpoint_info)
|
||||
|
|
@ -41,7 +41,7 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}):
|
|||
cache_dir=cache_dir,
|
||||
)
|
||||
scheduler = MeissonicScheduler.from_pretrained(fn, subfolder="scheduler", cache_dir=cache_dir)
|
||||
pipe = PipelineMeissonic(
|
||||
pipe = MeissonicPipeline(
|
||||
vqvae=vqvae.to(devices.dtype),
|
||||
text_encoder=text_encoder.to(devices.dtype),
|
||||
transformer=model.to(devices.dtype),
|
||||
|
|
@ -49,8 +49,8 @@ def load_meissonic(checkpoint_info, diffusers_load_config={}):
|
|||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonic
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicImg2Img
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = PipelineMeissonicInpaint
|
||||
diffusers.pipelines.auto_pipeline.AUTO_TEXT2IMAGE_PIPELINES_MAPPING["meissonic"] = MeissonicPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["meissonic"] = MeissonicImg2ImgPipeline
|
||||
diffusers.pipelines.auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING["meissonic"] = MeissonicInpaintPipeline
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
import os
|
||||
import diffusers
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_omnigen(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
import os
|
||||
from modules import shared, devices, sd_models, model_quant
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_omnigen2(checkpoint_info, diffusers_load_config={}): # pylint: disable=unused-argument
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
|
|
|
|||
Loading…
Reference in New Issue