automatic/modules/linfusion/__init__.py

45 lines
1.6 KiB
Python

from modules import shared, sd_models, devices
from .linfusion import LinFusion
from .attention import GeneralizedLinearAttention
applied: LinFusion = None
def detect(pipeline):
if pipeline.__class__.__name__ == 'StableDiffusionXLPipeline':
return "Yuanshi/LinFusion-XL"
if pipeline.__class__.__name__ == 'StableDiffusionPipeline':
return "Yuanshi/LinFusion-1-5"
return None
def apply(pipeline, pretrained: bool = True):
global applied # pylint: disable=global-statement
if applied is not None:
return
# linfusion = LinFusion.construct_for(pipeline=pipeline)
if not pretrained:
model_path = None
default_config = LinFusion.get_default_config(unet=pipeline.unet)
applied = LinFusion(**default_config).to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
applied.mount_to(unet=pipeline.unet)
else:
model_path = detect(pipeline)
if model_path is None:
shared.log.error('LinFusion: unsupported model type')
return
applied = LinFusion.from_pretrained(model_path, cache_dir=shared.opts.hfcache_dir).to(device=pipeline.unet.device, dtype=pipeline.unet.dtype)
applied.mount_to(unet=pipeline.unet)
shared.log.debug(f'LinFusion: apply class={applied.__class__.__name__} model="{model_path}" modules={len(applied.modules_dict)}')
def unapply(pipeline):
global applied # pylint: disable=global-statement
if applied is None:
return
shared.log.debug('LinFusion: unapply')
sd_models.set_diffusers_attention(pipeline)
devices.torch_gc()
applied = None