mirror of https://github.com/vladmandic/automatic
39 lines
998 B
Python
39 lines
998 B
Python
import os
|
|
import numpy as np
|
|
import albumentations
|
|
from torch.utils.data import Dataset
|
|
|
|
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
|
|
|
|
|
|
class CustomBase(Dataset):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
self.data = None
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, i):
|
|
example = self.data[i]
|
|
return example
|
|
|
|
|
|
|
|
class CustomTrain(CustomBase):
|
|
def __init__(self, size, training_images_list_file):
|
|
super().__init__()
|
|
with open(training_images_list_file, "r") as f:
|
|
paths = f.read().splitlines()
|
|
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
|
|
|
|
|
class CustomTest(CustomBase):
|
|
def __init__(self, size, test_images_list_file):
|
|
super().__init__()
|
|
with open(test_images_list_file, "r") as f:
|
|
paths = f.read().splitlines()
|
|
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
|
|
|
|
|