Merge branch 'main' into experimental

pull/8/head
toshiaki1729 2022-11-11 00:18:06 +09:00
commit 441f6bba3f
1 changed files with 27 additions and 15 deletions

View File

@ -3,6 +3,7 @@ import re
from typing import Optional, List, Tuple, Set
from modules import shared
from modules.textual_inversion.dataset import re_numbers_at_start
from PIL import Image
re_tags = re.compile(r'^(.+) \[\d+\]$')
@ -141,9 +142,12 @@ class DatasetTagEditor:
def load_dataset(self, img_dir: str, recursive: bool = False):
self.clear()
print(f'Loading dataset from {img_dir}')
try:
filepath_set = get_filepath_set(dir=img_dir, recursive=recursive)
except:
except Exception as e:
print(e)
print('Loading dataset has been aborted.')
return
self.dataset_dir = img_dir
@ -151,23 +155,31 @@ class DatasetTagEditor:
for img_path in filepath_set:
img_dir = os.path.dirname(img_path)
img_filename, img_ext = os.path.splitext(os.path.basename(img_path))
if img_ext == '.png':
text_filename = os.path.join(img_dir, img_filename+'.txt')
# from modules/textual_inversion/dataset.py
if os.path.exists(text_filename) and os.path.isfile(text_filename):
with open(text_filename, "r", encoding="utf8") as ftxt:
filename_text = ftxt.read()
else:
filename_text = img_filename
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if self.re_word:
tokens = self.re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
self.set_tags_by_image_path(img_path, [t.strip() for t in filename_text.split(',')])
if img_ext == '.txt':
continue
try:
img = Image.open(img_path)
except:
img.close()
continue
text_filename = os.path.join(img_dir, img_filename+'.txt')
# from modules/textual_inversion/dataset.py
if os.path.exists(text_filename) and os.path.isfile(text_filename):
with open(text_filename, "r", encoding="utf8") as ftxt:
filename_text = ftxt.read()
else:
filename_text = img_filename
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if self.re_word:
tokens = self.re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
self.set_tags_by_image_path(img_path, [t.strip() for t in filename_text.split(',')])
self.construct_tag_counts()
self.set_img_filter_img_path()
print(f'Loading dataset has been Completed')
def save_dataset(self, backup: bool) -> Tuple[int, int, str]: