mirror of https://github.com/vladmandic/automatic
170 lines
5.8 KiB
Python
170 lines
5.8 KiB
Python
import collections
|
|
import os
|
|
import tarfile
|
|
import urllib
|
|
import zipfile
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from taming.data.helper_types import Annotation
|
|
from torch._six import string_classes
|
|
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
|
from tqdm import tqdm
|
|
|
|
|
|
def unpack(path):
|
|
if path.endswith("tar.gz"):
|
|
with tarfile.open(path, "r:gz") as tar:
|
|
tar.extractall(path=os.path.split(path)[0])
|
|
elif path.endswith("tar"):
|
|
with tarfile.open(path, "r:") as tar:
|
|
tar.extractall(path=os.path.split(path)[0])
|
|
elif path.endswith("zip"):
|
|
with zipfile.ZipFile(path, "r") as f:
|
|
f.extractall(path=os.path.split(path)[0])
|
|
else:
|
|
raise NotImplementedError(
|
|
"Unknown file extension: {}".format(os.path.splitext(path)[1])
|
|
)
|
|
|
|
|
|
def reporthook(bar):
|
|
"""tqdm progress bar for downloads."""
|
|
|
|
def hook(b=1, bsize=1, tsize=None):
|
|
if tsize is not None:
|
|
bar.total = tsize
|
|
bar.update(b * bsize - bar.n)
|
|
|
|
return hook
|
|
|
|
|
|
def get_root(name):
|
|
base = "data/"
|
|
root = os.path.join(base, name)
|
|
os.makedirs(root, exist_ok=True)
|
|
return root
|
|
|
|
|
|
def is_prepared(root):
|
|
return Path(root).joinpath(".ready").exists()
|
|
|
|
|
|
def mark_prepared(root):
|
|
Path(root).joinpath(".ready").touch()
|
|
|
|
|
|
def prompt_download(file_, source, target_dir, content_dir=None):
|
|
targetpath = os.path.join(target_dir, file_)
|
|
while not os.path.exists(targetpath):
|
|
if content_dir is not None and os.path.exists(
|
|
os.path.join(target_dir, content_dir)
|
|
):
|
|
break
|
|
print(
|
|
"Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
|
|
)
|
|
if content_dir is not None:
|
|
print(
|
|
"Or place its content into '{}'.".format(
|
|
os.path.join(target_dir, content_dir)
|
|
)
|
|
)
|
|
input("Press Enter when done...")
|
|
return targetpath
|
|
|
|
|
|
def download_url(file_, url, target_dir):
|
|
targetpath = os.path.join(target_dir, file_)
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
with tqdm(
|
|
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
|
|
) as bar:
|
|
urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
|
|
return targetpath
|
|
|
|
|
|
def download_urls(urls, target_dir):
|
|
paths = dict()
|
|
for fname, url in urls.items():
|
|
outpath = download_url(fname, url, target_dir)
|
|
paths[fname] = outpath
|
|
return paths
|
|
|
|
|
|
def quadratic_crop(x, bbox, alpha=1.0):
|
|
"""bbox is xmin, ymin, xmax, ymax"""
|
|
im_h, im_w = x.shape[:2]
|
|
bbox = np.array(bbox, dtype=np.float32)
|
|
bbox = np.clip(bbox, 0, max(im_h, im_w))
|
|
center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
|
|
w = bbox[2] - bbox[0]
|
|
h = bbox[3] - bbox[1]
|
|
l = int(alpha * max(w, h))
|
|
l = max(l, 2)
|
|
|
|
required_padding = -1 * min(
|
|
center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
|
|
)
|
|
required_padding = int(np.ceil(required_padding))
|
|
if required_padding > 0:
|
|
padding = [
|
|
[required_padding, required_padding],
|
|
[required_padding, required_padding],
|
|
]
|
|
padding += [[0, 0]] * (len(x.shape) - 2)
|
|
x = np.pad(x, padding, "reflect")
|
|
center = center[0] + required_padding, center[1] + required_padding
|
|
xmin = int(center[0] - l / 2)
|
|
ymin = int(center[1] - l / 2)
|
|
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
|
|
|
|
|
|
def custom_collate(batch):
|
|
r"""source: pytorch 1.9.0, only one modification to original code """
|
|
|
|
elem = batch[0]
|
|
elem_type = type(elem)
|
|
if isinstance(elem, torch.Tensor):
|
|
out = None
|
|
if torch.utils.data.get_worker_info() is not None:
|
|
# If we're in a background process, concatenate directly into a
|
|
# shared memory tensor to avoid an extra copy
|
|
numel = sum([x.numel() for x in batch])
|
|
storage = elem.storage()._new_shared(numel)
|
|
out = elem.new(storage)
|
|
return torch.stack(batch, 0, out=out)
|
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
|
and elem_type.__name__ != 'string_':
|
|
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
|
# array of string classes and object
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
|
|
|
return custom_collate([torch.as_tensor(b) for b in batch])
|
|
elif elem.shape == (): # scalars
|
|
return torch.as_tensor(batch)
|
|
elif isinstance(elem, float):
|
|
return torch.tensor(batch, dtype=torch.float64)
|
|
elif isinstance(elem, int):
|
|
return torch.tensor(batch)
|
|
elif isinstance(elem, string_classes):
|
|
return batch
|
|
elif isinstance(elem, collections.abc.Mapping):
|
|
return {key: custom_collate([d[key] for d in batch]) for key in elem}
|
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
|
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
|
|
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
|
|
return batch # added
|
|
elif isinstance(elem, collections.abc.Sequence):
|
|
# check to make sure that the elements in batch have consistent size
|
|
it = iter(batch)
|
|
elem_size = len(next(it))
|
|
if not all(len(elem) == elem_size for elem in it):
|
|
raise RuntimeError('each element in list of batch should be of equal size')
|
|
transposed = zip(*batch)
|
|
return [custom_collate(samples) for samples in transposed]
|
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type))
|