automatic/pipelines/qwen/qwen_nunchaku.py

23 lines
1.2 KiB
Python

from modules import shared, devices
def load_qwen_nunchaku(repo_id):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = None
transformer = None
try:
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
except Exception:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" low nunchaku version')
return None
if repo_id.lower().endswith('qwen-image'):
nunchaku_repo = f"nunchaku-tech/nunchaku-qwen-image/svdq-{nunchaku_precision}_r128-qwen-image.safetensors" # r32 vs R128
else:
shared.log.error(f'Load module: quant=Nunchaku module=transformer repo="{repo_id}" unsupported')
if nunchaku_repo is not None:
shared.log.debug(f'Load module: quant=Nunchaku module=transformer repo="{nunchaku_repo}" precision={nunchaku_precision} offload={shared.opts.nunchaku_offload} attention={shared.opts.nunchaku_attention}')
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(nunchaku_repo, offload=shared.opts.nunchaku_offload, torch_dtype=devices.dtype) # pylint: disable=no-member
transformer.quantization_method = 'SVDQuant'
return transformer