sd-webui-inpaint-anything/ia_threading.py

187 lines
5.4 KiB
Python

import gc
import inspect
import threading
from contextlib import ContextDecorator
from functools import wraps
import torch
from modules import devices, safe, shared
from modules.sd_models import load_model, reload_model_weights
backup_sd_model, backup_device, backup_ckpt_info = None, None, None
model_access_sem = threading.Semaphore(1)
def clear_cache():
gc.collect()
devices.torch_gc()
def is_sdxl_lowvram(sd_model):
return (shared.cmd_opts.lowvram or shared.cmd_opts.medvram or
getattr(shared.cmd_opts, "medvram_sdxl", False) and hasattr(sd_model, "conditioner"))
def webui_reload_model_weights(sd_model=None, info=None):
try:
reload_model_weights(sd_model=sd_model, info=info)
except Exception:
load_model(checkpoint_info=info)
def pre_offload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if (shared.sd_model is not None and not is_sdxl_lowvram(shared.sd_model) and
getattr(shared.sd_model, "device", devices.cpu) != devices.cpu):
backup_sd_model = shared.sd_model
backup_device = getattr(backup_sd_model, "device")
backup_sd_model.to(devices.cpu)
clear_cache()
def await_pre_offload_model_weights():
global model_access_sem
thread = threading.Thread(target=pre_offload_model_weights, args=(model_access_sem,))
thread.start()
thread.join()
def pre_reload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None and backup_ckpt_info is not None:
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
backup_ckpt_info = None
def await_pre_reload_model_weights():
global model_access_sem
thread = threading.Thread(target=pre_reload_model_weights, args=(model_access_sem,))
thread.start()
thread.join()
def backup_reload_ckpt_info(sem, info):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None:
backup_ckpt_info = shared.sd_model.sd_checkpoint_info
webui_reload_model_weights(sd_model=shared.sd_model, info=info)
def await_backup_reload_ckpt_info(info):
global model_access_sem
thread = threading.Thread(target=backup_reload_ckpt_info, args=(model_access_sem, info))
thread.start()
thread.join()
def post_reload_model_weights(sem):
global backup_sd_model, backup_device, backup_ckpt_info
with sem:
if backup_sd_model is not None and backup_device is not None:
backup_sd_model.to(backup_device)
backup_sd_model, backup_device = None, None
if shared.sd_model is not None and backup_ckpt_info is not None:
webui_reload_model_weights(sd_model=shared.sd_model, info=backup_ckpt_info)
backup_ckpt_info = None
def async_post_reload_model_weights():
global model_access_sem
thread = threading.Thread(target=post_reload_model_weights, args=(model_access_sem,))
thread.start()
def acquire_release_semaphore(sem):
with sem:
pass
def await_acquire_release_semaphore():
global model_access_sem
thread = threading.Thread(target=acquire_release_semaphore, args=(model_access_sem,))
thread.start()
thread.join()
def clear_cache_decorator(func):
@wraps(func)
def yield_wrapper(*args, **kwargs):
clear_cache()
yield from func(*args, **kwargs)
clear_cache()
@wraps(func)
def wrapper(*args, **kwargs):
clear_cache()
res = func(*args, **kwargs)
clear_cache()
return res
if inspect.isgeneratorfunction(func):
return yield_wrapper
else:
return wrapper
def post_reload_decorator(func):
@wraps(func)
def yield_wrapper(*args, **kwargs):
await_acquire_release_semaphore()
yield from func(*args, **kwargs)
async_post_reload_model_weights()
@wraps(func)
def wrapper(*args, **kwargs):
await_acquire_release_semaphore()
res = func(*args, **kwargs)
async_post_reload_model_weights()
return res
if inspect.isgeneratorfunction(func):
return yield_wrapper
else:
return wrapper
def offload_reload_decorator(func):
@wraps(func)
def yield_wrapper(*args, **kwargs):
await_pre_offload_model_weights()
yield from func(*args, **kwargs)
async_post_reload_model_weights()
@wraps(func)
def wrapper(*args, **kwargs):
await_pre_offload_model_weights()
res = func(*args, **kwargs)
async_post_reload_model_weights()
return res
if inspect.isgeneratorfunction(func):
return yield_wrapper
else:
return wrapper
class torch_default_load_cd(ContextDecorator):
def __init__(self):
self.backup_load = safe.load
def __enter__(self):
self.backup_load = torch.load
torch.load = safe.unsafe_torch_load
return self
def __exit__(self, *exc):
torch.load = self.backup_load
return False