diff --git a/javascript/99_main.js b/javascript/99_main.js index 3e704d7..92eb383 100644 --- a/javascript/99_main.js +++ b/javascript/99_main.js @@ -137,6 +137,23 @@ document.addEventListener("DOMContentLoaded", function () { dteModifiedGallery_filter.addClickNextHandler(dataset_tag_editor_gl_filter_images_next_clicked) dteModifiedGallery_filter.addClickCloseHandler(dataset_tag_editor_gl_filter_images_close_clicked) } + + function changeTokenCounterPos(id, id_counter){ + var prompt = gradioApp().getElementById(id) + var counter = gradioApp().getElementById(id_counter) + + if(counter.parentElement == prompt.parentElement){ + return + } + + prompt.parentElement.insertBefore(counter, prompt) + counter.classList.add("token-counter-dte") + prompt.parentElement.style.position = "relative" + counter.style.width = "auto" + } + changeTokenCounterPos('dte_caption', 'dte_caption_counter') + changeTokenCounterPos('dte_edit_caption', 'dte_edit_caption_counter') + changeTokenCounterPos('dte_interrogate', 'dte_interrogate_counter') }); o.observe(gradioApp(), { childList: true, subtree: true }) diff --git a/scripts/main.py b/scripts/main.py index b586089..6a82528 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -40,14 +40,14 @@ GeneralConfig = namedtuple('GeneralConfig', [ ]) FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic']) BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order']) -EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order']) +EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order']) MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination']) CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False) CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND') CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR') CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value) -CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value) +CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value) CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '') class Config: @@ -255,6 +255,7 @@ def on_ui_tabs(): ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order ] components_edit_selected = [ + ui.edit_caption_of_selected_image.cb_use_raw_token, ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save, ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si, ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order diff --git a/scripts/ui/tab_edit_caption_of_selected_image.py b/scripts/ui/tab_edit_caption_of_selected_image.py index 1f998f8..76ec283 100644 --- a/scripts/ui/tab_edit_caption_of_selected_image.py +++ b/scripts/ui/tab_edit_caption_of_selected_image.py @@ -1,12 +1,16 @@ from __future__ import annotations from typing import TYPE_CHECKING, List, Callable +from functools import reduce import gradio as gr -from modules import shared +from modules import shared, extra_networks, prompt_parser +from modules.call_queue import wrap_queued_call +from modules.sd_hijack import model_hijack from scripts.dte_instance import dte_module from .ui_common import * from .uibase import UIBase +from .tokenizer import clip_tokenizer if TYPE_CHECKING: from .ui_classes import * @@ -24,7 +28,9 @@ class EditCaptionOfSelectedImageUI(UIBase): self.tb_hidden_edit_caption = gr.Textbox() self.btn_hidden_save_caption = gr.Button(elem_id="dataset_tag_editor_btn_hidden_save_caption") with gr.Tab(label='Read Caption from Selected Image'): - self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6) + self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption') + self.token_counter_caption = gr.HTML(value='', elem_id='dte_caption_counter') + self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)') with gr.Row(): self.btn_copy_caption = gr.Button(value='Copy and Overwrite') self.btn_prepend_caption = gr.Button(value='Prepend') @@ -34,18 +40,23 @@ class EditCaptionOfSelectedImageUI(UIBase): with gr.Row(): self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False) self.btn_interrogate_si = gr.Button(value='Interrogate') - self.tb_interrogate_selected_image = gr.Textbox(label='Interrogate Result', interactive=True, lines=6) + with gr.Column(): + self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate') + self.token_counter_interrogate = gr.HTML(value='', elem_id='dte_interrogate_counter') with gr.Row(): self.btn_copy_interrogate = gr.Button(value='Copy and Overwrite') self.btn_prepend_interrogate = gr.Button(value='Prepend') self.btn_append_interrogate = gr.Button(value='Append') - self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically') - self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save') - with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings: - self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by') - self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order') - self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved') - self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6) + with gr.Column(): + self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically') + self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save') + with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings: + self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by') + self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order') + self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved') + with gr.Column(): + self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6, elem_id= 'dte_edit_caption') + self.token_counter_edit_caption = gr.HTML(value='', elem_id='dte_edit_caption_counter') self.btn_apply_changes_selected_image = gr.Button(value='Apply changes to selected image', variant='primary') gr.HTML("""Changes are not applied to the text files until the "Save all changes" button is pressed.""") @@ -134,24 +145,24 @@ class EditCaptionOfSelectedImageUI(UIBase): self.btn_interrogate_si.click( fn=interrogate_selected_image, inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu], - outputs=[self.tb_interrogate_selected_image] + outputs=[self.tb_interrogate] ) self.btn_copy_interrogate.click( fn=lambda a:a, - inputs=[self.tb_interrogate_selected_image], + inputs=[self.tb_interrogate], outputs=[self.tb_edit_caption] ) self.btn_append_interrogate.click( fn=lambda a, b : b + (', ' if a and b else '') + a, - inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption], + inputs=[self.tb_interrogate, self.tb_edit_caption], outputs=[self.tb_edit_caption] ) self.btn_prepend_interrogate.click( fn=lambda a, b : a + (', ' if a and b else '') + b, - inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption], + inputs=[self.tb_interrogate, self.tb_edit_caption], outputs=[self.tb_edit_caption] ) @@ -182,3 +193,40 @@ class EditCaptionOfSelectedImageUI(UIBase): inputs=self.cb_sort_caption_on_save, outputs=self.sort_settings ) + + def update_token_counter(text:str, use_raw:bool): + if use_raw: + token_count = clip_tokenizer.token_count(text) + max_length = model_hijack.clip.get_target_prompt_token_count(token_count) + else: + try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list) + except Exception: + prompt = text + token_count, max_length = model_hijack.get_prompt_lengths(prompt) + return f"{token_count}/{max_length}" + + update_caption_token_counter_args = { + 'fn' : wrap_queued_call(update_token_counter), + 'inputs' : [self.tb_caption, self.cb_use_raw_token], + 'outputs' : [self.token_counter_caption] + } + update_edit_caption_token_counter_args = { + 'fn' : wrap_queued_call(update_token_counter), + 'inputs' : [self.tb_edit_caption, self.cb_use_raw_token], + 'outputs' : [self.token_counter_edit_caption] + } + update_interrogate_token_counter_args = { + 'fn' : wrap_queued_call(update_token_counter), + 'inputs' : [self.tb_interrogate, self.cb_use_raw_token], + 'outputs' : [self.token_counter_interrogate] + } + + self.cb_use_raw_token.change(**update_caption_token_counter_args) + self.cb_use_raw_token.change(**update_edit_caption_token_counter_args) + self.cb_use_raw_token.change(**update_interrogate_token_counter_args) + self.tb_caption.change(**update_caption_token_counter_args) + self.tb_edit_caption.change(**update_edit_caption_token_counter_args) + self.tb_interrogate.change(**update_interrogate_token_counter_args) diff --git a/scripts/ui/tokenizer/__init__.py b/scripts/ui/tokenizer/__init__.py new file mode 100644 index 0000000..76ace67 --- /dev/null +++ b/scripts/ui/tokenizer/__init__.py @@ -0,0 +1 @@ +from . import clip_tokenizer \ No newline at end of file diff --git a/scripts/ui/tokenizer/clip_tokenizer.py b/scripts/ui/tokenizer/clip_tokenizer.py new file mode 100644 index 0000000..d257eec --- /dev/null +++ b/scripts/ui/tokenizer/clip_tokenizer.py @@ -0,0 +1,42 @@ +# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified +# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py + +from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder +from modules import shared +import open_clip.tokenizer + +class VanillaClip: + def __init__(self, clip): + self.clip = clip + + def vocab(self): + return self.clip.tokenizer.get_vocab() + + def byte_decoder(self): + return self.clip.tokenizer.byte_decoder + +class OpenClip: + def __init__(self, clip): + self.clip = clip + self.tokenizer = open_clip.tokenizer._tokenizer + + def vocab(self): + return self.tokenizer.encoder + + def byte_decoder(self): + return self.tokenizer.byte_decoder + +def tokenize(text:str): + clip = shared.sd_model.cond_stage_model.wrapped + if isinstance(clip, FrozenCLIPEmbedder): + clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped) + elif isinstance(clip, FrozenOpenCLIPEmbedder): + clip = OpenClip(shared.sd_model.cond_stage_model.wrapped) + else: + raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}') + + tokens = shared.sd_model.cond_stage_model.tokenize([text])[0] + return tokens + +def token_count(text:str): + return len(tokenize(text)) diff --git a/style.css b/style.css new file mode 100644 index 0000000..290108c --- /dev/null +++ b/style.css @@ -0,0 +1,16 @@ +.token-counter-dte{ + position: absolute; + display: inline-block; + right: 2em; + min-width: 0 !important; + width: auto; + z-index: 100; +} + +.token-counter-dte div{ + display: inline; +} + +.token-counter-dte span{ + padding: 0.1em 0.75em; +} \ No newline at end of file