mirror of https://github.com/vladmandic/automatic
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import bisect
|
|
import numpy as np
|
|
import albumentations
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset, ConcatDataset
|
|
|
|
|
|
class ConcatDatasetWithIndex(ConcatDataset):
|
|
"""Modified from original pytorch code to return dataset idx"""
|
|
def __getitem__(self, idx):
|
|
if idx < 0:
|
|
if -idx > len(self):
|
|
raise ValueError("absolute value of index should not exceed dataset length")
|
|
idx = len(self) + idx
|
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
if dataset_idx == 0:
|
|
sample_idx = idx
|
|
else:
|
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
return self.datasets[dataset_idx][sample_idx], dataset_idx
|
|
|
|
|
|
class ImagePaths(Dataset):
|
|
def __init__(self, paths, size=None, random_crop=False, labels=None):
|
|
self.size = size
|
|
self.random_crop = random_crop
|
|
|
|
self.labels = dict() if labels is None else labels
|
|
self.labels["file_path_"] = paths
|
|
self._length = len(paths)
|
|
|
|
if self.size is not None and self.size > 0:
|
|
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
|
|
if not self.random_crop:
|
|
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
|
|
else:
|
|
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
|
|
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
|
|
else:
|
|
self.preprocessor = lambda **kwargs: kwargs
|
|
|
|
def __len__(self):
|
|
return self._length
|
|
|
|
def preprocess_image(self, image_path):
|
|
image = Image.open(image_path)
|
|
if not image.mode == "RGB":
|
|
image = image.convert("RGB")
|
|
image = np.array(image).astype(np.uint8)
|
|
image = self.preprocessor(image=image)["image"]
|
|
image = (image/127.5 - 1.0).astype(np.float32)
|
|
return image
|
|
|
|
def __getitem__(self, i):
|
|
example = dict()
|
|
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
|
|
for k in self.labels:
|
|
example[k] = self.labels[k][i]
|
|
return example
|
|
|
|
|
|
class NumpyPaths(ImagePaths):
|
|
def preprocess_image(self, image_path):
|
|
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
|
|
image = np.transpose(image, (1,2,0))
|
|
image = Image.fromarray(image, mode="RGB")
|
|
image = np.array(image).astype(np.uint8)
|
|
image = self.preprocessor(image=image)["image"]
|
|
image = (image/127.5 - 1.0).astype(np.float32)
|
|
return image
|