From 43d8d8f65fb1c873b276d40784c04d1ed64c365a Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Tue, 7 Mar 2023 22:23:48 +0900 Subject: [PATCH] implement some features to manage token count (#52) Add: - Count and truncate by tokens amount (#40) - Sort tags by token count Change: - Move "use raw clip token..." setting to "Settings" tab - Show tag frequency, length or token count depending on the "Sort by" --- scripts/dataset_tag_editor/__init__.py | 629 +--------------- scripts/dataset_tag_editor/dte_logic.py | 673 ++++++++++++++++++ scripts/main.py | 15 +- scripts/{ui => }/tokenizer/__init__.py | 0 scripts/tokenizer/clip_tokenizer.py | 48 ++ scripts/ui/block_tag_filter.py | 4 +- scripts/ui/block_tag_select.py | 11 +- scripts/ui/tab_batch_edit_captions.py | 17 +- .../ui/tab_edit_caption_of_selected_image.py | 31 +- scripts/ui/tab_move_or_delete_files.py | 4 +- scripts/ui/tokenizer/clip_tokenizer.py | 42 -- 11 files changed, 766 insertions(+), 708 deletions(-) create mode 100644 scripts/dataset_tag_editor/dte_logic.py rename scripts/{ui => }/tokenizer/__init__.py (100%) create mode 100644 scripts/tokenizer/clip_tokenizer.py delete mode 100644 scripts/ui/tokenizer/clip_tokenizer.py diff --git a/scripts/dataset_tag_editor/__init__.py b/scripts/dataset_tag_editor/__init__.py index 798cc4e..f601d0f 100644 --- a/scripts/dataset_tag_editor/__init__.py +++ b/scripts/dataset_tag_editor/__init__.py @@ -1,633 +1,8 @@ -from pathlib import Path -import re -from typing import List, Set, Optional -from modules import shared -from modules.textual_inversion.dataset import re_numbers_at_start -from PIL import Image -from enum import Enum - -from scripts.singleton import Singleton - -from . import dataset as ds from . import tagger from . import captioning from . import filters -from . import kohya_finetune_metadata as kohya_metadata +from . import dataset as ds +from .dte_logic import DatasetTagEditor, INTERROGATOR_NAMES, interrogate_image __all__ = ["ds", "tagger", "captioning", "filters", "kohya_metadata", "INTERROGATOR_NAMES", "interrogate_image", "DatasetTagEditor"] - -re_tags = re.compile(r'^(.+) \[\d+\]$') - -WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"] -WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf - -INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)] -INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS] - - -def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd): - try: - img = Image.open(path).convert('RGB') - except: - return '' - else: - for it in INTERROGATORS: - if it.name() == interrogator_name: - if isinstance(it, tagger.DeepDanbooru): - with it as tg: - res = tg.predict(img, threshold_booru) - elif isinstance(it, tagger.WaifuDiffusion): - with it as tg: - res = tg.predict(img, threshold_wd) - else: - with it as cap: - res = cap.predict(img) - return ', '.join(res) - - -class DatasetTagEditor(Singleton): - class SortBy(Enum): - ALPHA = 'Alphabetical Order' - FREQ = 'Frequency' - LEN = 'Length' - - class SortOrder(Enum): - ASC = 'Ascending' - DESC = 'Descending' - - class InterrogateMethod(Enum): - NONE = 0 - PREFILL = 1 - OVERWRITE = 2 - PREPEND = 3 - APPEND = 4 - - def __init__(self): - # from modules.textual_inversion.dataset - self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None - self.dataset = ds.Dataset() - self.img_idx = dict() - self.tag_counts = {} - self.dataset_dir = '' - self.images = {} - - def get_tag_list(self): - if len(self.tag_counts) == 0: - self.construct_tag_counts() - return [key for key in self.tag_counts.keys()] - - - def get_tag_set(self): - if len(self.tag_counts) == 0: - self.construct_tag_counts() - return {key for key in self.tag_counts.keys()} - - - def get_tags_by_image_path(self, imgpath:str): - return self.dataset.get_data_tags(imgpath) - - - def set_tags_by_image_path(self, imgpath:str, tags:List[str]): - self.dataset.append_data(ds.Data(imgpath, ','.join(tags))) - self.construct_tag_counts() - - - def write_tags(self, tags:List[str]): - if tags: - return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag] - else: - return [] - - - def read_tags(self, tags:List[str]): - if tags: - tags = [re_tags.match(tag).group(1) for tag in tags if tag] - return [t for t in tags if t] - else: - return [] - - - def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC): - sort_by = self.SortBy(sort_by) - sort_order = self.SortOrder(sort_order) - if sort_by == self.SortBy.ALPHA: - if sort_order == self.SortOrder.ASC: - return sorted(tags, reverse=False) - elif sort_order == self.SortOrder.DESC: - return sorted(tags, reverse=True) - elif sort_by == self.SortBy.FREQ: - if sort_order == self.SortOrder.ASC: - return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False) - elif sort_order == self.SortOrder.DESC: - return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False) - elif sort_by == self.SortBy.LEN: - if sort_order == self.SortOrder.ASC: - return sorted(tags, key=lambda t:(len(t), t), reverse=False) - elif sort_order == self.SortOrder.DESC: - return sorted(tags, key=lambda t:(-len(t), t), reverse=False) - return list(tags) - - - def get_filtered_imgpaths(self, filters:List[filters.Filter] = []): - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - - img_paths = sorted(filtered_set.datas.keys()) - - return img_paths - - - def get_filtered_imgs(self, filters:List[filters.Filter] = []): - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - - img_paths = sorted(filtered_set.datas.keys()) - - return [self.images.get(path) for path in img_paths] - - - def get_filtered_imgindices(self, filters:List[filters.Filter] = []): - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - - img_paths = sorted(filtered_set.datas.keys()) - - return [self.img_idx.get(p) for p in img_paths] - - - def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False): - if filter_tags: - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - tags:Set[str] = filtered_set.get_tagset() - else: - tags:Set[str] = self.dataset.get_tagset() - - result = set() - try: - for tag in tags: - if prefix: - if regex: - if re.search("^" + filter_word, tag) is not None: - result.add(tag) - continue - else: - if tag.startswith(filter_word): - result.add(tag) - continue - if suffix: - if regex: - if re.search(filter_word + "$", tag) is not None: - result.add(tag) - continue - else: - if tag.endswith(filter_word): - result.add(tag) - continue - if not prefix and not suffix: - if regex: - if re.search(filter_word, tag) is not None: - result.add(tag) - continue - else: - if filter_word in tag: - result.add(tag) - continue - except: - return tags - else: - return result - - - def cleanup_tags(self, tags:List[str]): - current_dataset_tags = self.dataset.get_tagset() - return [t for t in tags if t in current_dataset_tags] - - def cleanup_tagset(self, tags:Set[str]): - current_dataset_tagset = self.dataset.get_tagset() - return tags & current_dataset_tagset - - - def get_common_tags(self, filters:List[filters.Filter] = []): - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - - result = filtered_set.get_tagset() - for d in filtered_set.datas.values(): - result &= d.tagset - - return sorted(result) - - - def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False): - img_paths = self.get_filtered_imgpaths(filters=filters) - tags_to_append = replace_tags[len(search_tags):] - tags_to_remove = search_tags[len(replace_tags):] - tags_to_replace = {} - for i in range(min(len(search_tags), len(replace_tags))): - if replace_tags[i] is None or replace_tags[i] == '': - tags_to_remove.append(search_tags[i]) - else: - tags_to_replace[search_tags[i]] = replace_tags[i] - for img_path in img_paths: - tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove] - tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed] - self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append) - - self.construct_tag_counts() - - def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]): - tags_to_remove = search_tags[len(replace_tags):] - tags_to_replace = {} - for i in range(min(len(search_tags), len(replace_tags))): - if replace_tags[i] is None or replace_tags[i] == '': - tags_to_remove.append(search_tags[i]) - else: - tags_to_replace[search_tags[i]] = replace_tags[i] - tags_removed = {t for t in tags if t not in tags_to_remove} - tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed} - return {t for t in tags_replaced if t} - - - def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False): - img_paths = self.get_filtered_imgpaths(filters=filters) - - for img_path in img_paths: - caption = ', '.join(self.dataset.get_data_tags(img_path)) - if use_regex: - caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')] - else: - caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')] - caption = [t for t in caption if t] - self.set_tags_by_image_path(img_path, caption) - - self.construct_tag_counts() - - - def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False): - img_paths = self.get_filtered_imgpaths(filters=filters) - - for img_path in img_paths: - tags = self.dataset.get_data_tags(img_path) - tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex) - self.set_tags_by_image_path(img_path, tags) - - self.construct_tag_counts() - - - def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False): - if use_regex: - if selected_tags is None: - tags = [re.sub(search_text, replace_text, t) for t in tags] - else: - tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags] - else: - if selected_tags is None: - tags = [t.replace(search_text, replace_text) for t in tags] - else: - tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags] - tags = [t2 for t1 in tags for t2 in t1.split(',') if t2] - return [t for t in tags if t] - - - def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False): - if use_regex: - if selected_tags is None: - tags = {re.sub(search_text, replace_text, t) for t in tags} - else: - tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags} - else: - if selected_tags is None: - tags = {t.replace(search_text, replace_text) for t in tags} - else: - tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags} - tags = {t2 for t1 in tags for t2 in t1.split(',') if t2} - return {t for t in tags if t} - - - def remove_duplicated_tags(self, filters:List[filters.Filter] = []): - img_paths = self.get_filtered_imgpaths(filters) - for path in img_paths: - tags = self.dataset.get_data_tags(path) - res = [] - for t in tags: - if t not in res: - res.append(t) - self.set_tags_by_image_path(path, res) - - - def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []): - img_paths = self.get_filtered_imgpaths(filters) - for path in img_paths: - res = self.dataset.get_data_tags(path) - res = [t for t in res if t not in tags] - self.set_tags_by_image_path(path, res) - - def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args): - img_paths = self.get_filtered_imgpaths(filters) - for path in img_paths: - tags = self.dataset.get_data_tags(path) - res = self.sort_tags(tags, **sort_args) - self.set_tags_by_image_path(path, res) - - def get_img_path_list(self): - return [k for k in self.dataset.datas.keys() if k] - - - def get_img_path_set(self): - return {k for k in self.dataset.datas.keys() if k} - - - def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False): - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - for path in filtered_set.datas.keys(): - self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup) - - if delete_image: - self.dataset.remove(filtered_set) - self.construct_tag_counts() - - - def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False): - filtered_set = self.dataset.copy() - for filter in filters: - filtered_set.filter(filter) - for path in filtered_set.datas.keys(): - self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup) - - if move_image: - self.construct_tag_counts() - - - def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False): - if img_path not in self.dataset.datas.keys(): - return - - img_path_obj = Path(img_path) - - if delete_image: - try: - if img_path_obj.is_file(): - if img_path in self.images: - self.images[img_path].close() - del self.images[img_path] - img_path_obj.unlink() - self.dataset.remove_by_path(img_path) - print(f'[tag-editor] Deleted {img_path_obj.absolute()}') - except Exception as e: - print(e) - - if delete_caption: - try: - txt_path_obj = img_path_obj.with_suffix(caption_ext) - if txt_path_obj.is_file(): - txt_path_obj.unlink() - print(f'[tag-editor] Deleted {txt_path_obj.absolute()}') - except Exception as e: - print(e) - - if delete_backup: - try: - for extnum in range(1000): - bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}') - if bak_path_obj.is_file(): - bak_path_obj.unlink() - print(f'[tag-editor] Deleted {bak_path_obj.absolute()}') - except Exception as e: - print(e) - - - def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False): - if img_path not in self.dataset.datas.keys(): - return - - img_path_obj = Path(img_path) - dest_dir_obj = Path(dest_dir) - - if (move_image or move_caption or move_backup) and not dest_dir_obj.exists(): - dest_dir_obj.mkdir() - - if move_image: - try: - dst_path_obj = dest_dir_obj / img_path_obj.name - if img_path_obj.is_file(): - if img_path in self.images: - self.images[img_path].close() - del self.images[img_path] - img_path_obj.replace(dst_path_obj) - self.dataset.remove_by_path(img_path) - print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}') - except Exception as e: - print(e) - - if move_caption: - try: - txt_path_obj = img_path_obj.with_suffix(caption_ext) - dst_path_obj = dest_dir_obj / txt_path_obj.name - if txt_path_obj.is_file(): - txt_path_obj.replace(dst_path_obj) - print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}') - except Exception as e: - print(e) - - if move_backup: - try: - for extnum in range(1000): - bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}') - dst_path_obj = dest_dir_obj / bak_path_obj.name - if bak_path_obj.is_file(): - bak_path_obj.replace(dst_path_obj) - print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}') - except Exception as e: - print(e) - - - def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str]): - self.clear() - - img_dir_obj = Path(img_dir) - - print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}') - if recursive: - print(f'[tag-editor] Also loading from subdirectories.') - - try: - filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*') - filepaths = [p for p in filepaths if p.is_file()] - except Exception as e: - print(e) - print('[tag-editor] Loading Aborted.') - return - - self.dataset_dir = img_dir - - print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.') - - def load_images(filepaths:List[Path]): - imgpaths = [] - images = {} - for img_path in filepaths: - if img_path.suffix == caption_ext: - continue - try: - img = Image.open(img_path) - except: - continue - else: - abs_path = str(img_path.absolute()) - if not use_temp_dir: - img.already_saved_as = abs_path - images[abs_path] = img - - imgpaths.append(abs_path) - return imgpaths, images - - def load_captions(imgpaths:List[str]): - taglists = [] - for abs_path in imgpaths: - img_path = Path(abs_path) - text_path = img_path.with_suffix(caption_ext) - caption_text = '' - if interrogate_method != self.InterrogateMethod.OVERWRITE: - # from modules/textual_inversion/dataset.py, modified - if text_path.is_file(): - caption_text = text_path.read_text('utf8') - elif load_caption_from_filename: - caption_text = img_path.stem - caption_text = re.sub(re_numbers_at_start, '', caption_text) - if self.re_word: - tokens = self.re_word.findall(caption_text) - caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens) - - caption_tags = [t.strip() for t in caption_text.split(',')] - caption_tags = [t for t in caption_tags if t] - taglists.append(caption_tags) - - return taglists - - try: - captionings = [] - taggers = [] - if interrogate_method != self.InterrogateMethod.NONE: - for it in INTERROGATORS: - if it.name() in interrogator_names: - it.start() - if isinstance(it, tagger.Tagger): - if isinstance(it, tagger.DeepDanbooru): - taggers.append((it, threshold_booru)) - if isinstance(it, tagger.WaifuDiffusion): - taggers.append((it, threshold_waifu)) - elif isinstance(it, captioning.Captioning): - captionings.append(it) - - if kohya_json_path: - imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir) - else: - imgpaths, self.images = load_images(filepaths) - taglists = load_captions(imgpaths) - - for img_path, tags in zip(imgpaths, taglists): - interrogate_tags = [] - img = self.images.get(img_path) - if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not tags)): - if img is None: - print(f'Failed to load image {img_path}. Interrogating is aborted.') - else: - img = img.convert('RGB') - for cap in captionings: - interrogate_tags += cap.predict(img) - - for tg, threshold in taggers: - interrogate_tags += [t for t in tg.predict(img, threshold).keys()] - - if interrogate_method == self.InterrogateMethod.OVERWRITE: - tags = interrogate_tags - elif interrogate_method == self.InterrogateMethod.PREPEND: - tags = interrogate_tags + tags - else: - tags = tags + interrogate_tags - - self.set_tags_by_image_path(img_path, tags) - - finally: - if interrogate_method != self.InterrogateMethod.NONE: - for cap in captionings: - cap.stop() - for tg, _ in taggers: - tg.stop() - - for i, p in enumerate(sorted(self.dataset.datas.keys())): - self.img_idx[p] = i - - self.construct_tag_counts() - print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found') - - - def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool): - if len(self.dataset) == 0: - return (0, 0, '') - - saved_num = 0 - backup_num = 0 - for data in self.dataset.datas.values(): - img_path, tags = Path(data.imgpath), data.tags - txt_path = img_path.with_suffix(caption_ext) - # make backup - if backup and txt_path.is_file(): - for extnum in range(1000): - bak_path = img_path.with_suffix(f'.{extnum:0>3d}') - if not bak_path.is_file(): - break - else: - bak_path = None - if bak_path is None: - print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.") - else: - try: - txt_path.rename(bak_path) - except Exception as e: - print(e) - print(f"[tag-editor] A backup file of {txt_path} cannot be created.") - else: - backup_num += 1 - # save - try: - txt_path.write_text(', '.join(tags), 'utf8') - except Exception as e: - print(e) - print(f"[tag-editor] Warning: {txt_path} cannot be saved.") - else: - saved_num += 1 - print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}') - print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}') - - if(write_kohya_metadata): - kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path) - print(f'[tag-editor] Saved json metadata file in {meta_out_path}') - return (saved_num, len(self.dataset), self.dataset_dir) - - - def clear(self): - self.dataset.clear() - self.tag_counts.clear() - self.img_idx.clear() - self.dataset_dir = '' - self.images = {} - - - def construct_tag_counts(self): - self.tag_counts = {} - for data in self.dataset.datas.values(): - for tag in data.tags: - if tag in self.tag_counts.keys(): - self.tag_counts[tag] += 1 - else: - self.tag_counts[tag] = 1 diff --git a/scripts/dataset_tag_editor/dte_logic.py b/scripts/dataset_tag_editor/dte_logic.py new file mode 100644 index 0000000..7744715 --- /dev/null +++ b/scripts/dataset_tag_editor/dte_logic.py @@ -0,0 +1,673 @@ +from pathlib import Path +import re +from typing import List, Set, Optional +from enum import Enum +from PIL import Image + +from modules import shared +from modules.textual_inversion.dataset import re_numbers_at_start + +from scripts.singleton import Singleton + +from . import tagger, captioning, filters, dataset as ds, kohya_finetune_metadata as kohya_metadata +from scripts.tokenizer import clip_tokenizer + +WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"] +WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf + +INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)] +INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS] + +re_tags = re.compile(r'^(.+?)( \[\d+\])?$') + + +def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd): + try: + img = Image.open(path).convert('RGB') + except: + return '' + else: + for it in INTERROGATORS: + if it.name() == interrogator_name: + if isinstance(it, tagger.DeepDanbooru): + with it as tg: + res = tg.predict(img, threshold_booru) + elif isinstance(it, tagger.WaifuDiffusion): + with it as tg: + res = tg.predict(img, threshold_wd) + else: + with it as cap: + res = cap.predict(img) + return ', '.join(res) + + +class DatasetTagEditor(Singleton): + class SortBy(Enum): + ALPHA = 'Alphabetical Order' + FREQ = 'Frequency' + LEN = 'Length' + TOKEN = 'Token Length' + + class SortOrder(Enum): + ASC = 'Ascending' + DESC = 'Descending' + + class InterrogateMethod(Enum): + NONE = 0 + PREFILL = 1 + OVERWRITE = 2 + PREPEND = 3 + APPEND = 4 + + def __init__(self): + # from modules.textual_inversion.dataset + self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + self.dataset = ds.Dataset() + self.img_idx = dict() + self.tag_counts = {} + self.dataset_dir = '' + self.images = {} + self.tag_tokens = {} + self.raw_clip_token_used = None + + def get_tag_list(self): + if len(self.tag_counts) == 0: + self.construct_tag_infos() + return [key for key in self.tag_counts.keys()] + + + def get_tag_set(self): + if len(self.tag_counts) == 0: + self.construct_tag_infos() + return {key for key in self.tag_counts.keys()} + + + def get_tags_by_image_path(self, imgpath:str): + return self.dataset.get_data_tags(imgpath) + + + def set_tags_by_image_path(self, imgpath:str, tags:List[str]): + self.dataset.append_data(ds.Data(imgpath, ','.join(tags))) + self.construct_tag_infos() + + + def write_tags(self, tags:List[str], sort_by:SortBy=SortBy.FREQ): + sort_by = self.SortBy(sort_by) + if tags: + if sort_by == self.SortBy.FREQ: + return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag] + elif sort_by == self.SortBy.LEN: + return [f'{tag} [{len(tag)}]' for tag in tags if tag] + elif sort_by == self.SortBy.TOKEN: + return [f'{tag} [{self.tag_tokens.get(tag, (0, 0))[1]}]' for tag in tags if tag] + else: + return [f'{tag}' for tag in tags if tag] + else: + return [] + + + def read_tags(self, tags:List[str]): + if tags: + tags = [re_tags.match(tag).group(1) for tag in tags if tag] + return [t for t in tags if t] + else: + return [] + + + def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC): + sort_by = self.SortBy(sort_by) + sort_order = self.SortOrder(sort_order) + if sort_by == self.SortBy.ALPHA: + if sort_order == self.SortOrder.ASC: + return sorted(tags, reverse=False) + elif sort_order == self.SortOrder.DESC: + return sorted(tags, reverse=True) + elif sort_by == self.SortBy.FREQ: + if sort_order == self.SortOrder.ASC: + return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False) + elif sort_order == self.SortOrder.DESC: + return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False) + elif sort_by == self.SortBy.LEN: + if sort_order == self.SortOrder.ASC: + return sorted(tags, key=lambda t:(len(t), t), reverse=False) + elif sort_order == self.SortOrder.DESC: + return sorted(tags, key=lambda t:(-len(t), t), reverse=False) + elif sort_by == self.SortBy.TOKEN: + if sort_order == self.SortOrder.ASC: + return sorted(tags, key=lambda t:(self.tag_tokens.get(t, (0, 0))[1], t), reverse=False) + elif sort_order == self.SortOrder.DESC: + return sorted(tags, key=lambda t:(-self.tag_tokens.get(t, (0, 0))[1], t), reverse=False) + return list(tags) + + + def get_filtered_imgpaths(self, filters:List[filters.Filter] = []): + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + + img_paths = sorted(filtered_set.datas.keys()) + + return img_paths + + + def get_filtered_imgs(self, filters:List[filters.Filter] = []): + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + + img_paths = sorted(filtered_set.datas.keys()) + + return [self.images.get(path) for path in img_paths] + + + def get_filtered_imgindices(self, filters:List[filters.Filter] = []): + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + + img_paths = sorted(filtered_set.datas.keys()) + + return [self.img_idx.get(p) for p in img_paths] + + + def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False): + if filter_tags: + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + tags:Set[str] = filtered_set.get_tagset() + else: + tags:Set[str] = self.dataset.get_tagset() + + result = set() + try: + for tag in tags: + if prefix: + if regex: + if re.search("^" + filter_word, tag) is not None: + result.add(tag) + continue + else: + if tag.startswith(filter_word): + result.add(tag) + continue + if suffix: + if regex: + if re.search(filter_word + "$", tag) is not None: + result.add(tag) + continue + else: + if tag.endswith(filter_word): + result.add(tag) + continue + if not prefix and not suffix: + if regex: + if re.search(filter_word, tag) is not None: + result.add(tag) + continue + else: + if filter_word in tag: + result.add(tag) + continue + except: + return tags + else: + return result + + + def cleanup_tags(self, tags:List[str]): + current_dataset_tags = self.dataset.get_tagset() + return [t for t in tags if t in current_dataset_tags] + + def cleanup_tagset(self, tags:Set[str]): + current_dataset_tagset = self.dataset.get_tagset() + return tags & current_dataset_tagset + + + def get_common_tags(self, filters:List[filters.Filter] = []): + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + + result = filtered_set.get_tagset() + for d in filtered_set.datas.values(): + result &= d.tagset + + return sorted(result) + + + def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False): + img_paths = self.get_filtered_imgpaths(filters=filters) + tags_to_append = replace_tags[len(search_tags):] + tags_to_remove = search_tags[len(replace_tags):] + tags_to_replace = {} + for i in range(min(len(search_tags), len(replace_tags))): + if replace_tags[i] is None or replace_tags[i] == '': + tags_to_remove.append(search_tags[i]) + else: + tags_to_replace[search_tags[i]] = replace_tags[i] + for img_path in img_paths: + tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove] + tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed] + self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append) + + self.construct_tag_infos() + + def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]): + tags_to_remove = search_tags[len(replace_tags):] + tags_to_replace = {} + for i in range(min(len(search_tags), len(replace_tags))): + if replace_tags[i] is None or replace_tags[i] == '': + tags_to_remove.append(search_tags[i]) + else: + tags_to_replace[search_tags[i]] = replace_tags[i] + tags_removed = {t for t in tags if t not in tags_to_remove} + tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed} + return {t for t in tags_replaced if t} + + + def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False): + img_paths = self.get_filtered_imgpaths(filters=filters) + + for img_path in img_paths: + caption = ', '.join(self.dataset.get_data_tags(img_path)) + if use_regex: + caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')] + else: + caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')] + caption = [t for t in caption if t] + self.set_tags_by_image_path(img_path, caption) + + self.construct_tag_infos() + + + def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False): + img_paths = self.get_filtered_imgpaths(filters=filters) + + for img_path in img_paths: + tags = self.dataset.get_data_tags(img_path) + tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex) + self.set_tags_by_image_path(img_path, tags) + + self.construct_tag_infos() + + + def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False): + if use_regex: + if selected_tags is None: + tags = [re.sub(search_text, replace_text, t) for t in tags] + else: + tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags] + else: + if selected_tags is None: + tags = [t.replace(search_text, replace_text) for t in tags] + else: + tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags] + tags = [t2 for t1 in tags for t2 in t1.split(',') if t2] + return [t for t in tags if t] + + + def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False): + if use_regex: + if selected_tags is None: + tags = {re.sub(search_text, replace_text, t) for t in tags} + else: + tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags} + else: + if selected_tags is None: + tags = {t.replace(search_text, replace_text) for t in tags} + else: + tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags} + tags = {t2 for t1 in tags for t2 in t1.split(',') if t2} + return {t for t in tags if t} + + + def remove_duplicated_tags(self, filters:List[filters.Filter] = []): + img_paths = self.get_filtered_imgpaths(filters) + for path in img_paths: + tags = self.dataset.get_data_tags(path) + res = [] + for t in tags: + if t not in res: + res.append(t) + self.set_tags_by_image_path(path, res) + + + def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []): + img_paths = self.get_filtered_imgpaths(filters) + for path in img_paths: + res = self.dataset.get_data_tags(path) + res = [t for t in res if t not in tags] + self.set_tags_by_image_path(path, res) + + + def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args): + img_paths = self.get_filtered_imgpaths(filters) + for path in img_paths: + tags = self.dataset.get_data_tags(path) + res = self.sort_tags(tags, **sort_args) + self.set_tags_by_image_path(path, res) + print(f'[tag-editor] Tags are sorted by {sort_args.get("sort_by").value} ({sort_args.get("sort_order").value})') + + + def truncate_filtered_tags_by_token_count(self, filters:List[filters.Filter] = [], max_token_count:int = 75): + img_paths = self.get_filtered_imgpaths(filters) + for path in img_paths: + tags = self.dataset.get_data_tags(path) + res = [] + for tag in tags: + _, token_count = clip_tokenizer.tokenize(', '.join(res + [tag]), shared.opts.dataset_editor_use_raw_clip_token) + if token_count <= max_token_count: + res.append(tag) + else: + break + self.set_tags_by_image_path(path, res) + + self.construct_tag_infos() + print(f'[tag-editor] Tags are truncated into token count <= {max_token_count}') + + + def get_img_path_list(self): + return [k for k in self.dataset.datas.keys() if k] + + + def get_img_path_set(self): + return {k for k in self.dataset.datas.keys() if k} + + + def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False): + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + for path in filtered_set.datas.keys(): + self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup) + + if delete_image: + self.dataset.remove(filtered_set) + self.construct_tag_infos() + + + def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False): + filtered_set = self.dataset.copy() + for filter in filters: + filtered_set.filter(filter) + for path in filtered_set.datas.keys(): + self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup) + + if move_image: + self.construct_tag_infos() + + + def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False): + if img_path not in self.dataset.datas.keys(): + return + + img_path_obj = Path(img_path) + + if delete_image: + try: + if img_path_obj.is_file(): + if img_path in self.images: + self.images[img_path].close() + del self.images[img_path] + img_path_obj.unlink() + self.dataset.remove_by_path(img_path) + print(f'[tag-editor] Deleted {img_path_obj.absolute()}') + except Exception as e: + print(e) + + if delete_caption: + try: + txt_path_obj = img_path_obj.with_suffix(caption_ext) + if txt_path_obj.is_file(): + txt_path_obj.unlink() + print(f'[tag-editor] Deleted {txt_path_obj.absolute()}') + except Exception as e: + print(e) + + if delete_backup: + try: + for extnum in range(1000): + bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}') + if bak_path_obj.is_file(): + bak_path_obj.unlink() + print(f'[tag-editor] Deleted {bak_path_obj.absolute()}') + except Exception as e: + print(e) + + + def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False): + if img_path not in self.dataset.datas.keys(): + return + + img_path_obj = Path(img_path) + dest_dir_obj = Path(dest_dir) + + if (move_image or move_caption or move_backup) and not dest_dir_obj.exists(): + dest_dir_obj.mkdir() + + if move_image: + try: + dst_path_obj = dest_dir_obj / img_path_obj.name + if img_path_obj.is_file(): + if img_path in self.images: + self.images[img_path].close() + del self.images[img_path] + img_path_obj.replace(dst_path_obj) + self.dataset.remove_by_path(img_path) + print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}') + except Exception as e: + print(e) + + if move_caption: + try: + txt_path_obj = img_path_obj.with_suffix(caption_ext) + dst_path_obj = dest_dir_obj / txt_path_obj.name + if txt_path_obj.is_file(): + txt_path_obj.replace(dst_path_obj) + print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}') + except Exception as e: + print(e) + + if move_backup: + try: + for extnum in range(1000): + bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}') + dst_path_obj = dest_dir_obj / bak_path_obj.name + if bak_path_obj.is_file(): + bak_path_obj.replace(dst_path_obj) + print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}') + except Exception as e: + print(e) + + + def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str]): + self.clear() + + img_dir_obj = Path(img_dir) + + print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}') + if recursive: + print(f'[tag-editor] Also loading from subdirectories.') + + try: + filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*') + filepaths = [p for p in filepaths if p.is_file()] + except Exception as e: + print(e) + print('[tag-editor] Loading Aborted.') + return + + self.dataset_dir = img_dir + + print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.') + + def load_images(filepaths:List[Path]): + imgpaths = [] + images = {} + for img_path in filepaths: + if img_path.suffix == caption_ext: + continue + try: + img = Image.open(img_path) + except: + continue + else: + abs_path = str(img_path.absolute()) + if not use_temp_dir: + img.already_saved_as = abs_path + images[abs_path] = img + + imgpaths.append(abs_path) + return imgpaths, images + + def load_captions(imgpaths:List[str]): + taglists = [] + for abs_path in imgpaths: + img_path = Path(abs_path) + text_path = img_path.with_suffix(caption_ext) + caption_text = '' + if interrogate_method != self.InterrogateMethod.OVERWRITE: + # from modules/textual_inversion/dataset.py, modified + if text_path.is_file(): + caption_text = text_path.read_text('utf8') + elif load_caption_from_filename: + caption_text = img_path.stem + caption_text = re.sub(re_numbers_at_start, '', caption_text) + if self.re_word: + tokens = self.re_word.findall(caption_text) + caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens) + + caption_tags = [t.strip() for t in caption_text.split(',')] + caption_tags = [t for t in caption_tags if t] + taglists.append(caption_tags) + + return taglists + + try: + captionings = [] + taggers = [] + if interrogate_method != self.InterrogateMethod.NONE: + for it in INTERROGATORS: + if it.name() in interrogator_names: + it.start() + if isinstance(it, tagger.Tagger): + if isinstance(it, tagger.DeepDanbooru): + taggers.append((it, threshold_booru)) + if isinstance(it, tagger.WaifuDiffusion): + taggers.append((it, threshold_waifu)) + elif isinstance(it, captioning.Captioning): + captionings.append(it) + + if kohya_json_path: + imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir) + else: + imgpaths, self.images = load_images(filepaths) + taglists = load_captions(imgpaths) + + for img_path, tags in zip(imgpaths, taglists): + interrogate_tags = [] + img = self.images.get(img_path) + if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == self.InterrogateMethod.PREFILL and not tags)): + if img is None: + print(f'Failed to load image {img_path}. Interrogating is aborted.') + else: + img = img.convert('RGB') + for cap in captionings: + interrogate_tags += cap.predict(img) + + for tg, threshold in taggers: + interrogate_tags += [t for t in tg.predict(img, threshold).keys()] + + if interrogate_method == self.InterrogateMethod.OVERWRITE: + tags = interrogate_tags + elif interrogate_method == self.InterrogateMethod.PREPEND: + tags = interrogate_tags + tags + else: + tags = tags + interrogate_tags + + self.set_tags_by_image_path(img_path, tags) + + finally: + if interrogate_method != self.InterrogateMethod.NONE: + for cap in captionings: + cap.stop() + for tg, _ in taggers: + tg.stop() + + for i, p in enumerate(sorted(self.dataset.datas.keys())): + self.img_idx[p] = i + + self.construct_tag_infos() + print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found') + + + def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool): + if len(self.dataset) == 0: + return (0, 0, '') + + saved_num = 0 + backup_num = 0 + for data in self.dataset.datas.values(): + img_path, tags = Path(data.imgpath), data.tags + txt_path = img_path.with_suffix(caption_ext) + # make backup + if backup and txt_path.is_file(): + for extnum in range(1000): + bak_path = img_path.with_suffix(f'.{extnum:0>3d}') + if not bak_path.is_file(): + break + else: + bak_path = None + if bak_path is None: + print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.") + else: + try: + txt_path.rename(bak_path) + except Exception as e: + print(e) + print(f"[tag-editor] A backup file of {txt_path} cannot be created.") + else: + backup_num += 1 + # save + try: + txt_path.write_text(', '.join(tags), 'utf8') + except Exception as e: + print(e) + print(f"[tag-editor] Warning: {txt_path} cannot be saved.") + else: + saved_num += 1 + print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}') + print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}') + + if(write_kohya_metadata): + kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path) + print(f'[tag-editor] Saved json metadata file in {meta_out_path}') + return (saved_num, len(self.dataset), self.dataset_dir) + + + def clear(self): + self.dataset.clear() + self.tag_counts.clear() + self.tag_tokens.clear() + self.img_idx.clear() + self.dataset_dir = '' + self.images.clear() + + + def construct_tag_infos(self): + self.tag_counts = {} + update_token_count = self.raw_clip_token_used is None or self.raw_clip_token_used != shared.opts.dataset_editor_use_raw_clip_token + + if update_token_count: + self.tag_tokens.clear() + + for data in self.dataset.datas.values(): + for tag in data.tags: + if tag in self.tag_counts.keys(): + self.tag_counts[tag] += 1 + else: + self.tag_counts[tag] = 1 + if tag not in self.tag_tokens: + self.tag_tokens[tag] = clip_tokenizer.tokenize(tag, shared.opts.dataset_editor_use_raw_clip_token) + self.raw_clip_token_used = shared.opts.dataset_editor_use_raw_clip_token diff --git a/scripts/main.py b/scripts/main.py index 6a82528..ab85bc2 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple, Type, Dict, Any +from typing import NamedTuple, Type, Dict, Any from modules import shared, script_callbacks, scripts from modules.shared import opts import gradio as gr @@ -39,15 +39,15 @@ GeneralConfig = namedtuple('GeneralConfig', [ 'meta_use_full_path' ]) FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic']) -BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order']) -EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order']) +BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order', 'token_count']) +EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order']) MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination']) CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False) CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND') CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR') -CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value) -CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value) +CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75) +CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value) CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '') class Config: @@ -252,10 +252,10 @@ def on_ui_tabs(): ui.batch_edit_captions.rb_sr_replace_target, ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, ui.batch_edit_captions.tag_select_ui_remove.cb_regex, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order, - ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order + ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order, + ui.batch_edit_captions.nb_token_count ] components_edit_selected = [ - ui.edit_caption_of_selected_image.cb_use_raw_token, ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save, ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si, ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order @@ -339,6 +339,7 @@ def on_ui_settings(): section = ('dataset-tag-editor', "Dataset Tag Editor") shared.opts.add_option("dataset_editor_image_columns", shared.OptionInfo(6, "Number of columns on image gallery", section=section)) shared.opts.add_option("dataset_editor_use_temp_files", shared.OptionInfo(False, "Force image gallery to use temporary files", section=section)) + shared.opts.add_option("dataset_editor_use_raw_clip_token", shared.OptionInfo(True, "Use raw CLIP token to calculate token count (without emphasis or embeddings)", section=section)) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/scripts/ui/tokenizer/__init__.py b/scripts/tokenizer/__init__.py similarity index 100% rename from scripts/ui/tokenizer/__init__.py rename to scripts/tokenizer/__init__.py diff --git a/scripts/tokenizer/clip_tokenizer.py b/scripts/tokenizer/clip_tokenizer.py new file mode 100644 index 0000000..5d30b13 --- /dev/null +++ b/scripts/tokenizer/clip_tokenizer.py @@ -0,0 +1,48 @@ +# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified +# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py + +from typing import List +from functools import reduce +from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder +from modules import shared, extra_networks, prompt_parser +from modules.sd_hijack import model_hijack +import open_clip.tokenizer + +class VanillaClip: + def __init__(self, clip): + self.clip = clip + + def vocab(self): + return self.clip.tokenizer.get_vocab() + + def byte_decoder(self): + return self.clip.tokenizer.byte_decoder + +class OpenClip: + def __init__(self, clip): + self.clip = clip + self.tokenizer = open_clip.tokenizer._tokenizer + + def vocab(self): + return self.tokenizer.encoder + + def byte_decoder(self): + return self.tokenizer.byte_decoder + +def tokenize(text:str, use_raw_clip:bool=True): + if use_raw_clip: + tokens = shared.sd_model.cond_stage_model.tokenize([text])[0] + token_count = len(tokens) + else: + try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list) + except Exception: + prompt = text + token_chunks, token_count = model_hijack.clip.tokenize_line(prompt) + tokens = reduce(lambda list1, list2: list1+list2, [tc.tokens for tc in token_chunks]) + return tokens, token_count + +def get_target_token_count(token_count:int): + return model_hijack.clip.get_target_prompt_token_count(token_count) \ No newline at end of file diff --git a/scripts/ui/block_tag_filter.py b/scripts/ui/block_tag_filter.py index 23fd70a..99e7fc7 100644 --- a/scripts/ui/block_tag_filter.py +++ b/scripts/ui/block_tag_filter.py @@ -127,8 +127,8 @@ class TagFilterUI(): tags_in_filter = dte_instance.sort_tags(tags=tags_in_filter, sort_by=self.sort_by, sort_order=self.sort_order) tags = tags_in_filter + [tag for tag in tags if tag not in self.filter.tags] - tags = dte_instance.write_tags(tags) - tags_in_filter = dte_instance.write_tags(tags_in_filter) + tags = dte_instance.write_tags(tags, self.sort_by) + tags_in_filter = dte_instance.write_tags(tags_in_filter, self.sort_by) return gr.CheckboxGroup.update(value=tags_in_filter, choices=tags) diff --git a/scripts/ui/block_tag_select.py b/scripts/ui/block_tag_select.py index 47e4b42..04cb511 100644 --- a/scripts/ui/block_tag_select.py +++ b/scripts/ui/block_tag_select.py @@ -6,6 +6,9 @@ from .ui_common import * TagFilter = dte_module.filters.TagFilter Filter = dte_module.filters.Filter +SortBy = dte_instance.SortBy +SortOrder = dte_instance.SortOrder + class TagSelectUI(): def __init__(self): @@ -32,8 +35,8 @@ class TagSelectUI(): self.cb_suffix = gr.Checkbox(label='Suffix', value=False, interactive=True) self.cb_regex = gr.Checkbox(label='Use regex', value=False, interactive=True) with gr.Row(): - self.rb_sort_by = gr.Radio(choices=['Alphabetical Order', 'Frequency', 'Length'], value=sort_by, interactive=True, label='Sort by') - self.rb_sort_order = gr.Radio(choices=['Ascending', 'Descending'], value=sort_order, interactive=True, label='Sort Order') + self.rb_sort_by = gr.Radio(hoices=[e.value for e in SortBy], value=sort_by, interactive=True, label='Sort by') + self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=sort_order, interactive=True, label='Sort Order') with gr.Row(): self.btn_select_visibles = gr.Button(value='Select visible tags') self.btn_deselect_visibles = gr.Button(value='Deselect visible tags') @@ -105,6 +108,6 @@ class TagSelectUI(): self.tags = set(dte_instance.get_filtered_tags(self.get_filters(), filter_tags=True, prefix=self.prefix, suffix=self.suffix, regex=self.regex)) self.selected_tags &= self.tags tags = dte_instance.sort_tags(tags=tags, sort_by=self.sort_by, sort_order=self.sort_order) - tags = dte_instance.write_tags(tags) - selected_tags = dte_instance.write_tags(list(self.selected_tags)) + tags = dte_instance.write_tags(tags, self.sort_by) + selected_tags = dte_instance.write_tags(list(self.selected_tags), self.sort_by) return gr.CheckboxGroup.update(value=selected_tags, choices=tags) \ No newline at end of file diff --git a/scripts/ui/tab_batch_edit_captions.py b/scripts/ui/tab_batch_edit_captions.py index 3cd8e6b..15bed9b 100644 --- a/scripts/ui/tab_batch_edit_captions.py +++ b/scripts/ui/tab_batch_edit_captions.py @@ -67,7 +67,11 @@ class BatchEditCaptionsUI(UIBase): with gr.Row(): self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_batch_edit.batch_sort_by, interactive=True, label='Sort by') self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_batch_edit.batch_sort_order, interactive=True, label='Sort Order') - self.btn_sort_selected = gr.Button(value='Sort selected tags', variant='primary') + self.btn_sort_selected = gr.Button(value='Sort tags', variant='primary') + with gr.Column(variant='panel'): + gr.HTML('Truncate tags by token count.') + self.nb_token_count = gr.Number(value=cfg_batch_edit.token_count, precision=0) + self.btn_truncate_by_token = gr.Button(value='Truncate tags by token count', variant='primary') def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], load_dataset:LoadDatasetUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]): load_dataset.btn_load_datasets.click( @@ -173,6 +177,17 @@ class BatchEditCaptionsUI(UIBase): inputs=self.cb_show_only_tags_selected ) + def truncate_by_token_count(token_count:int): + token_count = max(int(token_count), 0) + dte_instance.truncate_filtered_tags_by_token_count(get_filters(), token_count) + return update_filter_and_gallery() + + self.btn_truncate_by_token.click( + fn=truncate_by_token_count, + inputs=self.nb_token_count, + outputs=o_update_filter_and_gallery + ) + def get_common_tags(self, get_filters:Callable[[], List[dte_module.filters.Filter]], filter_by_tags:FilterByTagsUI): if self.show_only_selected_tags: diff --git a/scripts/ui/tab_edit_caption_of_selected_image.py b/scripts/ui/tab_edit_caption_of_selected_image.py index 76ec283..6c1dffc 100644 --- a/scripts/ui/tab_edit_caption_of_selected_image.py +++ b/scripts/ui/tab_edit_caption_of_selected_image.py @@ -1,16 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Callable -from functools import reduce import gradio as gr -from modules import shared, extra_networks, prompt_parser +from modules import shared from modules.call_queue import wrap_queued_call -from modules.sd_hijack import model_hijack from scripts.dte_instance import dte_module from .ui_common import * from .uibase import UIBase -from .tokenizer import clip_tokenizer +from scripts.tokenizer import clip_tokenizer if TYPE_CHECKING: from .ui_classes import * @@ -30,7 +28,6 @@ class EditCaptionOfSelectedImageUI(UIBase): with gr.Tab(label='Read Caption from Selected Image'): self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption') self.token_counter_caption = gr.HTML(value='', elem_id='dte_caption_counter') - self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)') with gr.Row(): self.btn_copy_caption = gr.Button(value='Copy and Overwrite') self.btn_prepend_caption = gr.Button(value='Prepend') @@ -194,39 +191,27 @@ class EditCaptionOfSelectedImageUI(UIBase): outputs=self.sort_settings ) - def update_token_counter(text:str, use_raw:bool): - if use_raw: - token_count = clip_tokenizer.token_count(text) - max_length = model_hijack.clip.get_target_prompt_token_count(token_count) - else: - try: - text, _ = extra_networks.parse_prompt(text) - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list) - except Exception: - prompt = text - token_count, max_length = model_hijack.get_prompt_lengths(prompt) + def update_token_counter(text:str): + _, token_count = clip_tokenizer.tokenize(text, shared.opts.dataset_editor_use_raw_clip_token) + max_length = clip_tokenizer.get_target_token_count(token_count) return f"{token_count}/{max_length}" update_caption_token_counter_args = { 'fn' : wrap_queued_call(update_token_counter), - 'inputs' : [self.tb_caption, self.cb_use_raw_token], + 'inputs' : [self.tb_caption], 'outputs' : [self.token_counter_caption] } update_edit_caption_token_counter_args = { 'fn' : wrap_queued_call(update_token_counter), - 'inputs' : [self.tb_edit_caption, self.cb_use_raw_token], + 'inputs' : [self.tb_edit_caption], 'outputs' : [self.token_counter_edit_caption] } update_interrogate_token_counter_args = { 'fn' : wrap_queued_call(update_token_counter), - 'inputs' : [self.tb_interrogate, self.cb_use_raw_token], + 'inputs' : [self.tb_interrogate], 'outputs' : [self.token_counter_interrogate] } - self.cb_use_raw_token.change(**update_caption_token_counter_args) - self.cb_use_raw_token.change(**update_edit_caption_token_counter_args) - self.cb_use_raw_token.change(**update_interrogate_token_counter_args) self.tb_caption.change(**update_caption_token_counter_args) self.tb_edit_caption.change(**update_edit_caption_token_counter_args) self.tb_interrogate.change(**update_interrogate_token_counter_args) diff --git a/scripts/ui/tab_move_or_delete_files.py b/scripts/ui/tab_move_or_delete_files.py index 46aa68d..f22af9b 100644 --- a/scripts/ui/tab_move_or_delete_files.py +++ b/scripts/ui/tab_move_or_delete_files.py @@ -69,7 +69,7 @@ class MoveOrDeleteFilesUI(UIBase): img_path = dataset_gallery.selected_path if img_path: dte_instance.move_dataset_file(img_path, caption_ext, dest_dir, move_img, move_txt, move_bak) - dte_instance.construct_tag_counts() + dte_instance.construct_tag_infos() elif target_data == 'All Displayed Ones': dte_instance.move_dataset(dest_dir, caption_ext, get_filters(), move_img, move_txt, move_bak) @@ -95,7 +95,7 @@ class MoveOrDeleteFilesUI(UIBase): img_path = dataset_gallery.selected_path if img_path: dte_instance.delete_dataset_file(img_path, delete_img, caption_ext, delete_txt, delete_bak) - dte_instance.construct_tag_counts() + dte_instance.construct_tag_infos() elif target_data == 'All Displayed Ones': dte_instance.delete_dataset(caption_ext, get_filters(), delete_img, delete_txt, delete_bak) diff --git a/scripts/ui/tokenizer/clip_tokenizer.py b/scripts/ui/tokenizer/clip_tokenizer.py deleted file mode 100644 index d257eec..0000000 --- a/scripts/ui/tokenizer/clip_tokenizer.py +++ /dev/null @@ -1,42 +0,0 @@ -# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified -# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py - -from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder -from modules import shared -import open_clip.tokenizer - -class VanillaClip: - def __init__(self, clip): - self.clip = clip - - def vocab(self): - return self.clip.tokenizer.get_vocab() - - def byte_decoder(self): - return self.clip.tokenizer.byte_decoder - -class OpenClip: - def __init__(self, clip): - self.clip = clip - self.tokenizer = open_clip.tokenizer._tokenizer - - def vocab(self): - return self.tokenizer.encoder - - def byte_decoder(self): - return self.tokenizer.byte_decoder - -def tokenize(text:str): - clip = shared.sd_model.cond_stage_model.wrapped - if isinstance(clip, FrozenCLIPEmbedder): - clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped) - elif isinstance(clip, FrozenOpenCLIPEmbedder): - clip = OpenClip(shared.sd_model.cond_stage_model.wrapped) - else: - raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}') - - tokens = shared.sd_model.cond_stage_model.tokenize([text])[0] - return tokens - -def token_count(text:str): - return len(tokenize(text))