implement some features to manage token count (#52)
Add: - Count and truncate by tokens amount (#40) - Sort tags by token count Change: - Move "use raw clip token..." setting to "Settings" tab - Show tag frequency, length or token count depending on the "Sort by"pull/54/head
parent
94afc6e025
commit
43d8d8f65f
|
|
@ -1,633 +1,8 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
from typing import List, Set, Optional
|
||||
from modules import shared
|
||||
from modules.textual_inversion.dataset import re_numbers_at_start
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
|
||||
from scripts.singleton import Singleton
|
||||
|
||||
from . import dataset as ds
|
||||
from . import tagger
|
||||
from . import captioning
|
||||
from . import filters
|
||||
from . import kohya_finetune_metadata as kohya_metadata
|
||||
from . import dataset as ds
|
||||
|
||||
from .dte_logic import DatasetTagEditor, INTERROGATOR_NAMES, interrogate_image
|
||||
|
||||
__all__ = ["ds", "tagger", "captioning", "filters", "kohya_metadata", "INTERROGATOR_NAMES", "interrogate_image", "DatasetTagEditor"]
|
||||
|
||||
re_tags = re.compile(r'^(.+) \[\d+\]$')
|
||||
|
||||
WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"]
|
||||
WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
|
||||
INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
|
||||
INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS]
|
||||
|
||||
|
||||
def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd):
|
||||
try:
|
||||
img = Image.open(path).convert('RGB')
|
||||
except:
|
||||
return ''
|
||||
else:
|
||||
for it in INTERROGATORS:
|
||||
if it.name() == interrogator_name:
|
||||
if isinstance(it, tagger.DeepDanbooru):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_booru)
|
||||
elif isinstance(it, tagger.WaifuDiffusion):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_wd)
|
||||
else:
|
||||
with it as cap:
|
||||
res = cap.predict(img)
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
class DatasetTagEditor(Singleton):
|
||||
class SortBy(Enum):
|
||||
ALPHA = 'Alphabetical Order'
|
||||
FREQ = 'Frequency'
|
||||
LEN = 'Length'
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = 'Ascending'
|
||||
DESC = 'Descending'
|
||||
|
||||
class InterrogateMethod(Enum):
|
||||
NONE = 0
|
||||
PREFILL = 1
|
||||
OVERWRITE = 2
|
||||
PREPEND = 3
|
||||
APPEND = 4
|
||||
|
||||
def __init__(self):
|
||||
# from modules.textual_inversion.dataset
|
||||
self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
self.dataset = ds.Dataset()
|
||||
self.img_idx = dict()
|
||||
self.tag_counts = {}
|
||||
self.dataset_dir = ''
|
||||
self.images = {}
|
||||
|
||||
def get_tag_list(self):
|
||||
if len(self.tag_counts) == 0:
|
||||
self.construct_tag_counts()
|
||||
return [key for key in self.tag_counts.keys()]
|
||||
|
||||
|
||||
def get_tag_set(self):
|
||||
if len(self.tag_counts) == 0:
|
||||
self.construct_tag_counts()
|
||||
return {key for key in self.tag_counts.keys()}
|
||||
|
||||
|
||||
def get_tags_by_image_path(self, imgpath:str):
|
||||
return self.dataset.get_data_tags(imgpath)
|
||||
|
||||
|
||||
def set_tags_by_image_path(self, imgpath:str, tags:List[str]):
|
||||
self.dataset.append_data(ds.Data(imgpath, ','.join(tags)))
|
||||
self.construct_tag_counts()
|
||||
|
||||
|
||||
def write_tags(self, tags:List[str]):
|
||||
if tags:
|
||||
return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def read_tags(self, tags:List[str]):
|
||||
if tags:
|
||||
tags = [re_tags.match(tag).group(1) for tag in tags if tag]
|
||||
return [t for t in tags if t]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC):
|
||||
sort_by = self.SortBy(sort_by)
|
||||
sort_order = self.SortOrder(sort_order)
|
||||
if sort_by == self.SortBy.ALPHA:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, reverse=True)
|
||||
elif sort_by == self.SortBy.FREQ:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False)
|
||||
elif sort_by == self.SortBy.LEN:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, key=lambda t:(len(t), t), reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, key=lambda t:(-len(t), t), reverse=False)
|
||||
return list(tags)
|
||||
|
||||
|
||||
def get_filtered_imgpaths(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
img_paths = sorted(filtered_set.datas.keys())
|
||||
|
||||
return img_paths
|
||||
|
||||
|
||||
def get_filtered_imgs(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
img_paths = sorted(filtered_set.datas.keys())
|
||||
|
||||
return [self.images.get(path) for path in img_paths]
|
||||
|
||||
|
||||
def get_filtered_imgindices(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
img_paths = sorted(filtered_set.datas.keys())
|
||||
|
||||
return [self.img_idx.get(p) for p in img_paths]
|
||||
|
||||
|
||||
def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False):
|
||||
if filter_tags:
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
tags:Set[str] = filtered_set.get_tagset()
|
||||
else:
|
||||
tags:Set[str] = self.dataset.get_tagset()
|
||||
|
||||
result = set()
|
||||
try:
|
||||
for tag in tags:
|
||||
if prefix:
|
||||
if regex:
|
||||
if re.search("^" + filter_word, tag) is not None:
|
||||
result.add(tag)
|
||||
continue
|
||||
else:
|
||||
if tag.startswith(filter_word):
|
||||
result.add(tag)
|
||||
continue
|
||||
if suffix:
|
||||
if regex:
|
||||
if re.search(filter_word + "$", tag) is not None:
|
||||
result.add(tag)
|
||||
continue
|
||||
else:
|
||||
if tag.endswith(filter_word):
|
||||
result.add(tag)
|
||||
continue
|
||||
if not prefix and not suffix:
|
||||
if regex:
|
||||
if re.search(filter_word, tag) is not None:
|
||||
result.add(tag)
|
||||
continue
|
||||
else:
|
||||
if filter_word in tag:
|
||||
result.add(tag)
|
||||
continue
|
||||
except:
|
||||
return tags
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def cleanup_tags(self, tags:List[str]):
|
||||
current_dataset_tags = self.dataset.get_tagset()
|
||||
return [t for t in tags if t in current_dataset_tags]
|
||||
|
||||
def cleanup_tagset(self, tags:Set[str]):
|
||||
current_dataset_tagset = self.dataset.get_tagset()
|
||||
return tags & current_dataset_tagset
|
||||
|
||||
|
||||
def get_common_tags(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
result = filtered_set.get_tagset()
|
||||
for d in filtered_set.datas.values():
|
||||
result &= d.tagset
|
||||
|
||||
return sorted(result)
|
||||
|
||||
|
||||
def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False):
|
||||
img_paths = self.get_filtered_imgpaths(filters=filters)
|
||||
tags_to_append = replace_tags[len(search_tags):]
|
||||
tags_to_remove = search_tags[len(replace_tags):]
|
||||
tags_to_replace = {}
|
||||
for i in range(min(len(search_tags), len(replace_tags))):
|
||||
if replace_tags[i] is None or replace_tags[i] == '':
|
||||
tags_to_remove.append(search_tags[i])
|
||||
else:
|
||||
tags_to_replace[search_tags[i]] = replace_tags[i]
|
||||
for img_path in img_paths:
|
||||
tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove]
|
||||
tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed]
|
||||
self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append)
|
||||
|
||||
self.construct_tag_counts()
|
||||
|
||||
def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]):
|
||||
tags_to_remove = search_tags[len(replace_tags):]
|
||||
tags_to_replace = {}
|
||||
for i in range(min(len(search_tags), len(replace_tags))):
|
||||
if replace_tags[i] is None or replace_tags[i] == '':
|
||||
tags_to_remove.append(search_tags[i])
|
||||
else:
|
||||
tags_to_replace[search_tags[i]] = replace_tags[i]
|
||||
tags_removed = {t for t in tags if t not in tags_to_remove}
|
||||
tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed}
|
||||
return {t for t in tags_replaced if t}
|
||||
|
||||
|
||||
def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False):
|
||||
img_paths = self.get_filtered_imgpaths(filters=filters)
|
||||
|
||||
for img_path in img_paths:
|
||||
caption = ', '.join(self.dataset.get_data_tags(img_path))
|
||||
if use_regex:
|
||||
caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')]
|
||||
else:
|
||||
caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')]
|
||||
caption = [t for t in caption if t]
|
||||
self.set_tags_by_image_path(img_path, caption)
|
||||
|
||||
self.construct_tag_counts()
|
||||
|
||||
|
||||
def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False):
|
||||
img_paths = self.get_filtered_imgpaths(filters=filters)
|
||||
|
||||
for img_path in img_paths:
|
||||
tags = self.dataset.get_data_tags(img_path)
|
||||
tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex)
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
|
||||
self.construct_tag_counts()
|
||||
|
||||
|
||||
def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
|
||||
if use_regex:
|
||||
if selected_tags is None:
|
||||
tags = [re.sub(search_text, replace_text, t) for t in tags]
|
||||
else:
|
||||
tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags]
|
||||
else:
|
||||
if selected_tags is None:
|
||||
tags = [t.replace(search_text, replace_text) for t in tags]
|
||||
else:
|
||||
tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags]
|
||||
tags = [t2 for t1 in tags for t2 in t1.split(',') if t2]
|
||||
return [t for t in tags if t]
|
||||
|
||||
|
||||
def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
|
||||
if use_regex:
|
||||
if selected_tags is None:
|
||||
tags = {re.sub(search_text, replace_text, t) for t in tags}
|
||||
else:
|
||||
tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags}
|
||||
else:
|
||||
if selected_tags is None:
|
||||
tags = {t.replace(search_text, replace_text) for t in tags}
|
||||
else:
|
||||
tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags}
|
||||
tags = {t2 for t1 in tags for t2 in t1.split(',') if t2}
|
||||
return {t for t in tags if t}
|
||||
|
||||
|
||||
def remove_duplicated_tags(self, filters:List[filters.Filter] = []):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
tags = self.dataset.get_data_tags(path)
|
||||
res = []
|
||||
for t in tags:
|
||||
if t not in res:
|
||||
res.append(t)
|
||||
self.set_tags_by_image_path(path, res)
|
||||
|
||||
|
||||
def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
res = self.dataset.get_data_tags(path)
|
||||
res = [t for t in res if t not in tags]
|
||||
self.set_tags_by_image_path(path, res)
|
||||
|
||||
def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
tags = self.dataset.get_data_tags(path)
|
||||
res = self.sort_tags(tags, **sort_args)
|
||||
self.set_tags_by_image_path(path, res)
|
||||
|
||||
def get_img_path_list(self):
|
||||
return [k for k in self.dataset.datas.keys() if k]
|
||||
|
||||
|
||||
def get_img_path_set(self):
|
||||
return {k for k in self.dataset.datas.keys() if k}
|
||||
|
||||
|
||||
def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
for path in filtered_set.datas.keys():
|
||||
self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup)
|
||||
|
||||
if delete_image:
|
||||
self.dataset.remove(filtered_set)
|
||||
self.construct_tag_counts()
|
||||
|
||||
|
||||
def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
for path in filtered_set.datas.keys():
|
||||
self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup)
|
||||
|
||||
if move_image:
|
||||
self.construct_tag_counts()
|
||||
|
||||
|
||||
def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
|
||||
if img_path not in self.dataset.datas.keys():
|
||||
return
|
||||
|
||||
img_path_obj = Path(img_path)
|
||||
|
||||
if delete_image:
|
||||
try:
|
||||
if img_path_obj.is_file():
|
||||
if img_path in self.images:
|
||||
self.images[img_path].close()
|
||||
del self.images[img_path]
|
||||
img_path_obj.unlink()
|
||||
self.dataset.remove_by_path(img_path)
|
||||
print(f'[tag-editor] Deleted {img_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if delete_caption:
|
||||
try:
|
||||
txt_path_obj = img_path_obj.with_suffix(caption_ext)
|
||||
if txt_path_obj.is_file():
|
||||
txt_path_obj.unlink()
|
||||
print(f'[tag-editor] Deleted {txt_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if delete_backup:
|
||||
try:
|
||||
for extnum in range(1000):
|
||||
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
|
||||
if bak_path_obj.is_file():
|
||||
bak_path_obj.unlink()
|
||||
print(f'[tag-editor] Deleted {bak_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
|
||||
if img_path not in self.dataset.datas.keys():
|
||||
return
|
||||
|
||||
img_path_obj = Path(img_path)
|
||||
dest_dir_obj = Path(dest_dir)
|
||||
|
||||
if (move_image or move_caption or move_backup) and not dest_dir_obj.exists():
|
||||
dest_dir_obj.mkdir()
|
||||
|
||||
if move_image:
|
||||
try:
|
||||
dst_path_obj = dest_dir_obj / img_path_obj.name
|
||||
if img_path_obj.is_file():
|
||||
if img_path in self.images:
|
||||
self.images[img_path].close()
|
||||
del self.images[img_path]
|
||||
img_path_obj.replace(dst_path_obj)
|
||||
self.dataset.remove_by_path(img_path)
|
||||
print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if move_caption:
|
||||
try:
|
||||
txt_path_obj = img_path_obj.with_suffix(caption_ext)
|
||||
dst_path_obj = dest_dir_obj / txt_path_obj.name
|
||||
if txt_path_obj.is_file():
|
||||
txt_path_obj.replace(dst_path_obj)
|
||||
print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if move_backup:
|
||||
try:
|
||||
for extnum in range(1000):
|
||||
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
|
||||
dst_path_obj = dest_dir_obj / bak_path_obj.name
|
||||
if bak_path_obj.is_file():
|
||||
bak_path_obj.replace(dst_path_obj)
|
||||
print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str]):
|
||||
self.clear()
|
||||
|
||||
img_dir_obj = Path(img_dir)
|
||||
|
||||
print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}')
|
||||
if recursive:
|
||||
print(f'[tag-editor] Also loading from subdirectories.')
|
||||
|
||||
try:
|
||||
filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*')
|
||||
filepaths = [p for p in filepaths if p.is_file()]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('[tag-editor] Loading Aborted.')
|
||||
return
|
||||
|
||||
self.dataset_dir = img_dir
|
||||
|
||||
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
|
||||
|
||||
def load_images(filepaths:List[Path]):
|
||||
imgpaths = []
|
||||
images = {}
|
||||
for img_path in filepaths:
|
||||
if img_path.suffix == caption_ext:
|
||||
continue
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
abs_path = str(img_path.absolute())
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = abs_path
|
||||
images[abs_path] = img
|
||||
|
||||
imgpaths.append(abs_path)
|
||||
return imgpaths, images
|
||||
|
||||
def load_captions(imgpaths:List[str]):
|
||||
taglists = []
|
||||
for abs_path in imgpaths:
|
||||
img_path = Path(abs_path)
|
||||
text_path = img_path.with_suffix(caption_ext)
|
||||
caption_text = ''
|
||||
if interrogate_method != self.InterrogateMethod.OVERWRITE:
|
||||
# from modules/textual_inversion/dataset.py, modified
|
||||
if text_path.is_file():
|
||||
caption_text = text_path.read_text('utf8')
|
||||
elif load_caption_from_filename:
|
||||
caption_text = img_path.stem
|
||||
caption_text = re.sub(re_numbers_at_start, '', caption_text)
|
||||
if self.re_word:
|
||||
tokens = self.re_word.findall(caption_text)
|
||||
caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
caption_tags = [t.strip() for t in caption_text.split(',')]
|
||||
caption_tags = [t for t in caption_tags if t]
|
||||
taglists.append(caption_tags)
|
||||
|
||||
return taglists
|
||||
|
||||
try:
|
||||
captionings = []
|
||||
taggers = []
|
||||
if interrogate_method != self.InterrogateMethod.NONE:
|
||||
for it in INTERROGATORS:
|
||||
if it.name() in interrogator_names:
|
||||
it.start()
|
||||
if isinstance(it, tagger.Tagger):
|
||||
if isinstance(it, tagger.DeepDanbooru):
|
||||
taggers.append((it, threshold_booru))
|
||||
if isinstance(it, tagger.WaifuDiffusion):
|
||||
taggers.append((it, threshold_waifu))
|
||||
elif isinstance(it, captioning.Captioning):
|
||||
captionings.append(it)
|
||||
|
||||
if kohya_json_path:
|
||||
imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir)
|
||||
else:
|
||||
imgpaths, self.images = load_images(filepaths)
|
||||
taglists = load_captions(imgpaths)
|
||||
|
||||
for img_path, tags in zip(imgpaths, taglists):
|
||||
interrogate_tags = []
|
||||
img = self.images.get(img_path)
|
||||
if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == InterrogateMethod.PREFILL and not tags)):
|
||||
if img is None:
|
||||
print(f'Failed to load image {img_path}. Interrogating is aborted.')
|
||||
else:
|
||||
img = img.convert('RGB')
|
||||
for cap in captionings:
|
||||
interrogate_tags += cap.predict(img)
|
||||
|
||||
for tg, threshold in taggers:
|
||||
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
|
||||
|
||||
if interrogate_method == self.InterrogateMethod.OVERWRITE:
|
||||
tags = interrogate_tags
|
||||
elif interrogate_method == self.InterrogateMethod.PREPEND:
|
||||
tags = interrogate_tags + tags
|
||||
else:
|
||||
tags = tags + interrogate_tags
|
||||
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
|
||||
finally:
|
||||
if interrogate_method != self.InterrogateMethod.NONE:
|
||||
for cap in captionings:
|
||||
cap.stop()
|
||||
for tg, _ in taggers:
|
||||
tg.stop()
|
||||
|
||||
for i, p in enumerate(sorted(self.dataset.datas.keys())):
|
||||
self.img_idx[p] = i
|
||||
|
||||
self.construct_tag_counts()
|
||||
print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found')
|
||||
|
||||
|
||||
def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool):
|
||||
if len(self.dataset) == 0:
|
||||
return (0, 0, '')
|
||||
|
||||
saved_num = 0
|
||||
backup_num = 0
|
||||
for data in self.dataset.datas.values():
|
||||
img_path, tags = Path(data.imgpath), data.tags
|
||||
txt_path = img_path.with_suffix(caption_ext)
|
||||
# make backup
|
||||
if backup and txt_path.is_file():
|
||||
for extnum in range(1000):
|
||||
bak_path = img_path.with_suffix(f'.{extnum:0>3d}')
|
||||
if not bak_path.is_file():
|
||||
break
|
||||
else:
|
||||
bak_path = None
|
||||
if bak_path is None:
|
||||
print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.")
|
||||
else:
|
||||
try:
|
||||
txt_path.rename(bak_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"[tag-editor] A backup file of {txt_path} cannot be created.")
|
||||
else:
|
||||
backup_num += 1
|
||||
# save
|
||||
try:
|
||||
txt_path.write_text(', '.join(tags), 'utf8')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"[tag-editor] Warning: {txt_path} cannot be saved.")
|
||||
else:
|
||||
saved_num += 1
|
||||
print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}')
|
||||
print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}')
|
||||
|
||||
if(write_kohya_metadata):
|
||||
kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path)
|
||||
print(f'[tag-editor] Saved json metadata file in {meta_out_path}')
|
||||
return (saved_num, len(self.dataset), self.dataset_dir)
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.dataset.clear()
|
||||
self.tag_counts.clear()
|
||||
self.img_idx.clear()
|
||||
self.dataset_dir = ''
|
||||
self.images = {}
|
||||
|
||||
|
||||
def construct_tag_counts(self):
|
||||
self.tag_counts = {}
|
||||
for data in self.dataset.datas.values():
|
||||
for tag in data.tags:
|
||||
if tag in self.tag_counts.keys():
|
||||
self.tag_counts[tag] += 1
|
||||
else:
|
||||
self.tag_counts[tag] = 1
|
||||
|
|
|
|||
|
|
@ -0,0 +1,673 @@
|
|||
from pathlib import Path
|
||||
import re
|
||||
from typing import List, Set, Optional
|
||||
from enum import Enum
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared
|
||||
from modules.textual_inversion.dataset import re_numbers_at_start
|
||||
|
||||
from scripts.singleton import Singleton
|
||||
|
||||
from . import tagger, captioning, filters, dataset as ds, kohya_finetune_metadata as kohya_metadata
|
||||
from scripts.tokenizer import clip_tokenizer
|
||||
|
||||
WD_TAGGER_NAMES = ["wd-v1-4-vit-tagger", "wd-v1-4-convnext-tagger", "wd-v1-4-vit-tagger-v2", "wd-v1-4-convnext-tagger-v2", "wd-v1-4-swinv2-tagger-v2"]
|
||||
WD_TAGGER_THRESHOLDS = [0.35, 0.35, 0.3537, 0.3685, 0.3771] # v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
|
||||
INTERROGATORS = [captioning.BLIP(), tagger.DeepDanbooru()] + [tagger.WaifuDiffusion(name, WD_TAGGER_THRESHOLDS[i]) for i, name in enumerate(WD_TAGGER_NAMES)]
|
||||
INTERROGATOR_NAMES = [it.name() for it in INTERROGATORS]
|
||||
|
||||
re_tags = re.compile(r'^(.+?)( \[\d+\])?$')
|
||||
|
||||
|
||||
def interrogate_image(path:str, interrogator_name:str, threshold_booru, threshold_wd):
|
||||
try:
|
||||
img = Image.open(path).convert('RGB')
|
||||
except:
|
||||
return ''
|
||||
else:
|
||||
for it in INTERROGATORS:
|
||||
if it.name() == interrogator_name:
|
||||
if isinstance(it, tagger.DeepDanbooru):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_booru)
|
||||
elif isinstance(it, tagger.WaifuDiffusion):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_wd)
|
||||
else:
|
||||
with it as cap:
|
||||
res = cap.predict(img)
|
||||
return ', '.join(res)
|
||||
|
||||
|
||||
class DatasetTagEditor(Singleton):
|
||||
class SortBy(Enum):
|
||||
ALPHA = 'Alphabetical Order'
|
||||
FREQ = 'Frequency'
|
||||
LEN = 'Length'
|
||||
TOKEN = 'Token Length'
|
||||
|
||||
class SortOrder(Enum):
|
||||
ASC = 'Ascending'
|
||||
DESC = 'Descending'
|
||||
|
||||
class InterrogateMethod(Enum):
|
||||
NONE = 0
|
||||
PREFILL = 1
|
||||
OVERWRITE = 2
|
||||
PREPEND = 3
|
||||
APPEND = 4
|
||||
|
||||
def __init__(self):
|
||||
# from modules.textual_inversion.dataset
|
||||
self.re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
self.dataset = ds.Dataset()
|
||||
self.img_idx = dict()
|
||||
self.tag_counts = {}
|
||||
self.dataset_dir = ''
|
||||
self.images = {}
|
||||
self.tag_tokens = {}
|
||||
self.raw_clip_token_used = None
|
||||
|
||||
def get_tag_list(self):
|
||||
if len(self.tag_counts) == 0:
|
||||
self.construct_tag_infos()
|
||||
return [key for key in self.tag_counts.keys()]
|
||||
|
||||
|
||||
def get_tag_set(self):
|
||||
if len(self.tag_counts) == 0:
|
||||
self.construct_tag_infos()
|
||||
return {key for key in self.tag_counts.keys()}
|
||||
|
||||
|
||||
def get_tags_by_image_path(self, imgpath:str):
|
||||
return self.dataset.get_data_tags(imgpath)
|
||||
|
||||
|
||||
def set_tags_by_image_path(self, imgpath:str, tags:List[str]):
|
||||
self.dataset.append_data(ds.Data(imgpath, ','.join(tags)))
|
||||
self.construct_tag_infos()
|
||||
|
||||
|
||||
def write_tags(self, tags:List[str], sort_by:SortBy=SortBy.FREQ):
|
||||
sort_by = self.SortBy(sort_by)
|
||||
if tags:
|
||||
if sort_by == self.SortBy.FREQ:
|
||||
return [f'{tag} [{self.tag_counts.get(tag) or 0}]' for tag in tags if tag]
|
||||
elif sort_by == self.SortBy.LEN:
|
||||
return [f'{tag} [{len(tag)}]' for tag in tags if tag]
|
||||
elif sort_by == self.SortBy.TOKEN:
|
||||
return [f'{tag} [{self.tag_tokens.get(tag, (0, 0))[1]}]' for tag in tags if tag]
|
||||
else:
|
||||
return [f'{tag}' for tag in tags if tag]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def read_tags(self, tags:List[str]):
|
||||
if tags:
|
||||
tags = [re_tags.match(tag).group(1) for tag in tags if tag]
|
||||
return [t for t in tags if t]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def sort_tags(self, tags:List[str], sort_by:SortBy=SortBy.ALPHA, sort_order:SortOrder=SortOrder.ASC):
|
||||
sort_by = self.SortBy(sort_by)
|
||||
sort_order = self.SortOrder(sort_order)
|
||||
if sort_by == self.SortBy.ALPHA:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, reverse=True)
|
||||
elif sort_by == self.SortBy.FREQ:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, key=lambda t:(self.tag_counts.get(t, 0), t), reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, key=lambda t:(-self.tag_counts.get(t, 0), t), reverse=False)
|
||||
elif sort_by == self.SortBy.LEN:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, key=lambda t:(len(t), t), reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, key=lambda t:(-len(t), t), reverse=False)
|
||||
elif sort_by == self.SortBy.TOKEN:
|
||||
if sort_order == self.SortOrder.ASC:
|
||||
return sorted(tags, key=lambda t:(self.tag_tokens.get(t, (0, 0))[1], t), reverse=False)
|
||||
elif sort_order == self.SortOrder.DESC:
|
||||
return sorted(tags, key=lambda t:(-self.tag_tokens.get(t, (0, 0))[1], t), reverse=False)
|
||||
return list(tags)
|
||||
|
||||
|
||||
def get_filtered_imgpaths(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
img_paths = sorted(filtered_set.datas.keys())
|
||||
|
||||
return img_paths
|
||||
|
||||
|
||||
def get_filtered_imgs(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
img_paths = sorted(filtered_set.datas.keys())
|
||||
|
||||
return [self.images.get(path) for path in img_paths]
|
||||
|
||||
|
||||
def get_filtered_imgindices(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
img_paths = sorted(filtered_set.datas.keys())
|
||||
|
||||
return [self.img_idx.get(p) for p in img_paths]
|
||||
|
||||
|
||||
def get_filtered_tags(self, filters:List[filters.Filter] = [], filter_word:str = '', filter_tags = True, prefix=False, suffix=False, regex=False):
|
||||
if filter_tags:
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
tags:Set[str] = filtered_set.get_tagset()
|
||||
else:
|
||||
tags:Set[str] = self.dataset.get_tagset()
|
||||
|
||||
result = set()
|
||||
try:
|
||||
for tag in tags:
|
||||
if prefix:
|
||||
if regex:
|
||||
if re.search("^" + filter_word, tag) is not None:
|
||||
result.add(tag)
|
||||
continue
|
||||
else:
|
||||
if tag.startswith(filter_word):
|
||||
result.add(tag)
|
||||
continue
|
||||
if suffix:
|
||||
if regex:
|
||||
if re.search(filter_word + "$", tag) is not None:
|
||||
result.add(tag)
|
||||
continue
|
||||
else:
|
||||
if tag.endswith(filter_word):
|
||||
result.add(tag)
|
||||
continue
|
||||
if not prefix and not suffix:
|
||||
if regex:
|
||||
if re.search(filter_word, tag) is not None:
|
||||
result.add(tag)
|
||||
continue
|
||||
else:
|
||||
if filter_word in tag:
|
||||
result.add(tag)
|
||||
continue
|
||||
except:
|
||||
return tags
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def cleanup_tags(self, tags:List[str]):
|
||||
current_dataset_tags = self.dataset.get_tagset()
|
||||
return [t for t in tags if t in current_dataset_tags]
|
||||
|
||||
def cleanup_tagset(self, tags:Set[str]):
|
||||
current_dataset_tagset = self.dataset.get_tagset()
|
||||
return tags & current_dataset_tagset
|
||||
|
||||
|
||||
def get_common_tags(self, filters:List[filters.Filter] = []):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
|
||||
result = filtered_set.get_tagset()
|
||||
for d in filtered_set.datas.values():
|
||||
result &= d.tagset
|
||||
|
||||
return sorted(result)
|
||||
|
||||
|
||||
def replace_tags(self, search_tags:List[str], replace_tags:List[str], filters:List[filters.Filter] = [], prepend:bool = False):
|
||||
img_paths = self.get_filtered_imgpaths(filters=filters)
|
||||
tags_to_append = replace_tags[len(search_tags):]
|
||||
tags_to_remove = search_tags[len(replace_tags):]
|
||||
tags_to_replace = {}
|
||||
for i in range(min(len(search_tags), len(replace_tags))):
|
||||
if replace_tags[i] is None or replace_tags[i] == '':
|
||||
tags_to_remove.append(search_tags[i])
|
||||
else:
|
||||
tags_to_replace[search_tags[i]] = replace_tags[i]
|
||||
for img_path in img_paths:
|
||||
tags_removed = [t for t in self.dataset.get_data_tags(img_path) if t not in tags_to_remove]
|
||||
tags_replaced = [tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed]
|
||||
self.set_tags_by_image_path(img_path, tags_to_append + tags_replaced if prepend else tags_replaced + tags_to_append)
|
||||
|
||||
self.construct_tag_infos()
|
||||
|
||||
def get_replaced_tagset(self, tags:Set[str], search_tags:List[str], replace_tags:List[str]):
|
||||
tags_to_remove = search_tags[len(replace_tags):]
|
||||
tags_to_replace = {}
|
||||
for i in range(min(len(search_tags), len(replace_tags))):
|
||||
if replace_tags[i] is None or replace_tags[i] == '':
|
||||
tags_to_remove.append(search_tags[i])
|
||||
else:
|
||||
tags_to_replace[search_tags[i]] = replace_tags[i]
|
||||
tags_removed = {t for t in tags if t not in tags_to_remove}
|
||||
tags_replaced = {tags_to_replace.get(t) if t in tags_to_replace.keys() else t for t in tags_removed}
|
||||
return {t for t in tags_replaced if t}
|
||||
|
||||
|
||||
def search_and_replace_caption(self, search_text:str, replace_text:str, filters:List[filters.Filter] = [], use_regex:bool = False):
|
||||
img_paths = self.get_filtered_imgpaths(filters=filters)
|
||||
|
||||
for img_path in img_paths:
|
||||
caption = ', '.join(self.dataset.get_data_tags(img_path))
|
||||
if use_regex:
|
||||
caption = [t.strip() for t in re.sub(search_text, replace_text, caption).split(',')]
|
||||
else:
|
||||
caption = [t.strip() for t in caption.replace(search_text, replace_text).split(',')]
|
||||
caption = [t for t in caption if t]
|
||||
self.set_tags_by_image_path(img_path, caption)
|
||||
|
||||
self.construct_tag_infos()
|
||||
|
||||
|
||||
def search_and_replace_selected_tags(self, search_text:str, replace_text:str, selected_tags:Optional[Set[str]], filters:List[filters.Filter] = [], use_regex:bool = False):
|
||||
img_paths = self.get_filtered_imgpaths(filters=filters)
|
||||
|
||||
for img_path in img_paths:
|
||||
tags = self.dataset.get_data_tags(img_path)
|
||||
tags = self.search_and_replace_tag_list(search_text, replace_text, tags, selected_tags, use_regex)
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
|
||||
self.construct_tag_infos()
|
||||
|
||||
|
||||
def search_and_replace_tag_list(self, search_text:str, replace_text:str, tags:List[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
|
||||
if use_regex:
|
||||
if selected_tags is None:
|
||||
tags = [re.sub(search_text, replace_text, t) for t in tags]
|
||||
else:
|
||||
tags = [re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags]
|
||||
else:
|
||||
if selected_tags is None:
|
||||
tags = [t.replace(search_text, replace_text) for t in tags]
|
||||
else:
|
||||
tags = [t.replace(search_text, replace_text) if t in selected_tags else t for t in tags]
|
||||
tags = [t2 for t1 in tags for t2 in t1.split(',') if t2]
|
||||
return [t for t in tags if t]
|
||||
|
||||
|
||||
def search_and_replace_tag_set(self, search_text:str, replace_text:str, tags:Set[str], selected_tags:Optional[Set[str]] = None, use_regex:bool = False):
|
||||
if use_regex:
|
||||
if selected_tags is None:
|
||||
tags = {re.sub(search_text, replace_text, t) for t in tags}
|
||||
else:
|
||||
tags = {re.sub(search_text, replace_text, t) if t in selected_tags else t for t in tags}
|
||||
else:
|
||||
if selected_tags is None:
|
||||
tags = {t.replace(search_text, replace_text) for t in tags}
|
||||
else:
|
||||
tags = {t.replace(search_text, replace_text) if t in selected_tags else t for t in tags}
|
||||
tags = {t2 for t1 in tags for t2 in t1.split(',') if t2}
|
||||
return {t for t in tags if t}
|
||||
|
||||
|
||||
def remove_duplicated_tags(self, filters:List[filters.Filter] = []):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
tags = self.dataset.get_data_tags(path)
|
||||
res = []
|
||||
for t in tags:
|
||||
if t not in res:
|
||||
res.append(t)
|
||||
self.set_tags_by_image_path(path, res)
|
||||
|
||||
|
||||
def remove_tags(self, tags:Set[str], filters:List[filters.Filter] = []):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
res = self.dataset.get_data_tags(path)
|
||||
res = [t for t in res if t not in tags]
|
||||
self.set_tags_by_image_path(path, res)
|
||||
|
||||
|
||||
def sort_filtered_tags(self, filters:List[filters.Filter] = [], **sort_args):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
tags = self.dataset.get_data_tags(path)
|
||||
res = self.sort_tags(tags, **sort_args)
|
||||
self.set_tags_by_image_path(path, res)
|
||||
print(f'[tag-editor] Tags are sorted by {sort_args.get("sort_by").value} ({sort_args.get("sort_order").value})')
|
||||
|
||||
|
||||
def truncate_filtered_tags_by_token_count(self, filters:List[filters.Filter] = [], max_token_count:int = 75):
|
||||
img_paths = self.get_filtered_imgpaths(filters)
|
||||
for path in img_paths:
|
||||
tags = self.dataset.get_data_tags(path)
|
||||
res = []
|
||||
for tag in tags:
|
||||
_, token_count = clip_tokenizer.tokenize(', '.join(res + [tag]), shared.opts.dataset_editor_use_raw_clip_token)
|
||||
if token_count <= max_token_count:
|
||||
res.append(tag)
|
||||
else:
|
||||
break
|
||||
self.set_tags_by_image_path(path, res)
|
||||
|
||||
self.construct_tag_infos()
|
||||
print(f'[tag-editor] Tags are truncated into token count <= {max_token_count}')
|
||||
|
||||
|
||||
def get_img_path_list(self):
|
||||
return [k for k in self.dataset.datas.keys() if k]
|
||||
|
||||
|
||||
def get_img_path_set(self):
|
||||
return {k for k in self.dataset.datas.keys() if k}
|
||||
|
||||
|
||||
def delete_dataset(self, caption_ext:str, filters:List[filters.Filter], delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
for path in filtered_set.datas.keys():
|
||||
self.delete_dataset_file(path, caption_ext, delete_image, delete_caption, delete_backup)
|
||||
|
||||
if delete_image:
|
||||
self.dataset.remove(filtered_set)
|
||||
self.construct_tag_infos()
|
||||
|
||||
|
||||
def move_dataset(self, dest_dir:str, caption_ext:str, filters:List[filters.Filter], move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
|
||||
filtered_set = self.dataset.copy()
|
||||
for filter in filters:
|
||||
filtered_set.filter(filter)
|
||||
for path in filtered_set.datas.keys():
|
||||
self.move_dataset_file(path, caption_ext, dest_dir, move_image, move_caption, move_backup)
|
||||
|
||||
if move_image:
|
||||
self.construct_tag_infos()
|
||||
|
||||
|
||||
def delete_dataset_file(self, img_path:str, caption_ext:str, delete_image:bool = False, delete_caption:bool = False, delete_backup:bool = False):
|
||||
if img_path not in self.dataset.datas.keys():
|
||||
return
|
||||
|
||||
img_path_obj = Path(img_path)
|
||||
|
||||
if delete_image:
|
||||
try:
|
||||
if img_path_obj.is_file():
|
||||
if img_path in self.images:
|
||||
self.images[img_path].close()
|
||||
del self.images[img_path]
|
||||
img_path_obj.unlink()
|
||||
self.dataset.remove_by_path(img_path)
|
||||
print(f'[tag-editor] Deleted {img_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if delete_caption:
|
||||
try:
|
||||
txt_path_obj = img_path_obj.with_suffix(caption_ext)
|
||||
if txt_path_obj.is_file():
|
||||
txt_path_obj.unlink()
|
||||
print(f'[tag-editor] Deleted {txt_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if delete_backup:
|
||||
try:
|
||||
for extnum in range(1000):
|
||||
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
|
||||
if bak_path_obj.is_file():
|
||||
bak_path_obj.unlink()
|
||||
print(f'[tag-editor] Deleted {bak_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def move_dataset_file(self, img_path:str, caption_ext:str, dest_dir:str, move_image:bool = False, move_caption:bool = False, move_backup:bool = False):
|
||||
if img_path not in self.dataset.datas.keys():
|
||||
return
|
||||
|
||||
img_path_obj = Path(img_path)
|
||||
dest_dir_obj = Path(dest_dir)
|
||||
|
||||
if (move_image or move_caption or move_backup) and not dest_dir_obj.exists():
|
||||
dest_dir_obj.mkdir()
|
||||
|
||||
if move_image:
|
||||
try:
|
||||
dst_path_obj = dest_dir_obj / img_path_obj.name
|
||||
if img_path_obj.is_file():
|
||||
if img_path in self.images:
|
||||
self.images[img_path].close()
|
||||
del self.images[img_path]
|
||||
img_path_obj.replace(dst_path_obj)
|
||||
self.dataset.remove_by_path(img_path)
|
||||
print(f'[tag-editor] Moved {img_path_obj.absolute()} -> {dst_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if move_caption:
|
||||
try:
|
||||
txt_path_obj = img_path_obj.with_suffix(caption_ext)
|
||||
dst_path_obj = dest_dir_obj / txt_path_obj.name
|
||||
if txt_path_obj.is_file():
|
||||
txt_path_obj.replace(dst_path_obj)
|
||||
print(f'[tag-editor] Moved {txt_path_obj.absolute()} -> {dst_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
if move_backup:
|
||||
try:
|
||||
for extnum in range(1000):
|
||||
bak_path_obj = img_path_obj.with_suffix(f'.{extnum:0>3d}')
|
||||
dst_path_obj = dest_dir_obj / bak_path_obj.name
|
||||
if bak_path_obj.is_file():
|
||||
bak_path_obj.replace(dst_path_obj)
|
||||
print(f'[tag-editor] Moved {bak_path_obj.absolute()} -> {dst_path_obj.absolute()}')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def load_dataset(self, img_dir:str, caption_ext:str, recursive:bool, load_caption_from_filename:bool, interrogate_method:InterrogateMethod, interrogator_names:List[str], threshold_booru:float, threshold_waifu:float, use_temp_dir:bool, kohya_json_path:Optional[str]):
|
||||
self.clear()
|
||||
|
||||
img_dir_obj = Path(img_dir)
|
||||
|
||||
print(f'[tag-editor] Loading dataset from {img_dir_obj.absolute()}')
|
||||
if recursive:
|
||||
print(f'[tag-editor] Also loading from subdirectories.')
|
||||
|
||||
try:
|
||||
filepaths = img_dir_obj.glob('**/*') if recursive else img_dir_obj.glob('*')
|
||||
filepaths = [p for p in filepaths if p.is_file()]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('[tag-editor] Loading Aborted.')
|
||||
return
|
||||
|
||||
self.dataset_dir = img_dir
|
||||
|
||||
print(f'[tag-editor] Total {len(filepaths)} files under the directory including not image files.')
|
||||
|
||||
def load_images(filepaths:List[Path]):
|
||||
imgpaths = []
|
||||
images = {}
|
||||
for img_path in filepaths:
|
||||
if img_path.suffix == caption_ext:
|
||||
continue
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
abs_path = str(img_path.absolute())
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = abs_path
|
||||
images[abs_path] = img
|
||||
|
||||
imgpaths.append(abs_path)
|
||||
return imgpaths, images
|
||||
|
||||
def load_captions(imgpaths:List[str]):
|
||||
taglists = []
|
||||
for abs_path in imgpaths:
|
||||
img_path = Path(abs_path)
|
||||
text_path = img_path.with_suffix(caption_ext)
|
||||
caption_text = ''
|
||||
if interrogate_method != self.InterrogateMethod.OVERWRITE:
|
||||
# from modules/textual_inversion/dataset.py, modified
|
||||
if text_path.is_file():
|
||||
caption_text = text_path.read_text('utf8')
|
||||
elif load_caption_from_filename:
|
||||
caption_text = img_path.stem
|
||||
caption_text = re.sub(re_numbers_at_start, '', caption_text)
|
||||
if self.re_word:
|
||||
tokens = self.re_word.findall(caption_text)
|
||||
caption_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
|
||||
|
||||
caption_tags = [t.strip() for t in caption_text.split(',')]
|
||||
caption_tags = [t for t in caption_tags if t]
|
||||
taglists.append(caption_tags)
|
||||
|
||||
return taglists
|
||||
|
||||
try:
|
||||
captionings = []
|
||||
taggers = []
|
||||
if interrogate_method != self.InterrogateMethod.NONE:
|
||||
for it in INTERROGATORS:
|
||||
if it.name() in interrogator_names:
|
||||
it.start()
|
||||
if isinstance(it, tagger.Tagger):
|
||||
if isinstance(it, tagger.DeepDanbooru):
|
||||
taggers.append((it, threshold_booru))
|
||||
if isinstance(it, tagger.WaifuDiffusion):
|
||||
taggers.append((it, threshold_waifu))
|
||||
elif isinstance(it, captioning.Captioning):
|
||||
captionings.append(it)
|
||||
|
||||
if kohya_json_path:
|
||||
imgpaths, self.images, taglists = kohya_metadata.read(img_dir, kohya_json_path, use_temp_dir)
|
||||
else:
|
||||
imgpaths, self.images = load_images(filepaths)
|
||||
taglists = load_captions(imgpaths)
|
||||
|
||||
for img_path, tags in zip(imgpaths, taglists):
|
||||
interrogate_tags = []
|
||||
img = self.images.get(img_path)
|
||||
if interrogate_method != self.InterrogateMethod.NONE and ((interrogate_method != self.InterrogateMethod.PREFILL) or (interrogate_method == self.InterrogateMethod.PREFILL and not tags)):
|
||||
if img is None:
|
||||
print(f'Failed to load image {img_path}. Interrogating is aborted.')
|
||||
else:
|
||||
img = img.convert('RGB')
|
||||
for cap in captionings:
|
||||
interrogate_tags += cap.predict(img)
|
||||
|
||||
for tg, threshold in taggers:
|
||||
interrogate_tags += [t for t in tg.predict(img, threshold).keys()]
|
||||
|
||||
if interrogate_method == self.InterrogateMethod.OVERWRITE:
|
||||
tags = interrogate_tags
|
||||
elif interrogate_method == self.InterrogateMethod.PREPEND:
|
||||
tags = interrogate_tags + tags
|
||||
else:
|
||||
tags = tags + interrogate_tags
|
||||
|
||||
self.set_tags_by_image_path(img_path, tags)
|
||||
|
||||
finally:
|
||||
if interrogate_method != self.InterrogateMethod.NONE:
|
||||
for cap in captionings:
|
||||
cap.stop()
|
||||
for tg, _ in taggers:
|
||||
tg.stop()
|
||||
|
||||
for i, p in enumerate(sorted(self.dataset.datas.keys())):
|
||||
self.img_idx[p] = i
|
||||
|
||||
self.construct_tag_infos()
|
||||
print(f'[tag-editor] Loading Completed: {len(self.dataset)} images found')
|
||||
|
||||
|
||||
def save_dataset(self, backup:bool, caption_ext:str, write_kohya_metadata:bool, meta_out_path:str, meta_in_path:Optional[str], meta_overwrite:bool, meta_as_caption:bool, meta_full_path:bool):
|
||||
if len(self.dataset) == 0:
|
||||
return (0, 0, '')
|
||||
|
||||
saved_num = 0
|
||||
backup_num = 0
|
||||
for data in self.dataset.datas.values():
|
||||
img_path, tags = Path(data.imgpath), data.tags
|
||||
txt_path = img_path.with_suffix(caption_ext)
|
||||
# make backup
|
||||
if backup and txt_path.is_file():
|
||||
for extnum in range(1000):
|
||||
bak_path = img_path.with_suffix(f'.{extnum:0>3d}')
|
||||
if not bak_path.is_file():
|
||||
break
|
||||
else:
|
||||
bak_path = None
|
||||
if bak_path is None:
|
||||
print(f"[tag-editor] There are too many backup files with same filename. A backup file of {txt_path} cannot be created.")
|
||||
else:
|
||||
try:
|
||||
txt_path.rename(bak_path)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"[tag-editor] A backup file of {txt_path} cannot be created.")
|
||||
else:
|
||||
backup_num += 1
|
||||
# save
|
||||
try:
|
||||
txt_path.write_text(', '.join(tags), 'utf8')
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(f"[tag-editor] Warning: {txt_path} cannot be saved.")
|
||||
else:
|
||||
saved_num += 1
|
||||
print(f'[tag-editor] Backup text files: {backup_num}/{len(self.dataset)} under {self.dataset_dir}')
|
||||
print(f'[tag-editor] Saved text files: {saved_num}/{len(self.dataset)} under {self.dataset_dir}')
|
||||
|
||||
if(write_kohya_metadata):
|
||||
kohya_metadata.write(dataset=self.dataset, dataset_dir=self.dataset_dir, out_path=meta_out_path, in_path=meta_in_path, overwrite=meta_overwrite, save_as_caption=meta_as_caption, use_full_path=meta_full_path)
|
||||
print(f'[tag-editor] Saved json metadata file in {meta_out_path}')
|
||||
return (saved_num, len(self.dataset), self.dataset_dir)
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.dataset.clear()
|
||||
self.tag_counts.clear()
|
||||
self.tag_tokens.clear()
|
||||
self.img_idx.clear()
|
||||
self.dataset_dir = ''
|
||||
self.images.clear()
|
||||
|
||||
|
||||
def construct_tag_infos(self):
|
||||
self.tag_counts = {}
|
||||
update_token_count = self.raw_clip_token_used is None or self.raw_clip_token_used != shared.opts.dataset_editor_use_raw_clip_token
|
||||
|
||||
if update_token_count:
|
||||
self.tag_tokens.clear()
|
||||
|
||||
for data in self.dataset.datas.values():
|
||||
for tag in data.tags:
|
||||
if tag in self.tag_counts.keys():
|
||||
self.tag_counts[tag] += 1
|
||||
else:
|
||||
self.tag_counts[tag] = 1
|
||||
if tag not in self.tag_tokens:
|
||||
self.tag_tokens[tag] = clip_tokenizer.tokenize(tag, shared.opts.dataset_editor_use_raw_clip_token)
|
||||
self.raw_clip_token_used = shared.opts.dataset_editor_use_raw_clip_token
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, NamedTuple, Type, Dict, Any
|
||||
from typing import NamedTuple, Type, Dict, Any
|
||||
from modules import shared, script_callbacks, scripts
|
||||
from modules.shared import opts
|
||||
import gradio as gr
|
||||
|
|
@ -39,15 +39,15 @@ GeneralConfig = namedtuple('GeneralConfig', [
|
|||
'meta_use_full_path'
|
||||
])
|
||||
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
|
||||
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order'])
|
||||
EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
|
||||
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order', 'token_count'])
|
||||
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
|
||||
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
|
||||
|
||||
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
|
||||
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
|
||||
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
|
||||
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value, 75)
|
||||
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
|
||||
|
||||
class Config:
|
||||
|
|
@ -252,10 +252,10 @@ def on_ui_tabs():
|
|||
ui.batch_edit_captions.rb_sr_replace_target,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.cb_prefix, ui.batch_edit_captions.tag_select_ui_remove.cb_suffix, ui.batch_edit_captions.tag_select_ui_remove.cb_regex,
|
||||
ui.batch_edit_captions.tag_select_ui_remove.rb_sort_by, ui.batch_edit_captions.tag_select_ui_remove.rb_sort_order,
|
||||
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order
|
||||
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order,
|
||||
ui.batch_edit_captions.nb_token_count
|
||||
]
|
||||
components_edit_selected = [
|
||||
ui.edit_caption_of_selected_image.cb_use_raw_token,
|
||||
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
|
||||
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
|
||||
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order
|
||||
|
|
@ -339,6 +339,7 @@ def on_ui_settings():
|
|||
section = ('dataset-tag-editor', "Dataset Tag Editor")
|
||||
shared.opts.add_option("dataset_editor_image_columns", shared.OptionInfo(6, "Number of columns on image gallery", section=section))
|
||||
shared.opts.add_option("dataset_editor_use_temp_files", shared.OptionInfo(False, "Force image gallery to use temporary files", section=section))
|
||||
shared.opts.add_option("dataset_editor_use_raw_clip_token", shared.OptionInfo(True, "Use raw CLIP token to calculate token count (without emphasis or embeddings)", section=section))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
|
||||
|
||||
from typing import List
|
||||
from functools import reduce
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
|
||||
from modules import shared, extra_networks, prompt_parser
|
||||
from modules.sd_hijack import model_hijack
|
||||
import open_clip.tokenizer
|
||||
|
||||
class VanillaClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
|
||||
def vocab(self):
|
||||
return self.clip.tokenizer.get_vocab()
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.clip.tokenizer.byte_decoder
|
||||
|
||||
class OpenClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
self.tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
def vocab(self):
|
||||
return self.tokenizer.encoder
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.tokenizer.byte_decoder
|
||||
|
||||
def tokenize(text:str, use_raw_clip:bool=True):
|
||||
if use_raw_clip:
|
||||
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
|
||||
token_count = len(tokens)
|
||||
else:
|
||||
try:
|
||||
text, _ = extra_networks.parse_prompt(text)
|
||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||
prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
|
||||
except Exception:
|
||||
prompt = text
|
||||
token_chunks, token_count = model_hijack.clip.tokenize_line(prompt)
|
||||
tokens = reduce(lambda list1, list2: list1+list2, [tc.tokens for tc in token_chunks])
|
||||
return tokens, token_count
|
||||
|
||||
def get_target_token_count(token_count:int):
|
||||
return model_hijack.clip.get_target_prompt_token_count(token_count)
|
||||
|
|
@ -127,8 +127,8 @@ class TagFilterUI():
|
|||
tags_in_filter = dte_instance.sort_tags(tags=tags_in_filter, sort_by=self.sort_by, sort_order=self.sort_order)
|
||||
|
||||
tags = tags_in_filter + [tag for tag in tags if tag not in self.filter.tags]
|
||||
tags = dte_instance.write_tags(tags)
|
||||
tags_in_filter = dte_instance.write_tags(tags_in_filter)
|
||||
tags = dte_instance.write_tags(tags, self.sort_by)
|
||||
tags_in_filter = dte_instance.write_tags(tags_in_filter, self.sort_by)
|
||||
|
||||
return gr.CheckboxGroup.update(value=tags_in_filter, choices=tags)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ from .ui_common import *
|
|||
TagFilter = dte_module.filters.TagFilter
|
||||
Filter = dte_module.filters.Filter
|
||||
|
||||
SortBy = dte_instance.SortBy
|
||||
SortOrder = dte_instance.SortOrder
|
||||
|
||||
|
||||
class TagSelectUI():
|
||||
def __init__(self):
|
||||
|
|
@ -32,8 +35,8 @@ class TagSelectUI():
|
|||
self.cb_suffix = gr.Checkbox(label='Suffix', value=False, interactive=True)
|
||||
self.cb_regex = gr.Checkbox(label='Use regex', value=False, interactive=True)
|
||||
with gr.Row():
|
||||
self.rb_sort_by = gr.Radio(choices=['Alphabetical Order', 'Frequency', 'Length'], value=sort_by, interactive=True, label='Sort by')
|
||||
self.rb_sort_order = gr.Radio(choices=['Ascending', 'Descending'], value=sort_order, interactive=True, label='Sort Order')
|
||||
self.rb_sort_by = gr.Radio(hoices=[e.value for e in SortBy], value=sort_by, interactive=True, label='Sort by')
|
||||
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=sort_order, interactive=True, label='Sort Order')
|
||||
with gr.Row():
|
||||
self.btn_select_visibles = gr.Button(value='Select visible tags')
|
||||
self.btn_deselect_visibles = gr.Button(value='Deselect visible tags')
|
||||
|
|
@ -105,6 +108,6 @@ class TagSelectUI():
|
|||
self.tags = set(dte_instance.get_filtered_tags(self.get_filters(), filter_tags=True, prefix=self.prefix, suffix=self.suffix, regex=self.regex))
|
||||
self.selected_tags &= self.tags
|
||||
tags = dte_instance.sort_tags(tags=tags, sort_by=self.sort_by, sort_order=self.sort_order)
|
||||
tags = dte_instance.write_tags(tags)
|
||||
selected_tags = dte_instance.write_tags(list(self.selected_tags))
|
||||
tags = dte_instance.write_tags(tags, self.sort_by)
|
||||
selected_tags = dte_instance.write_tags(list(self.selected_tags), self.sort_by)
|
||||
return gr.CheckboxGroup.update(value=selected_tags, choices=tags)
|
||||
|
|
@ -67,7 +67,11 @@ class BatchEditCaptionsUI(UIBase):
|
|||
with gr.Row():
|
||||
self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_batch_edit.batch_sort_by, interactive=True, label='Sort by')
|
||||
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_batch_edit.batch_sort_order, interactive=True, label='Sort Order')
|
||||
self.btn_sort_selected = gr.Button(value='Sort selected tags', variant='primary')
|
||||
self.btn_sort_selected = gr.Button(value='Sort tags', variant='primary')
|
||||
with gr.Column(variant='panel'):
|
||||
gr.HTML('Truncate tags by token count.')
|
||||
self.nb_token_count = gr.Number(value=cfg_batch_edit.token_count, precision=0)
|
||||
self.btn_truncate_by_token = gr.Button(value='Truncate tags by token count', variant='primary')
|
||||
|
||||
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], load_dataset:LoadDatasetUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
|
||||
load_dataset.btn_load_datasets.click(
|
||||
|
|
@ -173,6 +177,17 @@ class BatchEditCaptionsUI(UIBase):
|
|||
inputs=self.cb_show_only_tags_selected
|
||||
)
|
||||
|
||||
def truncate_by_token_count(token_count:int):
|
||||
token_count = max(int(token_count), 0)
|
||||
dte_instance.truncate_filtered_tags_by_token_count(get_filters(), token_count)
|
||||
return update_filter_and_gallery()
|
||||
|
||||
self.btn_truncate_by_token.click(
|
||||
fn=truncate_by_token_count,
|
||||
inputs=self.nb_token_count,
|
||||
outputs=o_update_filter_and_gallery
|
||||
)
|
||||
|
||||
|
||||
def get_common_tags(self, get_filters:Callable[[], List[dte_module.filters.Filter]], filter_by_tags:FilterByTagsUI):
|
||||
if self.show_only_selected_tags:
|
||||
|
|
|
|||
|
|
@ -1,16 +1,14 @@
|
|||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, List, Callable
|
||||
from functools import reduce
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, extra_networks, prompt_parser
|
||||
from modules import shared
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules.sd_hijack import model_hijack
|
||||
from scripts.dte_instance import dte_module
|
||||
|
||||
from .ui_common import *
|
||||
from .uibase import UIBase
|
||||
from .tokenizer import clip_tokenizer
|
||||
from scripts.tokenizer import clip_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ui_classes import *
|
||||
|
|
@ -30,7 +28,6 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
with gr.Tab(label='Read Caption from Selected Image'):
|
||||
self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption')
|
||||
self.token_counter_caption = gr.HTML(value='<span></span>', elem_id='dte_caption_counter')
|
||||
self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)')
|
||||
with gr.Row():
|
||||
self.btn_copy_caption = gr.Button(value='Copy and Overwrite')
|
||||
self.btn_prepend_caption = gr.Button(value='Prepend')
|
||||
|
|
@ -194,39 +191,27 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
outputs=self.sort_settings
|
||||
)
|
||||
|
||||
def update_token_counter(text:str, use_raw:bool):
|
||||
if use_raw:
|
||||
token_count = clip_tokenizer.token_count(text)
|
||||
max_length = model_hijack.clip.get_target_prompt_token_count(token_count)
|
||||
else:
|
||||
try:
|
||||
text, _ = extra_networks.parse_prompt(text)
|
||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||
prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
|
||||
except Exception:
|
||||
prompt = text
|
||||
token_count, max_length = model_hijack.get_prompt_lengths(prompt)
|
||||
def update_token_counter(text:str):
|
||||
_, token_count = clip_tokenizer.tokenize(text, shared.opts.dataset_editor_use_raw_clip_token)
|
||||
max_length = clip_tokenizer.get_target_token_count(token_count)
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
update_caption_token_counter_args = {
|
||||
'fn' : wrap_queued_call(update_token_counter),
|
||||
'inputs' : [self.tb_caption, self.cb_use_raw_token],
|
||||
'inputs' : [self.tb_caption],
|
||||
'outputs' : [self.token_counter_caption]
|
||||
}
|
||||
update_edit_caption_token_counter_args = {
|
||||
'fn' : wrap_queued_call(update_token_counter),
|
||||
'inputs' : [self.tb_edit_caption, self.cb_use_raw_token],
|
||||
'inputs' : [self.tb_edit_caption],
|
||||
'outputs' : [self.token_counter_edit_caption]
|
||||
}
|
||||
update_interrogate_token_counter_args = {
|
||||
'fn' : wrap_queued_call(update_token_counter),
|
||||
'inputs' : [self.tb_interrogate, self.cb_use_raw_token],
|
||||
'inputs' : [self.tb_interrogate],
|
||||
'outputs' : [self.token_counter_interrogate]
|
||||
}
|
||||
|
||||
self.cb_use_raw_token.change(**update_caption_token_counter_args)
|
||||
self.cb_use_raw_token.change(**update_edit_caption_token_counter_args)
|
||||
self.cb_use_raw_token.change(**update_interrogate_token_counter_args)
|
||||
self.tb_caption.change(**update_caption_token_counter_args)
|
||||
self.tb_edit_caption.change(**update_edit_caption_token_counter_args)
|
||||
self.tb_interrogate.change(**update_interrogate_token_counter_args)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
img_path = dataset_gallery.selected_path
|
||||
if img_path:
|
||||
dte_instance.move_dataset_file(img_path, caption_ext, dest_dir, move_img, move_txt, move_bak)
|
||||
dte_instance.construct_tag_counts()
|
||||
dte_instance.construct_tag_infos()
|
||||
|
||||
elif target_data == 'All Displayed Ones':
|
||||
dte_instance.move_dataset(dest_dir, caption_ext, get_filters(), move_img, move_txt, move_bak)
|
||||
|
|
@ -95,7 +95,7 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
img_path = dataset_gallery.selected_path
|
||||
if img_path:
|
||||
dte_instance.delete_dataset_file(img_path, delete_img, caption_ext, delete_txt, delete_bak)
|
||||
dte_instance.construct_tag_counts()
|
||||
dte_instance.construct_tag_infos()
|
||||
|
||||
elif target_data == 'All Displayed Ones':
|
||||
dte_instance.delete_dataset(caption_ext, get_filters(), delete_img, delete_txt, delete_bak)
|
||||
|
|
|
|||
|
|
@ -1,42 +0,0 @@
|
|||
# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
|
||||
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
|
||||
from modules import shared
|
||||
import open_clip.tokenizer
|
||||
|
||||
class VanillaClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
|
||||
def vocab(self):
|
||||
return self.clip.tokenizer.get_vocab()
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.clip.tokenizer.byte_decoder
|
||||
|
||||
class OpenClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
self.tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
def vocab(self):
|
||||
return self.tokenizer.encoder
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.tokenizer.byte_decoder
|
||||
|
||||
def tokenize(text:str):
|
||||
clip = shared.sd_model.cond_stage_model.wrapped
|
||||
if isinstance(clip, FrozenCLIPEmbedder):
|
||||
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
elif isinstance(clip, FrozenOpenCLIPEmbedder):
|
||||
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
|
||||
|
||||
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
|
||||
return tokens
|
||||
|
||||
def token_count(text:str):
|
||||
return len(tokenize(text))
|
||||
Loading…
Reference in New Issue