sd_dreambooth_extension/dreambooth/memory.py

141 lines
5.1 KiB
Python

# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A collection of utilities for ensuring that training can always occur. Heavily influenced by the
[toma](https://github.com/BlackHC/toma) library.
"""
import functools
import gc
import inspect
import traceback
import torch
import torch.backends.cudnn
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import cleanup
def should_reduce_batch_size(exception: Exception) -> bool:
"""
Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory
Args:
exception (`Exception`):
An exception
"""
_statements = [
"CUDA out of memory.", # CUDA OOM
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
"DefaultCPUAllocator: can't allocate memory", # CPU OOM
]
if isinstance(exception, RuntimeError) and len(exception.args) == 1:
return any(err in exception.args[0] for err in _statements)
return False
profiler = None
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128,
starting_grad_size: int = 128, logging_dir: str = "", cleanup_function: callable = None):
"""
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function`
`function` must take in a `batch_size` parameter as its first argument.
Args:
function (`callable`, *optional*):
A function to wrap
starting_batch_size (`int`, *optional*):
The batch size to try and fit into memory
starting_grad_size:
The starting number of grad accumulation steps to use. Will be divided by 2 every loop.
logging_dir:
The directory to use for logging.
cleanup_function:
A function to call after each loop. Useful for clearing memory.
"""
global profiler
try:
profile_memory = shared.profile_db
except Exception:
profile_memory = False
torch.backends.cudnn.benchmark = not profile_memory
if profile_memory and profiler is None:
from torch.profiler import profile
cleanup(True)
profiler = profile(
schedule=torch.profiler.schedule(wait=0, warmup=0, active=100, repeat=100),
on_trace_ready=torch.profiler.tensorboard_trace_handler(f'{logging_dir}'),
with_stack=True,
profile_memory=True)
print("Starting profiler...")
profiler.start()
else:
prof = None
if function is None:
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size,
starting_grad_size=starting_grad_size, logging_dir=logging_dir)
batch_size = starting_batch_size
grad_size = starting_grad_size
def decorator(*args, **kwargs):
nonlocal batch_size
nonlocal grad_size
nonlocal prof
gc.collect()
torch.cuda.empty_cache()
params = list(inspect.signature(function).parameters.keys())
# Guard against user error
if len(params) < (len(args) + 1):
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
raise TypeError(
f"Batch size was passed into `{function.__name__}` as the first argument when called."
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
)
while True:
if batch_size == 0:
gc.collect()
torch.cuda.empty_cache()
# Execute cleanup_function if it is not None
if cleanup_function is not None:
cleanup_function()
raise RuntimeError("No executable batch size found, reached zero.")
try:
return function(batch_size, grad_size, prof, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
gc.collect()
torch.cuda.empty_cache()
batch_size //= 2
grad_size //= 2
if grad_size == 0:
grad_size = 1
print(f"OOM Detected, reducing batch/grad size to {batch_size}/{grad_size}.")
traceback.print_exc()
else:
raise
return decorator