mirror of https://github.com/vladmandic/automatic
fix(glm): rename _wrap_vision_language_generate to hijack_vision_language_generate
parent
de3b2b62eb
commit
7fb58f22c2
|
|
@ -54,7 +54,7 @@ class GLMTokenProgressProcessor(transformers.LogitsProcessor):
|
|||
return scores
|
||||
|
||||
|
||||
def _wrap_vision_language_generate(pipe):
|
||||
def hijack_vision_language_generate(pipe):
|
||||
"""Wrap vision_language_encoder.generate to add progress tracking."""
|
||||
if not hasattr(pipe, 'vision_language_encoder') or pipe.vision_language_encoder is None:
|
||||
return
|
||||
|
|
@ -132,6 +132,6 @@ def load_glm_image(checkpoint_info, diffusers_load_config=None):
|
|||
|
||||
del transformer, text_encoder, vision_language_encoder
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
_wrap_vision_language_generate(pipe) # Add progress tracking for AR token generation
|
||||
hijack_vision_language_generate(pipe) # Add progress tracking for AR token generation
|
||||
devices.torch_gc(force=True, reason='load')
|
||||
return pipe
|
||||
|
|
|
|||
Loading…
Reference in New Issue