implement some features to manage token count (#52)

Add:
- Count and truncate by tokens amount (#40)
- Sort tags by token count

Change:
- Move "use raw clip token..." setting to "Settings" tab
- Show tag frequency, length or token count depending on the "Sort by"
pull/54/head
toshiaki1729 2023-03-07 22:23:48 +09:00 committed by GitHub
parent 94afc6e025
commit 43d8d8f65f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 766 additions and 708 deletions

View File

@ -1,633 +1,8 @@
from pathlib import Path
import re
from typing import List, Set, Optional
from modules import shared
from modules.textual_inversion.dataset import re_numbers_at_start
from PIL import Image
from enum import Enum
from scripts.singleton import Singleton
from . import dataset as ds
from . import tagger
from . import captioning
from . import filters
from . import kohya_finetune_metadata as kohya_metadata
from . import dataset as ds
from .dte_logic import DatasetTagEditor, INTERROGATOR_NAMES, interrogate_image
__all__ = ["ds", "tagger", "captioning", "filters", "kohya_metadata", "INTERROGATOR_NAMES", "interrogate_image", "DatasetTagEditor"]
re_tags = re.compile(r'^(.+) \[\d+\]$')
WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"]
WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS]
def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd):
try:
img = Image.open(path).convert('RGB')
except:
return ''
else:
for it in INTERROGATORS:
if it.name() == interrogator_name:
if isinstance(it, tagger.DeepDanbooru):
with it as tg:
res = tg.predict(img, threshold_booru)
elif isinstance(it, tagger.WaifuDiffusion):
with it as tg:
res = tg.predict(img, threshold_wd)
else:
with it as cap:
res = cap.predict(img)
return ', '.join(res)
class DatasetTagEditor(Singleton):
class SortBy(Enum):
ALPHA = 'Alphabetical Order'
FREQ = 'Frequency'
LEN = 'Length'
class SortOrder(Enum):
ASC = 'Ascending'
DESC = 'Descending'
class InterrogateMethod(Enum):
NONE = 0
PREFILL = 1
OVERWRITE = 2
PREPEND = 3
APPEND = 4
def __init__(self):
# from modules.textual_inversion.dataset
self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.dataset = ds.Dataset()
self.img_idx = dict()
self.tag_counts = {}
self.dataset_dir = ''
self.images = {}
def get_tag_list(self):
if len(self.tag_counts) == 0:
self.construct_tag_counts()
return [key for key in self.tag_counts.keys()]
def get_tag_set(self):
if len(self.tag_counts) == 0:
self.construct_tag_counts()
return {key for key in self.tag_counts.keys()}
def get_tags_by_image_path(self, imgpath:str):
return self.dataset.get_data_tags(imgpath)
def set_tags_by_image_path(self, imgpath:str, tags:List[str]):
self.dataset.append_data(ds.Data(imgpath, ','.join(tags)))
self.construct_tag_counts()
def write_tags(self, tags:List[str]):
if tags:
return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag]
else:
return []
def read_tags(self, tags:List[str]):
if tags:
tags = [re_tags.match(tag).group(1) for tag in tags if tag]
return [t for t in tags if t]
else:
return []
def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC):
sort_by = self.SortBy(sort_by)
sort_order = self.SortOrder(sort_order)
if sort_by == self.SortBy.ALPHA:
if sort_order == self.SortOrder.ASC:
return sorted(tags, reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, reverse=True)
elif sort_by == self.SortBy.FREQ:
if sort_order == self.SortOrder.ASC:
return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False)
elif sort_by == self.SortBy.LEN:
if sort_order == self.SortOrder.ASC:
return sorted(tags, key=lambda t:(len(t), t), reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, key=lambda t:(-len(t), t), reverse=False)
return list(tags)
def get_filtered_imgpaths(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
img_paths = sorted(filtered_set.datas.keys())
return img_paths
def get_filtered_imgs(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
img_paths = sorted(filtered_set.datas.keys())
return [self.images.get(path) for path in img_paths]
def get_filtered_imgindices(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
img_paths = sorted(filtered_set.datas.keys())
return [self.img_idx.get(p) for p in img_paths]
def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False):
if filter_tags:
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
tags:Set[str] = filtered_set.get_tagset()
else:
tags:Set[str] = self.dataset.get_tagset()
result = set()
try:
for tag in tags:
if prefix:
if regex:
if re.search("^" + filter_word, tag) is not None:
result.add(tag)
continue
else:
if tag.startswith(filter_word):
result.add(tag)
continue
if suffix:
if regex:
if re.search(filter_word + "$", tag) is not None:
result.add(tag)
continue
else:
if tag.endswith(filter_word):
result.add(tag)
continue
if not prefix and not suffix:
if regex:
if re.search(filter_word, tag) is not None:
result.add(tag)
continue
else:
if filter_word in tag:
result.add(tag)
continue
except:
return tags
else:
return result
def cleanup_tags(self, tags:List[str]):
current_dataset_tags = self.dataset.get_tagset()
return [t for t in tags if t in current_dataset_tags]
def cleanup_tagset(self, tags:Set[str]):
current_dataset_tagset = self.dataset.get_tagset()
return tags & current_dataset_tagset
def get_common_tags(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
result = filtered_set.get_tagset()
for d in filtered_set.datas.values():
result &= d.tagset
return sorted(result)
def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False):
img_paths = self.get_filtered_imgpaths(filters=filters)
tags_to_append = replace_tags[len(search_tags):]
tags_to_remove = search_tags[len(replace_tags):]
tags_to_replace = {}
for i in range(min(len(search_tags), len(replace_tags))):
if replace_tags[i] is None or replace_tags[i] == '':
tags_to_remove.append(search_tags[i])
else:
tags_to_replace[search_tags[i]] = replace_tags[i]
for img_path in img_paths:
tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove]
tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed]
self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append)
self.construct_tag_counts()
def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]):
tags_to_remove = search_tags[len(replace_tags):]
tags_to_replace = {}
for i in range(min(len(search_tags), len(replace_tags))):
if replace_tags[i] is None or replace_tags[i] == '':
tags_to_remove.append(search_tags[i])
else:
tags_to_replace[search_tags[i]] = replace_tags[i]
tags_removed = {t for t in tags if t not in tags_to_remove}
tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed}
return {t for t in tags_replaced if t}
def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False):
img_paths = self.get_filtered_imgpaths(filters=filters)
for img_path in img_paths:
caption = ', '.join(self.dataset.get_data_tags(img_path))
if use_regex:
caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')]
else:
caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')]
caption = [t for t in caption if t]
self.set_tags_by_image_path(img_path, caption)
self.construct_tag_counts()
def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False):
img_paths = self.get_filtered_imgpaths(filters=filters)
for img_path in img_paths:
tags = self.dataset.get_data_tags(img_path)
tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex)
self.set_tags_by_image_path(img_path, tags)
self.construct_tag_counts()
def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
if use_regex:
if selected_tags is None:
tags = [re.sub(search_text, replace_text, t) for t in tags]
else:
tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags]
else:
if selected_tags is None:
tags = [t.replace(search_text, replace_text) for t in tags]
else:
tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags]
tags = [t2 for t1 in tags for t2 in t1.split(',') if t2]
return [t for t in tags if t]
def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
if use_regex:
if selected_tags is None:
tags = {re.sub(search_text, replace_text, t) for t in tags}
else:
tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags}
else:
if selected_tags is None:
tags = {t.replace(search_text, replace_text) for t in tags}
else:
tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags}
tags = {t2 for t1 in tags for t2 in t1.split(',') if t2}
return {t for t in tags if t}
def remove_duplicated_tags(self, filters:List[filters.Filter] = []):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
tags = self.dataset.get_data_tags(path)
res = []
for t in tags:
if t not in res:
res.append(t)
self.set_tags_by_image_path(path, res)
def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
res = self.dataset.get_data_tags(path)
res = [t for t in res if t not in tags]
self.set_tags_by_image_path(path, res)
def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
tags = self.dataset.get_data_tags(path)
res = self.sort_tags(tags, **sort_args)
self.set_tags_by_image_path(path, res)
def get_img_path_list(self):
return [k for k in self.dataset.datas.keys() if k]
def get_img_path_set(self):
return {k for k in self.dataset.datas.keys() if k}
def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
for path in filtered_set.datas.keys():
self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup)
if delete_image:
self.dataset.remove(filtered_set)
self.construct_tag_counts()
def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
for path in filtered_set.datas.keys():
self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup)
if move_image:
self.construct_tag_counts()
def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
if img_path not in self.dataset.datas.keys():
return
img_path_obj = Path(img_path)
if delete_image:
try:
if img_path_obj.is_file():
if img_path in self.images:
self.images[img_path].close()
del self.images[img_path]
img_path_obj.unlink()
self.dataset.remove_by_path(img_path)
print(f'[tag-editor] Deleted {img_path_obj.absolute()}')
except Exception as e:
print(e)
if delete_caption:
try:
txt_path_obj = img_path_obj.with_suffix(caption_ext)
if txt_path_obj.is_file():
txt_path_obj.unlink()
print(f'[tag-editor] Deleted {txt_path_obj.absolute()}')
except Exception as e:
print(e)
if delete_backup:
try:
for extnum in range(1000):
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
if bak_path_obj.is_file():
bak_path_obj.unlink()
print(f'[tag-editor] Deleted {bak_path_obj.absolute()}')
except Exception as e:
print(e)
def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
if img_path not in self.dataset.datas.keys():
return
img_path_obj = Path(img_path)
dest_dir_obj = Path(dest_dir)
if (move_image or move_caption or move_backup) and not dest_dir_obj.exists():
dest_dir_obj.mkdir()
if move_image:
try:
dst_path_obj = dest_dir_obj / img_path_obj.name
if img_path_obj.is_file():
if img_path in self.images:
self.images[img_path].close()
del self.images[img_path]
img_path_obj.replace(dst_path_obj)
self.dataset.remove_by_path(img_path)
print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}')
except Exception as e:
print(e)
if move_caption:
try:
txt_path_obj = img_path_obj.with_suffix(caption_ext)
dst_path_obj = dest_dir_obj / txt_path_obj.name
if txt_path_obj.is_file():
txt_path_obj.replace(dst_path_obj)
print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}')
except Exception as e:
print(e)
if move_backup:
try:
for extnum in range(1000):
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
dst_path_obj = dest_dir_obj / bak_path_obj.name
if bak_path_obj.is_file():
bak_path_obj.replace(dst_path_obj)
print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}')
except Exception as e:
print(e)
def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str]):
self.clear()
img_dir_obj = Path(img_dir)
print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}')
if recursive:
print(f'[tag-editor] Also loading from subdirectories.')
try:
filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*')
filepaths = [p for p in filepaths if p.is_file()]
except Exception as e:
print(e)
print('[tag-editor] Loading Aborted.')
return
self.dataset_dir = img_dir
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
def load_images(filepaths:List[Path]):
imgpaths = []
images = {}
for img_path in filepaths:
if img_path.suffix == caption_ext:
continue
try:
img = Image.open(img_path)
except:
continue
else:
abs_path = str(img_path.absolute())
if not use_temp_dir:
img.already_saved_as = abs_path
images[abs_path] = img
imgpaths.append(abs_path)
return imgpaths, images
def load_captions(imgpaths:List[str]):
taglists = []
for abs_path in imgpaths:
img_path = Path(abs_path)
text_path = img_path.with_suffix(caption_ext)
caption_text = ''
if interrogate_method != self.InterrogateMethod.OVERWRITE:
# from modules/textual_inversion/dataset.py, modified
if text_path.is_file():
caption_text = text_path.read_text('utf8')
elif load_caption_from_filename:
caption_text = img_path.stem
caption_text = re.sub(re_numbers_at_start, '', caption_text)
if self.re_word:
tokens = self.re_word.findall(caption_text)
caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
caption_tags = [t.strip() for t in caption_text.split(',')]
caption_tags = [t for t in caption_tags if t]
taglists.append(caption_tags)
return taglists
try:
captionings = []
taggers = []
if interrogate_method != self.InterrogateMethod.NONE:
for it in INTERROGATORS:
if it.name() in interrogator_names:
it.start()
if isinstance(it, tagger.Tagger):
if isinstance(it, tagger.DeepDanbooru):
taggers.append((it, threshold_booru))
if isinstance(it, tagger.WaifuDiffusion):
taggers.append((it, threshold_waifu))
elif isinstance(it, captioning.Captioning):
captionings.append(it)
if kohya_json_path:
imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir)
else:
imgpaths, self.images = load_images(filepaths)
taglists = load_captions(imgpaths)
for img_path, tags in zip(imgpaths, taglists):
interrogate_tags = []
img = self.images.get(img_path)
if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not tags)):
if img is None:
print(f'Failed to load image {img_path}. Interrogating is aborted.')
else:
img = img.convert('RGB')
for cap in captionings:
interrogate_tags += cap.predict(img)
for tg, threshold in taggers:
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
if interrogate_method == self.InterrogateMethod.OVERWRITE:
tags = interrogate_tags
elif interrogate_method == self.InterrogateMethod.PREPEND:
tags = interrogate_tags + tags
else:
tags = tags + interrogate_tags
self.set_tags_by_image_path(img_path, tags)
finally:
if interrogate_method != self.InterrogateMethod.NONE:
for cap in captionings:
cap.stop()
for tg, _ in taggers:
tg.stop()
for i, p in enumerate(sorted(self.dataset.datas.keys())):
self.img_idx[p] = i
self.construct_tag_counts()
print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found')
def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool):
if len(self.dataset) == 0:
return (0, 0, '')
saved_num = 0
backup_num = 0
for data in self.dataset.datas.values():
img_path, tags = Path(data.imgpath), data.tags
txt_path = img_path.with_suffix(caption_ext)
# make backup
if backup and txt_path.is_file():
for extnum in range(1000):
bak_path = img_path.with_suffix(f'.{extnum:0>3d}')
if not bak_path.is_file():
break
else:
bak_path = None
if bak_path is None:
print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.")
else:
try:
txt_path.rename(bak_path)
except Exception as e:
print(e)
print(f"[tag-editor] A backup file of {txt_path} cannot be created.")
else:
backup_num += 1
# save
try:
txt_path.write_text(', '.join(tags), 'utf8')
except Exception as e:
print(e)
print(f"[tag-editor] Warning: {txt_path} cannot be saved.")
else:
saved_num += 1
print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}')
print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}')
if(write_kohya_metadata):
kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path)
print(f'[tag-editor] Saved json metadata file in {meta_out_path}')
return (saved_num, len(self.dataset), self.dataset_dir)
def clear(self):
self.dataset.clear()
self.tag_counts.clear()
self.img_idx.clear()
self.dataset_dir = ''
self.images = {}
def construct_tag_counts(self):
self.tag_counts = {}
for data in self.dataset.datas.values():
for tag in data.tags:
if tag in self.tag_counts.keys():
self.tag_counts[tag] += 1
else:
self.tag_counts[tag] = 1

View File

@ -0,0 +1,673 @@
from pathlib import Path
import re
from typing import List, Set, Optional
from enum import Enum
from PIL import Image
from modules import shared
from modules.textual_inversion.dataset import re_numbers_at_start
from scripts.singleton import Singleton
from . import tagger, captioning, filters, dataset as ds, kohya_finetune_metadata as kohya_metadata
from scripts.tokenizer import clip_tokenizer
WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"]
WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS]
re_tags = re.compile(r'^(.+?)( \[\d+\])?$')
def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd):
try:
img = Image.open(path).convert('RGB')
except:
return ''
else:
for it in INTERROGATORS:
if it.name() == interrogator_name:
if isinstance(it, tagger.DeepDanbooru):
with it as tg:
res = tg.predict(img, threshold_booru)
elif isinstance(it, tagger.WaifuDiffusion):
with it as tg:
res = tg.predict(img, threshold_wd)
else:
with it as cap:
res = cap.predict(img)
return ', '.join(res)
class DatasetTagEditor(Singleton):
class SortBy(Enum):
ALPHA = 'Alphabetical Order'
FREQ = 'Frequency'
LEN = 'Length'
TOKEN = 'Token Length'
class SortOrder(Enum):
ASC = 'Ascending'
DESC = 'Descending'
class InterrogateMethod(Enum):
NONE = 0
PREFILL = 1
OVERWRITE = 2
PREPEND = 3
APPEND = 4
def __init__(self):
# from modules.textual_inversion.dataset
self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.dataset = ds.Dataset()
self.img_idx = dict()
self.tag_counts = {}
self.dataset_dir = ''
self.images = {}
self.tag_tokens = {}
self.raw_clip_token_used = None
def get_tag_list(self):
if len(self.tag_counts) == 0:
self.construct_tag_infos()
return [key for key in self.tag_counts.keys()]
def get_tag_set(self):
if len(self.tag_counts) == 0:
self.construct_tag_infos()
return {key for key in self.tag_counts.keys()}
def get_tags_by_image_path(self, imgpath:str):
return self.dataset.get_data_tags(imgpath)
def set_tags_by_image_path(self, imgpath:str, tags:List[str]):
self.dataset.append_data(ds.Data(imgpath, ','.join(tags)))
self.construct_tag_infos()
def write_tags(self, tags:List[str], sort_by:SortBy=SortBy.FREQ):
sort_by = self.SortBy(sort_by)
if tags:
if sort_by == self.SortBy.FREQ:
return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag]
elif sort_by == self.SortBy.LEN:
return [f'{tag} [{len(tag)}]' for tag in tags if tag]
elif sort_by == self.SortBy.TOKEN:
return [f'{tag} [{self.tag_tokens.get(tag, (0, 0))[1]}]' for tag in tags if tag]
else:
return [f'{tag}' for tag in tags if tag]
else:
return []
def read_tags(self, tags:List[str]):
if tags:
tags = [re_tags.match(tag).group(1) for tag in tags if tag]
return [t for t in tags if t]
else:
return []
def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC):
sort_by = self.SortBy(sort_by)
sort_order = self.SortOrder(sort_order)
if sort_by == self.SortBy.ALPHA:
if sort_order == self.SortOrder.ASC:
return sorted(tags, reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, reverse=True)
elif sort_by == self.SortBy.FREQ:
if sort_order == self.SortOrder.ASC:
return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False)
elif sort_by == self.SortBy.LEN:
if sort_order == self.SortOrder.ASC:
return sorted(tags, key=lambda t:(len(t), t), reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, key=lambda t:(-len(t), t), reverse=False)
elif sort_by == self.SortBy.TOKEN:
if sort_order == self.SortOrder.ASC:
return sorted(tags, key=lambda t:(self.tag_tokens.get(t, (0, 0))[1], t), reverse=False)
elif sort_order == self.SortOrder.DESC:
return sorted(tags, key=lambda t:(-self.tag_tokens.get(t, (0, 0))[1], t), reverse=False)
return list(tags)
def get_filtered_imgpaths(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
img_paths = sorted(filtered_set.datas.keys())
return img_paths
def get_filtered_imgs(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
img_paths = sorted(filtered_set.datas.keys())
return [self.images.get(path) for path in img_paths]
def get_filtered_imgindices(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
img_paths = sorted(filtered_set.datas.keys())
return [self.img_idx.get(p) for p in img_paths]
def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False):
if filter_tags:
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
tags:Set[str] = filtered_set.get_tagset()
else:
tags:Set[str] = self.dataset.get_tagset()
result = set()
try:
for tag in tags:
if prefix:
if regex:
if re.search("^" + filter_word, tag) is not None:
result.add(tag)
continue
else:
if tag.startswith(filter_word):
result.add(tag)
continue
if suffix:
if regex:
if re.search(filter_word + "$", tag) is not None:
result.add(tag)
continue
else:
if tag.endswith(filter_word):
result.add(tag)
continue
if not prefix and not suffix:
if regex:
if re.search(filter_word, tag) is not None:
result.add(tag)
continue
else:
if filter_word in tag:
result.add(tag)
continue
except:
return tags
else:
return result
def cleanup_tags(self, tags:List[str]):
current_dataset_tags = self.dataset.get_tagset()
return [t for t in tags if t in current_dataset_tags]
def cleanup_tagset(self, tags:Set[str]):
current_dataset_tagset = self.dataset.get_tagset()
return tags & current_dataset_tagset
def get_common_tags(self, filters:List[filters.Filter] = []):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
result = filtered_set.get_tagset()
for d in filtered_set.datas.values():
result &= d.tagset
return sorted(result)
def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False):
img_paths = self.get_filtered_imgpaths(filters=filters)
tags_to_append = replace_tags[len(search_tags):]
tags_to_remove = search_tags[len(replace_tags):]
tags_to_replace = {}
for i in range(min(len(search_tags), len(replace_tags))):
if replace_tags[i] is None or replace_tags[i] == '':
tags_to_remove.append(search_tags[i])
else:
tags_to_replace[search_tags[i]] = replace_tags[i]
for img_path in img_paths:
tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove]
tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed]
self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append)
self.construct_tag_infos()
def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]):
tags_to_remove = search_tags[len(replace_tags):]
tags_to_replace = {}
for i in range(min(len(search_tags), len(replace_tags))):
if replace_tags[i] is None or replace_tags[i] == '':
tags_to_remove.append(search_tags[i])
else:
tags_to_replace[search_tags[i]] = replace_tags[i]
tags_removed = {t for t in tags if t not in tags_to_remove}
tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed}
return {t for t in tags_replaced if t}
def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False):
img_paths = self.get_filtered_imgpaths(filters=filters)
for img_path in img_paths:
caption = ', '.join(self.dataset.get_data_tags(img_path))
if use_regex:
caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')]
else:
caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')]
caption = [t for t in caption if t]
self.set_tags_by_image_path(img_path, caption)
self.construct_tag_infos()
def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False):
img_paths = self.get_filtered_imgpaths(filters=filters)
for img_path in img_paths:
tags = self.dataset.get_data_tags(img_path)
tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex)
self.set_tags_by_image_path(img_path, tags)
self.construct_tag_infos()
def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
if use_regex:
if selected_tags is None:
tags = [re.sub(search_text, replace_text, t) for t in tags]
else:
tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags]
else:
if selected_tags is None:
tags = [t.replace(search_text, replace_text) for t in tags]
else:
tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags]
tags = [t2 for t1 in tags for t2 in t1.split(',') if t2]
return [t for t in tags if t]
def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
if use_regex:
if selected_tags is None:
tags = {re.sub(search_text, replace_text, t) for t in tags}
else:
tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags}
else:
if selected_tags is None:
tags = {t.replace(search_text, replace_text) for t in tags}
else:
tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags}
tags = {t2 for t1 in tags for t2 in t1.split(',') if t2}
return {t for t in tags if t}
def remove_duplicated_tags(self, filters:List[filters.Filter] = []):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
tags = self.dataset.get_data_tags(path)
res = []
for t in tags:
if t not in res:
res.append(t)
self.set_tags_by_image_path(path, res)
def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
res = self.dataset.get_data_tags(path)
res = [t for t in res if t not in tags]
self.set_tags_by_image_path(path, res)
def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
tags = self.dataset.get_data_tags(path)
res = self.sort_tags(tags, **sort_args)
self.set_tags_by_image_path(path, res)
print(f'[tag-editor] Tags are sorted by {sort_args.get("sort_by").value} ({sort_args.get("sort_order").value})')
def truncate_filtered_tags_by_token_count(self, filters:List[filters.Filter] = [], max_token_count:int = 75):
img_paths = self.get_filtered_imgpaths(filters)
for path in img_paths:
tags = self.dataset.get_data_tags(path)
res = []
for tag in tags:
_, token_count = clip_tokenizer.tokenize(', '.join(res + [tag]), shared.opts.dataset_editor_use_raw_clip_token)
if token_count <= max_token_count:
res.append(tag)
else:
break
self.set_tags_by_image_path(path, res)
self.construct_tag_infos()
print(f'[tag-editor] Tags are truncated into token count <= {max_token_count}')
def get_img_path_list(self):
return [k for k in self.dataset.datas.keys() if k]
def get_img_path_set(self):
return {k for k in self.dataset.datas.keys() if k}
def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
for path in filtered_set.datas.keys():
self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup)
if delete_image:
self.dataset.remove(filtered_set)
self.construct_tag_infos()
def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
filtered_set = self.dataset.copy()
for filter in filters:
filtered_set.filter(filter)
for path in filtered_set.datas.keys():
self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup)
if move_image:
self.construct_tag_infos()
def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
if img_path not in self.dataset.datas.keys():
return
img_path_obj = Path(img_path)
if delete_image:
try:
if img_path_obj.is_file():
if img_path in self.images:
self.images[img_path].close()
del self.images[img_path]
img_path_obj.unlink()
self.dataset.remove_by_path(img_path)
print(f'[tag-editor] Deleted {img_path_obj.absolute()}')
except Exception as e:
print(e)
if delete_caption:
try:
txt_path_obj = img_path_obj.with_suffix(caption_ext)
if txt_path_obj.is_file():
txt_path_obj.unlink()
print(f'[tag-editor] Deleted {txt_path_obj.absolute()}')
except Exception as e:
print(e)
if delete_backup:
try:
for extnum in range(1000):
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
if bak_path_obj.is_file():
bak_path_obj.unlink()
print(f'[tag-editor] Deleted {bak_path_obj.absolute()}')
except Exception as e:
print(e)
def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
if img_path not in self.dataset.datas.keys():
return
img_path_obj = Path(img_path)
dest_dir_obj = Path(dest_dir)
if (move_image or move_caption or move_backup) and not dest_dir_obj.exists():
dest_dir_obj.mkdir()
if move_image:
try:
dst_path_obj = dest_dir_obj / img_path_obj.name
if img_path_obj.is_file():
if img_path in self.images:
self.images[img_path].close()
del self.images[img_path]
img_path_obj.replace(dst_path_obj)
self.dataset.remove_by_path(img_path)
print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}')
except Exception as e:
print(e)
if move_caption:
try:
txt_path_obj = img_path_obj.with_suffix(caption_ext)
dst_path_obj = dest_dir_obj / txt_path_obj.name
if txt_path_obj.is_file():
txt_path_obj.replace(dst_path_obj)
print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}')
except Exception as e:
print(e)
if move_backup:
try:
for extnum in range(1000):
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
dst_path_obj = dest_dir_obj / bak_path_obj.name
if bak_path_obj.is_file():
bak_path_obj.replace(dst_path_obj)
print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}')
except Exception as e:
print(e)
def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str]):
self.clear()
img_dir_obj = Path(img_dir)
print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}')
if recursive:
print(f'[tag-editor] Also loading from subdirectories.')
try:
filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*')
filepaths = [p for p in filepaths if p.is_file()]
except Exception as e:
print(e)
print('[tag-editor] Loading Aborted.')
return
self.dataset_dir = img_dir
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
def load_images(filepaths:List[Path]):
imgpaths = []
images = {}
for img_path in filepaths:
if img_path.suffix == caption_ext:
continue
try:
img = Image.open(img_path)
except:
continue
else:
abs_path = str(img_path.absolute())
if not use_temp_dir:
img.already_saved_as = abs_path
images[abs_path] = img
imgpaths.append(abs_path)
return imgpaths, images
def load_captions(imgpaths:List[str]):
taglists = []
for abs_path in imgpaths:
img_path = Path(abs_path)
text_path = img_path.with_suffix(caption_ext)
caption_text = ''
if interrogate_method != self.InterrogateMethod.OVERWRITE:
# from modules/textual_inversion/dataset.py, modified
if text_path.is_file():
caption_text = text_path.read_text('utf8')
elif load_caption_from_filename:
caption_text = img_path.stem
caption_text = re.sub(re_numbers_at_start, '', caption_text)
if self.re_word:
tokens = self.re_word.findall(caption_text)
caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
caption_tags = [t.strip() for t in caption_text.split(',')]
caption_tags = [t for t in caption_tags if t]
taglists.append(caption_tags)
return taglists
try:
captionings = []
taggers = []
if interrogate_method != self.InterrogateMethod.NONE:
for it in INTERROGATORS:
if it.name() in interrogator_names:
it.start()
if isinstance(it, tagger.Tagger):
if isinstance(it, tagger.DeepDanbooru):
taggers.append((it, threshold_booru))
if isinstance(it, tagger.WaifuDiffusion):
taggers.append((it, threshold_waifu))
elif isinstance(it, captioning.Captioning):
captionings.append(it)
if kohya_json_path:
imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir)
else:
imgpaths, self.images = load_images(filepaths)
taglists = load_captions(imgpaths)
for img_path, tags in zip(imgpaths, taglists):
interrogate_tags = []
img = self.images.get(img_path)
if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == self.InterrogateMethod.PREFILL and not tags)):
if img is None:
print(f'Failed to load image {img_path}. Interrogating is aborted.')
else:
img = img.convert('RGB')
for cap in captionings:
interrogate_tags += cap.predict(img)
for tg, threshold in taggers:
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
if interrogate_method == self.InterrogateMethod.OVERWRITE:
tags = interrogate_tags
elif interrogate_method == self.InterrogateMethod.PREPEND:
tags = interrogate_tags + tags
else:
tags = tags + interrogate_tags
self.set_tags_by_image_path(img_path, tags)
finally:
if interrogate_method != self.InterrogateMethod.NONE:
for cap in captionings:
cap.stop()
for tg, _ in taggers:
tg.stop()
for i, p in enumerate(sorted(self.dataset.datas.keys())):
self.img_idx[p] = i
self.construct_tag_infos()
print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found')
def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool):
if len(self.dataset) == 0:
return (0, 0, '')
saved_num = 0
backup_num = 0
for data in self.dataset.datas.values():
img_path, tags = Path(data.imgpath), data.tags
txt_path = img_path.with_suffix(caption_ext)
# make backup
if backup and txt_path.is_file():
for extnum in range(1000):
bak_path = img_path.with_suffix(f'.{extnum:0>3d}')
if not bak_path.is_file():
break
else:
bak_path = None
if bak_path is None:
print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.")
else:
try:
txt_path.rename(bak_path)
except Exception as e:
print(e)
print(f"[tag-editor] A backup file of {txt_path} cannot be created.")
else:
backup_num += 1
# save
try:
txt_path.write_text(', '.join(tags), 'utf8')
except Exception as e:
print(e)
print(f"[tag-editor] Warning: {txt_path} cannot be saved.")
else:
saved_num += 1
print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}')
print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}')
if(write_kohya_metadata):
kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path)
print(f'[tag-editor] Saved json metadata file in {meta_out_path}')
return (saved_num, len(self.dataset), self.dataset_dir)
def clear(self):
self.dataset.clear()
self.tag_counts.clear()
self.tag_tokens.clear()
self.img_idx.clear()
self.dataset_dir = ''
self.images.clear()
def construct_tag_infos(self):
self.tag_counts = {}
update_token_count = self.raw_clip_token_used is None or self.raw_clip_token_used != shared.opts.dataset_editor_use_raw_clip_token
if update_token_count:
self.tag_tokens.clear()
for data in self.dataset.datas.values():
for tag in data.tags:
if tag in self.tag_counts.keys():
self.tag_counts[tag] += 1
else:
self.tag_counts[tag] = 1
if tag not in self.tag_tokens:
self.tag_tokens[tag] = clip_tokenizer.tokenize(tag, shared.opts.dataset_editor_use_raw_clip_token)
self.raw_clip_token_used = shared.opts.dataset_editor_use_raw_clip_token

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Type, Dict, Any
from typing import NamedTuple, Type, Dict, Any
from modules import shared, script_callbacks, scripts
from modules.shared import opts
import gradio as gr
@ -39,15 +39,15 @@ GeneralConfig = namedtuple('GeneralConfig', [
'meta_use_full_path'
])
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order'])
EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order', 'token_count'])
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75)
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
class Config:
@ -252,10 +252,10 @@ def on_ui_tabs():
ui.batch_edit_captions.rb_sr_replace_target,
ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order,
ui.batch_edit_captions.nb_token_count
]
components_edit_selected = [
ui.edit_caption_of_selected_image.cb_use_raw_token,
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order
@ -339,6 +339,7 @@ def on_ui_settings():
section = ('dataset-tag-editor', "Dataset Tag Editor")
shared.opts.add_option("dataset_editor_image_columns", shared.OptionInfo(6, "Number of columns on image gallery", section=section))
shared.opts.add_option("dataset_editor_use_temp_files", shared.OptionInfo(False, "Force image gallery to use temporary files", section=section))
shared.opts.add_option("dataset_editor_use_raw_clip_token", shared.OptionInfo(True, "Use raw CLIP token to calculate token count (without emphasis or embeddings)", section=section))
script_callbacks.on_ui_settings(on_ui_settings)

View File

@ -0,0 +1,48 @@
# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
from typing import List
from functools import reduce
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
from modules import shared, extra_networks, prompt_parser
from modules.sd_hijack import model_hijack
import open_clip.tokenizer
class VanillaClip:
def __init__(self, clip):
self.clip = clip
def vocab(self):
return self.clip.tokenizer.get_vocab()
def byte_decoder(self):
return self.clip.tokenizer.byte_decoder
class OpenClip:
def __init__(self, clip):
self.clip = clip
self.tokenizer = open_clip.tokenizer._tokenizer
def vocab(self):
return self.tokenizer.encoder
def byte_decoder(self):
return self.tokenizer.byte_decoder
def tokenize(text:str, use_raw_clip:bool=True):
if use_raw_clip:
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
token_count = len(tokens)
else:
try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
except Exception:
prompt = text
token_chunks, token_count = model_hijack.clip.tokenize_line(prompt)
tokens = reduce(lambda list1, list2: list1+list2, [tc.tokens for tc in token_chunks])
return tokens, token_count
def get_target_token_count(token_count:int):
return model_hijack.clip.get_target_prompt_token_count(token_count)

View File

@ -127,8 +127,8 @@ class TagFilterUI():
tags_in_filter = dte_instance.sort_tags(tags=tags_in_filter, sort_by=self.sort_by, sort_order=self.sort_order)
tags = tags_in_filter + [tag for tag in tags if tag not in self.filter.tags]
tags = dte_instance.write_tags(tags)
tags_in_filter = dte_instance.write_tags(tags_in_filter)
tags = dte_instance.write_tags(tags, self.sort_by)
tags_in_filter = dte_instance.write_tags(tags_in_filter, self.sort_by)
return gr.CheckboxGroup.update(value=tags_in_filter, choices=tags)

View File

@ -6,6 +6,9 @@ from .ui_common import *
TagFilter = dte_module.filters.TagFilter
Filter = dte_module.filters.Filter
SortBy = dte_instance.SortBy
SortOrder = dte_instance.SortOrder
class TagSelectUI():
def __init__(self):
@ -32,8 +35,8 @@ class TagSelectUI():
self.cb_suffix = gr.Checkbox(label='Suffix', value=False, interactive=True)
self.cb_regex = gr.Checkbox(label='Use regex', value=False, interactive=True)
with gr.Row():
self.rb_sort_by = gr.Radio(choices=['Alphabetical Order', 'Frequency', 'Length'], value=sort_by, interactive=True, label='Sort by')
self.rb_sort_order = gr.Radio(choices=['Ascending', 'Descending'], value=sort_order, interactive=True, label='Sort Order')
self.rb_sort_by = gr.Radio(hoices=[e.value for e in SortBy], value=sort_by, interactive=True, label='Sort by')
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=sort_order, interactive=True, label='Sort Order')
with gr.Row():
self.btn_select_visibles = gr.Button(value='Select visible tags')
self.btn_deselect_visibles = gr.Button(value='Deselect visible tags')
@ -105,6 +108,6 @@ class TagSelectUI():
self.tags = set(dte_instance.get_filtered_tags(self.get_filters(), filter_tags=True, prefix=self.prefix, suffix=self.suffix, regex=self.regex))
self.selected_tags &= self.tags
tags = dte_instance.sort_tags(tags=tags, sort_by=self.sort_by, sort_order=self.sort_order)
tags = dte_instance.write_tags(tags)
selected_tags = dte_instance.write_tags(list(self.selected_tags))
tags = dte_instance.write_tags(tags, self.sort_by)
selected_tags = dte_instance.write_tags(list(self.selected_tags), self.sort_by)
return gr.CheckboxGroup.update(value=selected_tags, choices=tags)

View File

@ -67,7 +67,11 @@ class BatchEditCaptionsUI(UIBase):
with gr.Row():
self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_batch_edit.batch_sort_by, interactive=True, label='Sort by')
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_batch_edit.batch_sort_order, interactive=True, label='Sort Order')
self.btn_sort_selected = gr.Button(value='Sort selected tags', variant='primary')
self.btn_sort_selected = gr.Button(value='Sort tags', variant='primary')
with gr.Column(variant='panel'):
gr.HTML('Truncate tags by token count.')
self.nb_token_count = gr.Number(value=cfg_batch_edit.token_count, precision=0)
self.btn_truncate_by_token = gr.Button(value='Truncate tags by token count', variant='primary')
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], load_dataset:LoadDatasetUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
load_dataset.btn_load_datasets.click(
@ -173,6 +177,17 @@ class BatchEditCaptionsUI(UIBase):
inputs=self.cb_show_only_tags_selected
)
def truncate_by_token_count(token_count:int):
token_count = max(int(token_count), 0)
dte_instance.truncate_filtered_tags_by_token_count(get_filters(), token_count)
return update_filter_and_gallery()
self.btn_truncate_by_token.click(
fn=truncate_by_token_count,
inputs=self.nb_token_count,
outputs=o_update_filter_and_gallery
)
def get_common_tags(self, get_filters:Callable[[], List[dte_module.filters.Filter]], filter_by_tags:FilterByTagsUI):
if self.show_only_selected_tags:

View File

@ -1,16 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Callable
from functools import reduce
import gradio as gr
from modules import shared, extra_networks, prompt_parser
from modules import shared
from modules.call_queue import wrap_queued_call
from modules.sd_hijack import model_hijack
from scripts.dte_instance import dte_module
from .ui_common import *
from .uibase import UIBase
from .tokenizer import clip_tokenizer
from scripts.tokenizer import clip_tokenizer
if TYPE_CHECKING:
from .ui_classes import *
@ -30,7 +28,6 @@ class EditCaptionOfSelectedImageUI(UIBase):
with gr.Tab(label='Read Caption from Selected Image'):
self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption')
self.token_counter_caption = gr.HTML(value='<span></span>', elem_id='dte_caption_counter')
self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)')
with gr.Row():
self.btn_copy_caption = gr.Button(value='Copy and Overwrite')
self.btn_prepend_caption = gr.Button(value='Prepend')
@ -194,39 +191,27 @@ class EditCaptionOfSelectedImageUI(UIBase):
outputs=self.sort_settings
)
def update_token_counter(text:str, use_raw:bool):
if use_raw:
token_count = clip_tokenizer.token_count(text)
max_length = model_hijack.clip.get_target_prompt_token_count(token_count)
else:
try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
except Exception:
prompt = text
token_count, max_length = model_hijack.get_prompt_lengths(prompt)
def update_token_counter(text:str):
_, token_count = clip_tokenizer.tokenize(text, shared.opts.dataset_editor_use_raw_clip_token)
max_length = clip_tokenizer.get_target_token_count(token_count)
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
update_caption_token_counter_args = {
'fn' : wrap_queued_call(update_token_counter),
'inputs' : [self.tb_caption, self.cb_use_raw_token],
'inputs' : [self.tb_caption],
'outputs' : [self.token_counter_caption]
}
update_edit_caption_token_counter_args = {
'fn' : wrap_queued_call(update_token_counter),
'inputs' : [self.tb_edit_caption, self.cb_use_raw_token],
'inputs' : [self.tb_edit_caption],
'outputs' : [self.token_counter_edit_caption]
}
update_interrogate_token_counter_args = {
'fn' : wrap_queued_call(update_token_counter),
'inputs' : [self.tb_interrogate, self.cb_use_raw_token],
'inputs' : [self.tb_interrogate],
'outputs' : [self.token_counter_interrogate]
}
self.cb_use_raw_token.change(**update_caption_token_counter_args)
self.cb_use_raw_token.change(**update_edit_caption_token_counter_args)
self.cb_use_raw_token.change(**update_interrogate_token_counter_args)
self.tb_caption.change(**update_caption_token_counter_args)
self.tb_edit_caption.change(**update_edit_caption_token_counter_args)
self.tb_interrogate.change(**update_interrogate_token_counter_args)

View File

@ -69,7 +69,7 @@ class MoveOrDeleteFilesUI(UIBase):
img_path = dataset_gallery.selected_path
if img_path:
dte_instance.move_dataset_file(img_path, caption_ext, dest_dir, move_img, move_txt, move_bak)
dte_instance.construct_tag_counts()
dte_instance.construct_tag_infos()
elif target_data == 'All Displayed Ones':
dte_instance.move_dataset(dest_dir, caption_ext, get_filters(), move_img, move_txt, move_bak)
@ -95,7 +95,7 @@ class MoveOrDeleteFilesUI(UIBase):
img_path = dataset_gallery.selected_path
if img_path:
dte_instance.delete_dataset_file(img_path, delete_img, caption_ext, delete_txt, delete_bak)
dte_instance.construct_tag_counts()
dte_instance.construct_tag_infos()
elif target_data == 'All Displayed Ones':
dte_instance.delete_dataset(caption_ext, get_filters(), delete_img, delete_txt, delete_bak)

View File

@ -1,42 +0,0 @@
# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
from modules import shared
import open_clip.tokenizer
class VanillaClip:
def __init__(self, clip):
self.clip = clip
def vocab(self):
return self.clip.tokenizer.get_vocab()
def byte_decoder(self):
return self.clip.tokenizer.byte_decoder
class OpenClip:
def __init__(self, clip):
self.clip = clip
self.tokenizer = open_clip.tokenizer._tokenizer
def vocab(self):
return self.tokenizer.encoder
def byte_decoder(self):
return self.tokenizer.byte_decoder
def tokenize(text:str):
clip = shared.sd_model.cond_stage_model.wrapped
if isinstance(clip, FrozenCLIPEmbedder):
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
elif isinstance(clip, FrozenOpenCLIPEmbedder):
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
else:
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
return tokens
def token_count(text:str):
return len(tokenize(text))