mirror of https://github.com/vladmandic/automatic
33 lines
1.3 KiB
Python
33 lines
1.3 KiB
Python
import os
|
|
import time
|
|
from modules import shared, errors, timer, sd_models
|
|
|
|
|
|
def hijack_encode_prompt(*args, **kwargs):
|
|
shared.state.begin('TE')
|
|
t0 = time.time()
|
|
if 'max_sequence_length' in kwargs:
|
|
kwargs['max_sequence_length'] = max(kwargs['max_sequence_length'], os.environ.get('HIDREAM_MAX_SEQUENCE_LENGTH', 256))
|
|
# if hasattr(shared.sd_model, 'text_encoder') and shared.sd_model.text_encoder is not None:
|
|
# sd_models.move_model(shared.sd_model.text_encoder, devices.device)
|
|
try:
|
|
res = shared.sd_model.orig_encode_prompt(*args, **kwargs)
|
|
except Exception as e:
|
|
shared.log.error(f'Encode prompt: {e}')
|
|
errors.display(e, 'Encode prompt')
|
|
res = None
|
|
t1 = time.time()
|
|
timer.process.add('te', t1-t0)
|
|
if hasattr(shared.sd_model, "maybe_free_model_hooks"):
|
|
shared.sd_model.maybe_free_model_hooks()
|
|
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
|
|
shared.state.end()
|
|
return res
|
|
|
|
|
|
def init_hijack(pipe):
|
|
if shared.opts.te_hijack and pipe is not None and not hasattr(pipe, 'orig_encode_prompt') and hasattr(pipe, 'encode_prompt'):
|
|
# shared.log.debug(f'Model: cls={pipe.__class__.__name__} hijack encode')
|
|
pipe.orig_encode_prompt = pipe.encode_prompt
|
|
pipe.encode_prompt = hijack_encode_prompt
|