mirror of https://github.com/vladmandic/automatic
add GLM-Image pipeline support
- Add GLM-Image (zai-org/GLM-Image) model detection and loading - Custom pipeline loader with proper component handling: - ByT5 text encoder (cannot use shared T5 due to different hidden size) - Vision-language encoder (9B AR model) - DiT transformer (7B) - Fix EOS token early stopping in AR generation - Add AR token generation progress tracking with terminal progress bar - Fix uninitialized audio variable in processing - Add TAESD support for GLM-Image (using f1 variant)pull/4548/head
parent
8500156888
commit
3f259cff9a
|
|
@ -78,6 +78,8 @@ def get_model_type(pipe):
|
|||
model_type = 'prx'
|
||||
elif 'LongCat' in name:
|
||||
model_type = 'longcat'
|
||||
elif 'GlmImage' in name:
|
||||
model_type = 'glm_image'
|
||||
elif 'Ovis-Image' in name:
|
||||
model_type = 'ovis'
|
||||
# video models
|
||||
|
|
|
|||
|
|
@ -400,6 +400,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
infotexts = []
|
||||
output_images = []
|
||||
output_binary = None
|
||||
audio = None
|
||||
|
||||
process_init(p)
|
||||
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):
|
||||
|
|
|
|||
|
|
@ -156,6 +156,8 @@ def task_specific_kwargs(p, model):
|
|||
task_args['reference_images'] = p.init_images
|
||||
if ('GoogleNanoBananaPipeline' in model_cls) and (p.init_images is not None) and (len(p.init_images) > 0):
|
||||
task_args['image'] = p.init_images[0]
|
||||
if ('GlmImagePipeline' in model_cls) and (p.init_images is not None) and (len(p.init_images) > 0):
|
||||
task_args['image'] = p.init_images
|
||||
if 'BlipDiffusionPipeline' in model_cls:
|
||||
if len(p.init_images) == 0:
|
||||
shared.log.error('BLiP diffusion requires init image')
|
||||
|
|
|
|||
|
|
@ -143,6 +143,8 @@ def guess_by_name(fn, current_guess):
|
|||
new_guess = 'LongCat'
|
||||
elif 'ovis-image' in fn.lower():
|
||||
new_guess = 'Ovis-Image'
|
||||
elif 'glm-image' in fn.lower():
|
||||
new_guess = 'GLM-Image'
|
||||
if debug_load:
|
||||
shared.log.trace(f'Autodetect: method=name file="{fn}" previous="{current_guess}" current="{new_guess}"')
|
||||
return new_guess or current_guess
|
||||
|
|
|
|||
|
|
@ -483,6 +483,10 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
|
|||
from pipelines.model_ovis import load_ovis
|
||||
sd_model = load_ovis(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
elif model_type in ['GLM-Image']:
|
||||
from pipelines.model_glm import load_glm_image
|
||||
sd_model = load_glm_image(checkpoint_info, diffusers_load_config)
|
||||
allow_post_quant = False
|
||||
except Exception as e:
|
||||
shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
|
||||
if debug_load:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ prev_cls = ''
|
|||
prev_type = ''
|
||||
prev_model = ''
|
||||
lock = threading.Lock()
|
||||
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat']
|
||||
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat', 'glm_image']
|
||||
|
||||
|
||||
def warn_once(msg, variant=None):
|
||||
|
|
@ -59,7 +59,7 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
model_cls = 'sd'
|
||||
elif model_cls in {'pixartsigma', 'hunyuandit', 'omnigen', 'auraflow'}:
|
||||
model_cls = 'sdxl'
|
||||
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma', 'longcat'}:
|
||||
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma', 'longcat', 'glm_image'}:
|
||||
model_cls = 'f1'
|
||||
elif model_cls in {'wanai', 'qwen', 'chrono'}:
|
||||
variant = variant or 'TAE WanVideo'
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ pipelines = {
|
|||
'HunyuanImage': getattr(diffusers, 'HunyuanImagePipeline', None),
|
||||
'Z-Image': getattr(diffusers, 'ZImagePipeline', None),
|
||||
'LongCat': getattr(diffusers, 'LongCatImagePipeline', None),
|
||||
'GLM-Image': getattr(diffusers, 'GlmImagePipeline', None),
|
||||
# dynamically imported and redefined later
|
||||
'Meissonic': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
'Monetico': getattr(diffusers, 'DiffusionPipeline', None),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
import time
|
||||
import rich.progress as rp
|
||||
import transformers
|
||||
import diffusers
|
||||
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
|
||||
from pipelines import generic
|
||||
|
||||
|
||||
class GLMTokenProgressProcessor(transformers.LogitsProcessor):
|
||||
"""LogitsProcessor that tracks autoregressive token generation progress for GLM-Image."""
|
||||
|
||||
def __init__(self):
|
||||
self.total_tokens = 0
|
||||
self.current_step = 0
|
||||
self.task_id = None
|
||||
self.pbar = None
|
||||
self.pbar_task = None
|
||||
self.start_time = 0
|
||||
|
||||
def set_total(self, total_tokens: int):
|
||||
self.total_tokens = total_tokens
|
||||
self.current_step = 0
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if self.current_step == 0:
|
||||
self.task_id = shared.state.begin('AR Generation')
|
||||
self.start_time = time.time()
|
||||
self.pbar = rp.Progress(
|
||||
rp.TextColumn('[cyan]AR Generation'),
|
||||
rp.TextColumn('{task.fields[speed]}'),
|
||||
rp.BarColumn(bar_width=40, complete_style='#327fba', finished_style='#327fba'),
|
||||
rp.TaskProgressColumn(),
|
||||
rp.MofNCompleteColumn(),
|
||||
rp.TimeElapsedColumn(),
|
||||
rp.TimeRemainingColumn(),
|
||||
console=shared.console,
|
||||
)
|
||||
self.pbar.start()
|
||||
self.pbar_task = self.pbar.add_task(description='', total=self.total_tokens, speed='')
|
||||
self.current_step += 1
|
||||
shared.state.sampling_step = self.current_step
|
||||
shared.state.sampling_steps = self.total_tokens
|
||||
if self.pbar is not None and self.pbar_task is not None:
|
||||
elapsed = time.time() - self.start_time
|
||||
speed = f'{self.current_step / elapsed:.2f}tok/s' if elapsed > 0 else ''
|
||||
self.pbar.update(self.pbar_task, completed=self.current_step, speed=speed)
|
||||
if self.current_step >= self.total_tokens:
|
||||
if self.pbar is not None:
|
||||
self.pbar.stop()
|
||||
self.pbar = None
|
||||
if self.task_id is not None:
|
||||
shared.state.end(self.task_id)
|
||||
self.task_id = None
|
||||
return scores
|
||||
|
||||
|
||||
def _wrap_vision_language_generate(pipe):
|
||||
"""Wrap vision_language_encoder.generate to add progress tracking."""
|
||||
if not hasattr(pipe, 'vision_language_encoder') or pipe.vision_language_encoder is None:
|
||||
return
|
||||
|
||||
original_generate = pipe.vision_language_encoder.generate
|
||||
progress_processor = GLMTokenProgressProcessor()
|
||||
|
||||
def wrapped_generate(*args, **kwargs):
|
||||
# Get max_new_tokens to determine total tokens
|
||||
max_new_tokens = kwargs.get('max_new_tokens', 0)
|
||||
progress_processor.set_total(max_new_tokens)
|
||||
|
||||
# Add progress processor to logits_processor list
|
||||
existing_processors = kwargs.get('logits_processor', None)
|
||||
if existing_processors is None:
|
||||
existing_processors = []
|
||||
elif not isinstance(existing_processors, list):
|
||||
existing_processors = list(existing_processors)
|
||||
kwargs['logits_processor'] = existing_processors + [progress_processor]
|
||||
|
||||
return original_generate(*args, **kwargs)
|
||||
|
||||
pipe.vision_language_encoder.generate = wrapped_generate
|
||||
|
||||
|
||||
def load_glm_image(checkpoint_info, diffusers_load_config=None):
|
||||
if diffusers_load_config is None:
|
||||
diffusers_load_config = {}
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info)
|
||||
sd_models.hf_auth_check(checkpoint_info)
|
||||
|
||||
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
|
||||
shared.log.debug(f'Load model: type=GLM-Image repo="{repo_id}" offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
|
||||
|
||||
# Load transformer (DiT decoder - 7B) with quantization support
|
||||
transformer = generic.load_transformer(
|
||||
repo_id,
|
||||
cls_name=diffusers.GlmImageTransformer2DModel,
|
||||
load_config=diffusers_load_config
|
||||
)
|
||||
|
||||
# Load text encoder (ByT5 for glyph) - cannot use shared T5 as GLM-Image requires specific ByT5 encoder (1472 hidden size)
|
||||
text_encoder = generic.load_text_encoder(
|
||||
repo_id,
|
||||
cls_name=transformers.T5EncoderModel,
|
||||
load_config=diffusers_load_config,
|
||||
allow_shared=False
|
||||
)
|
||||
|
||||
# Load vision-language encoder (AR model - 9B)
|
||||
# Note: This is a conditional generation model, different from typical text encoders
|
||||
vision_language_encoder = generic.load_text_encoder(
|
||||
repo_id,
|
||||
cls_name=transformers.GlmImageForConditionalGeneration,
|
||||
subfolder="vision_language_encoder",
|
||||
load_config=diffusers_load_config,
|
||||
allow_shared=False
|
||||
)
|
||||
|
||||
pipe = diffusers.GlmImagePipeline.from_pretrained(
|
||||
repo_id,
|
||||
cache_dir=shared.opts.diffusers_dir,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
vision_language_encoder=vision_language_encoder,
|
||||
**load_args,
|
||||
)
|
||||
|
||||
pipe.task_args = {
|
||||
'output_type': 'np',
|
||||
'generate_kwargs': {
|
||||
'eos_token_id': None, # Disable EOS early stopping to ensure all required tokens are generated
|
||||
},
|
||||
}
|
||||
|
||||
del transformer, text_encoder, vision_language_encoder
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
_wrap_vision_language_generate(pipe) # Add progress tracking for AR token generation
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
Loading…
Reference in New Issue