diff --git a/modules/sd_offload.py b/modules/sd_offload.py index d30555ab1..f8cd05658 100644 --- a/modules/sd_offload.py +++ b/modules/sd_offload.py @@ -7,7 +7,7 @@ import torch import accelerate.hooks import accelerate.utils.modeling from modules.logger import log -from modules import shared, devices, errors, model_quant, sd_models +from modules import shared, devices, errors, model_quant, sd_models, sd_models_aux from modules.timer import process as process_timer @@ -244,6 +244,7 @@ class OffloadHook(accelerate.hooks.ModelHook): if shared.opts.diffusers_offload_pre: t0 = time.time() debug_move(f'Offload: type=balanced op=pre module={module.__class__.__name__}') + sd_models_aux.evict_aux(reason=f'pre:{module.__class__.__name__}') for pipe in get_pipe_variants(): for module_name in get_module_names(pipe): module_instance = getattr(pipe, module_name, None) diff --git a/modules/sd_offload_aux.py b/modules/sd_offload_aux.py index 244947645..8b228fb1f 100644 --- a/modules/sd_offload_aux.py +++ b/modules/sd_offload_aux.py @@ -1,14 +1,8 @@ import os -import re -import sys -import time -import inspect import dataclasses import torch -import accelerate.hooks -import accelerate.utils.modeling from modules.logger import log -from modules import shared, devices, errors, model_quant, sd_models, sd_offload +from modules import shared, devices from modules.timer import process as process_timer @@ -24,23 +18,24 @@ class AuxModel: name: str size: float # GB -_aux_models: dict[str, AuxModel] = {} + +aux_models: dict[str, AuxModel] = {} def register_aux(name: str, model: torch.nn.Module) -> None: size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3 - _aux_models[name] = AuxModel(model=model, name=name, size=size) + aux_models[name] = AuxModel(model=model, name=name, size=size) debug_move(f'Offload: type=aux op=register name={name} size={size:.3f}') def deregister_aux(name: str) -> None: - entry = _aux_models.pop(name, None) + entry = aux_models.pop(name, None) if entry: debug_move(f'Offload: type=aux op=deregister name={name}') def evict_aux(exclude: str = None, reason: str = 'evict') -> None: - for name, entry in _aux_models.items(): + for name, entry in aux_models.items(): if name == exclude: continue if entry.model is not None and hasattr(entry.model, 'device') and not devices.same_device(entry.model.device, devices.cpu): @@ -60,7 +55,7 @@ def _do_move_to_cpu(model, op_label, size): def move_aux_to_gpu(name: str) -> None: - entry = _aux_models.get(name) + entry = aux_models.get(name) if entry is None or entry.model is None: return if hasattr(entry.model, 'device') and devices.same_device(entry.model.device, devices.device): @@ -68,7 +63,8 @@ def move_aux_to_gpu(name: str) -> None: # 1. Evict other auxiliary models first evict_aux(exclude=name, reason='pre') # 2. If balanced offload active, evict diffusers pipeline modules if memory is tight - shared.sd_model = sd_offload.apply_balanced_offload(shared.sd_model) + from modules.sd_offload import apply_balanced_offload + shared.sd_model = apply_balanced_offload(shared.sd_model) # 3. Move to GPU (stream + sync) if shared.opts.diffusers_offload_streams: global move_stream # pylint: disable=global-statement @@ -85,10 +81,9 @@ def move_aux_to_gpu(name: str) -> None: def offload_aux(name: str) -> None: if not shared.opts.caption_offload: return - entry = _aux_models.get(name) + entry = aux_models.get(name) if entry is None or entry.model is None: return if hasattr(entry.model, 'device') and devices.same_device(entry.model.device, devices.cpu): return _do_move_to_cpu(entry.model, f'post:{name}', entry.size) -