187 lines
5.4 KiB
Python
187 lines
5.4 KiB
Python
import gc
|
|
import inspect
|
|
import threading
|
|
from contextlib import ContextDecorator
|
|
from functools import wraps
|
|
|
|
import torch
|
|
from modules import devices, safe, shared
|
|
from modules.sd_models import load_model, reload_model_weights
|
|
|
|
backup_sd_model, backup_device, backup_ckpt_info = None, None, None
|
|
model_access_sem = threading.Semaphore(1)
|
|
|
|
|
|
def clear_cache():
|
|
gc.collect()
|
|
devices.torch_gc()
|
|
|
|
|
|
def is_sdxl_lowvram(sd_model):
|
|
return (shared.cmd_opts.lowvram or shared.cmd_opts.medvram or
|
|
getattr(shared.cmd_opts, "medvram_sdxl", False) and hasattr(sd_model, "conditioner"))
|
|
|
|
|
|
def webui_reload_model_weights(sd_model=None, info=None):
|
|
try:
|
|
reload_model_weights(sd_model=sd_model, info=info)
|
|
except Exception:
|
|
load_model(checkpoint_info=info)
|
|
|
|
|
|
def pre_offload_model_weights(sem):
|
|
global backup_sd_model, backup_device, backup_ckpt_info
|
|
with sem:
|
|
if (shared.sd_model is not None and not is_sdxl_lowvram(shared.sd_model) and
|
|
getattr(shared.sd_model, "device", devices.cpu) != devices.cpu):
|
|
backup_sd_model = shared.sd_model
|
|
backup_device = backup_sd_model.device
|
|
backup_sd_model.to(devices.cpu)
|
|
clear_cache()
|
|
|
|
|
|
def await_pre_offload_model_weights():
|
|
global model_access_sem
|
|
thread = threading.Thread(target=pre_offload_model_weights, args=(model_access_sem,))
|
|
thread.start()
|
|
thread.join()
|
|
|
|
|
|
def pre_reload_model_weights(sem):
|
|
global backup_sd_model, backup_device, backup_ckpt_info
|
|
with sem:
|
|
if backup_sd_model is not None and backup_device is not None:
|
|
backup_sd_model.to(backup_device)
|
|
backup_sd_model, backup_device = None, None
|
|
if shared.sd_model is not None and backup_ckpt_info is not None:
|
|
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
|
|
backup_ckpt_info = None
|
|
|
|
|
|
def await_pre_reload_model_weights():
|
|
global model_access_sem
|
|
thread = threading.Thread(target=pre_reload_model_weights, args=(model_access_sem,))
|
|
thread.start()
|
|
thread.join()
|
|
|
|
|
|
def backup_reload_ckpt_info(sem, info):
|
|
global backup_sd_model, backup_device, backup_ckpt_info
|
|
with sem:
|
|
if backup_sd_model is not None and backup_device is not None:
|
|
backup_sd_model.to(backup_device)
|
|
backup_sd_model, backup_device = None, None
|
|
if shared.sd_model is not None:
|
|
backup_ckpt_info = shared.sd_model.sd_checkpoint_info
|
|
webui_reload_model_weights(sd_model=shared.sd_model, info=info)
|
|
|
|
|
|
def await_backup_reload_ckpt_info(info):
|
|
global model_access_sem
|
|
thread = threading.Thread(target=backup_reload_ckpt_info, args=(model_access_sem, info))
|
|
thread.start()
|
|
thread.join()
|
|
|
|
|
|
def post_reload_model_weights(sem):
|
|
global backup_sd_model, backup_device, backup_ckpt_info
|
|
with sem:
|
|
if backup_sd_model is not None and backup_device is not None:
|
|
backup_sd_model.to(backup_device)
|
|
backup_sd_model, backup_device = None, None
|
|
if shared.sd_model is not None and backup_ckpt_info is not None:
|
|
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
|
|
backup_ckpt_info = None
|
|
|
|
|
|
def async_post_reload_model_weights():
|
|
global model_access_sem
|
|
thread = threading.Thread(target=post_reload_model_weights, args=(model_access_sem,))
|
|
thread.start()
|
|
|
|
|
|
def acquire_release_semaphore(sem):
|
|
with sem:
|
|
pass
|
|
|
|
|
|
def await_acquire_release_semaphore():
|
|
global model_access_sem
|
|
thread = threading.Thread(target=acquire_release_semaphore, args=(model_access_sem,))
|
|
thread.start()
|
|
thread.join()
|
|
|
|
|
|
def clear_cache_decorator(func):
|
|
@wraps(func)
|
|
def yield_wrapper(*args, **kwargs):
|
|
clear_cache()
|
|
yield from func(*args, **kwargs)
|
|
clear_cache()
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
clear_cache()
|
|
res = func(*args, **kwargs)
|
|
clear_cache()
|
|
return res
|
|
|
|
if inspect.isgeneratorfunction(func):
|
|
return yield_wrapper
|
|
else:
|
|
return wrapper
|
|
|
|
|
|
def post_reload_decorator(func):
|
|
@wraps(func)
|
|
def yield_wrapper(*args, **kwargs):
|
|
await_acquire_release_semaphore()
|
|
yield from func(*args, **kwargs)
|
|
async_post_reload_model_weights()
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
await_acquire_release_semaphore()
|
|
res = func(*args, **kwargs)
|
|
async_post_reload_model_weights()
|
|
return res
|
|
|
|
if inspect.isgeneratorfunction(func):
|
|
return yield_wrapper
|
|
else:
|
|
return wrapper
|
|
|
|
|
|
def offload_reload_decorator(func):
|
|
@wraps(func)
|
|
def yield_wrapper(*args, **kwargs):
|
|
await_pre_offload_model_weights()
|
|
yield from func(*args, **kwargs)
|
|
async_post_reload_model_weights()
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
await_pre_offload_model_weights()
|
|
res = func(*args, **kwargs)
|
|
async_post_reload_model_weights()
|
|
return res
|
|
|
|
if inspect.isgeneratorfunction(func):
|
|
return yield_wrapper
|
|
else:
|
|
return wrapper
|
|
|
|
|
|
class torch_default_load_cd(ContextDecorator):
|
|
def __init__(self):
|
|
self.backup_load = safe.load
|
|
|
|
def __enter__(self):
|
|
self.backup_load = torch.load
|
|
torch.load = safe.unsafe_torch_load
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
torch.load = self.backup_load
|
|
return False
|