automatic/modules/model_t5.py

79 lines
3.0 KiB
Python

import transformers
def load_t5(t5=None, cache_dir=None):
from modules import devices, modelloader
repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers'
if 'fp16' in t5.lower():
modelloader.hf_login()
t5 = transformers.T5EncoderModel.from_pretrained(
repo_id,
subfolder='text_encoder_3',
# torch_dtype=dtype,
cache_dir=cache_dir,
torch_dtype=devices.dtype,
)
elif 'fp4' in t5.lower():
modelloader.hf_login()
from installer import install
install('bitsandbytes', quiet=True)
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(
repo_id,
subfolder='text_encoder_3',
quantization_config=quantization_config,
cache_dir=cache_dir,
torch_dtype=devices.dtype,
)
elif 'fp8' in t5.lower():
modelloader.hf_login()
from installer import install
install('bitsandbytes', quiet=True)
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(
repo_id,
subfolder='text_encoder_3',
quantization_config=quantization_config,
cache_dir=cache_dir,
torch_dtype=devices.dtype,
)
elif 'int8' in t5.lower():
modelloader.hf_login()
from installer import install
install('nncf==2.7.0', quiet=True)
from modules.sd_models_compile import nncf_compress_model
from modules.sd_hijack import NNCF_T5DenseGatedActDense # T5DenseGatedActDense uses fp32
t5 = transformers.T5EncoderModel.from_pretrained(
repo_id,
subfolder='text_encoder_3',
cache_dir=cache_dir,
torch_dtype=devices.dtype,
)
for i in range(len(t5.encoder.block)):
t5.encoder.block[i].layer[1].DenseReluDense = NNCF_T5DenseGatedActDense(
t5.encoder.block[i].layer[1].DenseReluDense
)
t5 = nncf_compress_model(t5)
else:
t5 = None
return t5
def set_t5(pipe, module, t5=None, cache_dir=None):
from modules import devices, shared
if pipe is None or not hasattr(pipe, module):
return pipe
t5 = load_t5(t5=t5, cache_dir=cache_dir)
setattr(pipe, module, t5)
if shared.cmd_opts.lowvram or shared.opts.diffusers_seq_cpu_offload:
from accelerate import cpu_offload
getattr(pipe, module).to("cpu")
cpu_offload(getattr(pipe, module), devices.device, offload_buffers=len(getattr(pipe, module)._parameters) > 0) # pylint: disable=protected-access
elif shared.cmd_opts.medvram or shared.opts.diffusers_model_cpu_offload:
if not hasattr(pipe, "_all_hooks") or len(pipe._all_hooks) == 0: # pylint: disable=protected-access
pipe.enable_model_cpu_offload(device=devices.device)
else:
pipe.maybe_free_model_hooks()
devices.torch_gc()
return pipe