add GLM-Image pipeline support

- Add GLM-Image (zai-org/GLM-Image) model detection and loading
- Custom pipeline loader with proper component handling:
  - ByT5 text encoder (cannot use shared T5 due to different hidden size)
  - Vision-language encoder (9B AR model)
  - DiT transformer (7B)
- Fix EOS token early stopping in AR generation
- Add AR token generation progress tracking with terminal progress bar
- Fix uninitialized audio variable in processing
- Add TAESD support for GLM-Image (using f1 variant)
pull/4548/head
CalamitousFelicitousness 2026-01-14 03:33:49 +00:00
parent 8500156888
commit 3f259cff9a
8 changed files with 151 additions and 2 deletions

View File

@ -78,6 +78,8 @@ def get_model_type(pipe):
model_type = 'prx'
elif 'LongCat' in name:
model_type = 'longcat'
elif 'GlmImage' in name:
model_type = 'glm_image'
elif 'Ovis-Image' in name:
model_type = 'ovis'
# video models

View File

@ -400,6 +400,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
output_binary = None
audio = None
process_init(p)
if p.scripts is not None and isinstance(p.scripts, scripts_manager.ScriptRunner):

View File

@ -156,6 +156,8 @@ def task_specific_kwargs(p, model):
task_args['reference_images'] = p.init_images
if ('GoogleNanoBananaPipeline' in model_cls) and (p.init_images is not None) and (len(p.init_images) > 0):
task_args['image'] = p.init_images[0]
if ('GlmImagePipeline' in model_cls) and (p.init_images is not None) and (len(p.init_images) > 0):
task_args['image'] = p.init_images
if 'BlipDiffusionPipeline' in model_cls:
if len(p.init_images) == 0:
shared.log.error('BLiP diffusion requires init image')

View File

@ -143,6 +143,8 @@ def guess_by_name(fn, current_guess):
new_guess = 'LongCat'
elif 'ovis-image' in fn.lower():
new_guess = 'Ovis-Image'
elif 'glm-image' in fn.lower():
new_guess = 'GLM-Image'
if debug_load:
shared.log.trace(f'Autodetect: method=name file="{fn}" previous="{current_guess}" current="{new_guess}"')
return new_guess or current_guess

View File

@ -483,6 +483,10 @@ def load_diffuser_force(detected_model_type, checkpoint_info, diffusers_load_con
from pipelines.model_ovis import load_ovis
sd_model = load_ovis(checkpoint_info, diffusers_load_config)
allow_post_quant = False
elif model_type in ['GLM-Image']:
from pipelines.model_glm import load_glm_image
sd_model = load_glm_image(checkpoint_info, diffusers_load_config)
allow_post_quant = False
except Exception as e:
shared.log.error(f'Load {op}: path="{checkpoint_info.path}" {e}')
if debug_load:

View File

@ -38,7 +38,7 @@ prev_cls = ''
prev_type = ''
prev_model = ''
lock = threading.Lock()
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat']
supported = ['sd', 'sdxl', 'sd3', 'f1', 'h1', 'z_image', 'lumina2', 'hunyuanvideo', 'wanai', 'chrono', 'mochivideo', 'pixartsigma', 'pixartalpha', 'hunyuandit', 'omnigen', 'qwen', 'longcat', 'glm_image']
def warn_once(msg, variant=None):
@ -59,7 +59,7 @@ def get_model(model_type = 'decoder', variant = None):
model_cls = 'sd'
elif model_cls in {'pixartsigma', 'hunyuandit', 'omnigen', 'auraflow'}:
model_cls = 'sdxl'
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma', 'longcat'}:
elif model_cls in {'h1', 'z_image', 'lumina2', 'chroma', 'longcat', 'glm_image'}:
model_cls = 'f1'
elif model_cls in {'wanai', 'qwen', 'chrono'}:
variant = variant or 'TAE WanVideo'

View File

@ -49,6 +49,7 @@ pipelines = {
'HunyuanImage': getattr(diffusers, 'HunyuanImagePipeline', None),
'Z-Image': getattr(diffusers, 'ZImagePipeline', None),
'LongCat': getattr(diffusers, 'LongCatImagePipeline', None),
'GLM-Image': getattr(diffusers, 'GlmImagePipeline', None),
# dynamically imported and redefined later
'Meissonic': getattr(diffusers, 'DiffusionPipeline', None),
'Monetico': getattr(diffusers, 'DiffusionPipeline', None),

137
pipelines/model_glm.py Normal file
View File

@ -0,0 +1,137 @@
import time
import rich.progress as rp
import transformers
import diffusers
from modules import shared, devices, sd_models, model_quant, sd_hijack_te
from pipelines import generic
class GLMTokenProgressProcessor(transformers.LogitsProcessor):
"""LogitsProcessor that tracks autoregressive token generation progress for GLM-Image."""
def __init__(self):
self.total_tokens = 0
self.current_step = 0
self.task_id = None
self.pbar = None
self.pbar_task = None
self.start_time = 0
def set_total(self, total_tokens: int):
self.total_tokens = total_tokens
self.current_step = 0
def __call__(self, input_ids, scores):
if self.current_step == 0:
self.task_id = shared.state.begin('AR Generation')
self.start_time = time.time()
self.pbar = rp.Progress(
rp.TextColumn('[cyan]AR Generation'),
rp.TextColumn('{task.fields[speed]}'),
rp.BarColumn(bar_width=40, complete_style='#327fba', finished_style='#327fba'),
rp.TaskProgressColumn(),
rp.MofNCompleteColumn(),
rp.TimeElapsedColumn(),
rp.TimeRemainingColumn(),
console=shared.console,
)
self.pbar.start()
self.pbar_task = self.pbar.add_task(description='', total=self.total_tokens, speed='')
self.current_step += 1
shared.state.sampling_step = self.current_step
shared.state.sampling_steps = self.total_tokens
if self.pbar is not None and self.pbar_task is not None:
elapsed = time.time() - self.start_time
speed = f'{self.current_step / elapsed:.2f}tok/s' if elapsed > 0 else ''
self.pbar.update(self.pbar_task, completed=self.current_step, speed=speed)
if self.current_step >= self.total_tokens:
if self.pbar is not None:
self.pbar.stop()
self.pbar = None
if self.task_id is not None:
shared.state.end(self.task_id)
self.task_id = None
return scores
def _wrap_vision_language_generate(pipe):
"""Wrap vision_language_encoder.generate to add progress tracking."""
if not hasattr(pipe, 'vision_language_encoder') or pipe.vision_language_encoder is None:
return
original_generate = pipe.vision_language_encoder.generate
progress_processor = GLMTokenProgressProcessor()
def wrapped_generate(*args, **kwargs):
# Get max_new_tokens to determine total tokens
max_new_tokens = kwargs.get('max_new_tokens', 0)
progress_processor.set_total(max_new_tokens)
# Add progress processor to logits_processor list
existing_processors = kwargs.get('logits_processor', None)
if existing_processors is None:
existing_processors = []
elif not isinstance(existing_processors, list):
existing_processors = list(existing_processors)
kwargs['logits_processor'] = existing_processors + [progress_processor]
return original_generate(*args, **kwargs)
pipe.vision_language_encoder.generate = wrapped_generate
def load_glm_image(checkpoint_info, diffusers_load_config=None):
if diffusers_load_config is None:
diffusers_load_config = {}
repo_id = sd_models.path_to_repo(checkpoint_info)
sd_models.hf_auth_check(checkpoint_info)
load_args, _quant_args = model_quant.get_dit_args(diffusers_load_config, allow_quant=False)
shared.log.debug(f'Load model: type=GLM-Image repo="{repo_id}" offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype} args={load_args}')
# Load transformer (DiT decoder - 7B) with quantization support
transformer = generic.load_transformer(
repo_id,
cls_name=diffusers.GlmImageTransformer2DModel,
load_config=diffusers_load_config
)
# Load text encoder (ByT5 for glyph) - cannot use shared T5 as GLM-Image requires specific ByT5 encoder (1472 hidden size)
text_encoder = generic.load_text_encoder(
repo_id,
cls_name=transformers.T5EncoderModel,
load_config=diffusers_load_config,
allow_shared=False
)
# Load vision-language encoder (AR model - 9B)
# Note: This is a conditional generation model, different from typical text encoders
vision_language_encoder = generic.load_text_encoder(
repo_id,
cls_name=transformers.GlmImageForConditionalGeneration,
subfolder="vision_language_encoder",
load_config=diffusers_load_config,
allow_shared=False
)
pipe = diffusers.GlmImagePipeline.from_pretrained(
repo_id,
cache_dir=shared.opts.diffusers_dir,
transformer=transformer,
text_encoder=text_encoder,
vision_language_encoder=vision_language_encoder,
**load_args,
)
pipe.task_args = {
'output_type': 'np',
'generate_kwargs': {
'eos_token_id': None, # Disable EOS early stopping to ensure all required tokens are generated
},
}
del transformer, text_encoder, vision_language_encoder
sd_hijack_te.init_hijack(pipe)
_wrap_vision_language_generate(pipe) # Add progress tracking for AR token generation
devices.torch_gc(force=True, reason='load')
return pipe