sd_dreambooth_extension/dreambooth/utils/utils.py

248 lines
6.6 KiB
Python

from __future__ import annotations
import gc
import html
import importlib.util
import logging
import os
import sys
import traceback
from typing import Optional
import importlib_metadata
from packaging import version
from dreambooth import shared
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import torch
from huggingface_hub import HfFolder, whoami
from helpers.mytqdm import mytqdm
from dreambooth.shared import status
def printi(msg, params=None, log=True):
if log:
status.textinfo = msg
if status.job_count > status.job_no:
status.job_no += 1
if params:
mytqdm.write(msg, params)
else:
mytqdm.write(msg)
def sanitize_tags(name):
tags = name.split(",")
name = ""
for tag in tags:
tag = tag.replace(" ", "_").strip()
name = "".join(x for x in tag if (x.isalnum() or x in "._-"))
name = name.replace(" ", "_")
return "".join(x for x in name if (x.isalnum() or x in "._-,"))
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in "._-"))
def printm(msg=""):
from dreambooth import shared
use_logger = True
try:
from core.handlers.config import ConfigHandler
except:
use_logger = False
if shared.debug or use_logger:
allocated = round(torch.cuda.memory_allocated(0) / 1024**3, 1)
cached = round(torch.cuda.memory_reserved(0) / 1024**3, 1)
if use_logger:
logging.getLogger(__name__).info(f"{msg}({allocated}/{cached})")
else:
print(f"{msg}({allocated}/{cached})")
def cleanup(do_print: bool = False):
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
except:
print("cleanup exception")
if do_print:
print("Cleanup completed.")
def xformers_check():
env_vars_true_values = {"1", "ON", "YES", "TRUE"}
env_vars_true_and_auto_values = env_vars_true_values.union({"AUTO"})
use_tf = os.environ.get("USE_TF", "AUTO").upper()
use_torch = os.environ.get("USE_TORCH", "AUTO").upper()
if (
use_torch in env_vars_true_and_auto_values
and use_tf not in env_vars_true_values
):
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
except importlib_metadata.PackageNotFoundError:
print("No metadatapackage")
_torch_available = False
else:
_torch_available = False
try:
_xformers_version = importlib_metadata.version("xformers")
if _torch_available:
import torch
if version.Version(torch.__version__) < version.Version("1.12"):
raise ValueError("PyTorch version must be >= 1.12")
if version.Version(_xformers_version) < version.Version("0.0.17.dev"):
raise ValueError("Xformers version must be >= 0.0.17.dev")
has_xformers = True
except Exception as e:
# print(f"Exception importing xformers: {e}")
has_xformers = False
return has_xformers
def list_optimizer():
optimizer_list = ["Torch AdamW"]
try:
if shared.device.type != "mps":
from bitsandbytes.optim import AdamW8bit
optimizer_list.append("8bit AdamW")
except:
pass
try:
from dadaptation import DAdaptAdam
optimizer_list.append("AdamW Dadaptation")
except:
pass
try:
from dadaptation import DAdaptAdan
optimizer_list.append("Adan Dadaptation")
except:
pass
try:
from dadaptation import DAdaptSGD
optimizer_list.append("SGD Dadaptation")
except:
pass
return optimizer_list
def list_attention():
has_xformers = xformers_check()
import diffusers.utils
diffusers.utils.is_xformers_available = xformers_check
if has_xformers:
return ["default", "xformers"]
else:
return ["default"]
def list_precisions():
precisions = ["no", "fp16"]
try:
if torch.cuda.is_bf16_supported():
precisions.append("bf16")
except:
pass
return precisions
def list_schedulers():
return [
"linear",
"linear_with_warmup",
"cosine",
"cosine_annealing",
"cosine_annealing_with_restarts",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
"dadapt_with_warmup",
]
def wrap_gpu_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
try:
status.begin()
res = func(*args, **kwargs)
status.end()
except Exception as e:
status.end()
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
arg_str = f"Arguments: {str(args)} {str(kwargs)}"
print(arg_str[:max_debug_str_len], file=sys.stderr)
if len(arg_str) > max_debug_str_len:
print(
f"(Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)",
file=sys.stderr,
)
print(traceback.format_exc(), file=sys.stderr)
status.job = ""
status.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, ""]
res = extra_outputs_array + [
f"<div class='error'>{html.escape(type(e).__name__ + ': ' + str(e))}</div>"
]
return res
return f
def get_full_repo_name(
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def verify_locon_installed(args):
is_locon_installed = os.path.exists(
os.path.join(
shared.script_path,
"extensions",
"a1111-sd-webui-locon",
)
)
if args.save_lora_for_extra_net and args.use_lora_extended and not is_locon_installed:
raise Exception(
r"a1111-sd-webui-locon extension is required to save "
r"extra net for extended lora. Please install "
r"https://github.com/KohakuBlueleaf/a1111-sd-webui-locon"
)