diff --git a/javascript/99_main.js b/javascript/99_main.js
index 3e704d7..92eb383 100644
--- a/javascript/99_main.js
+++ b/javascript/99_main.js
@@ -137,6 +137,23 @@ document.addEventListener("DOMContentLoaded", function () {
dteModifiedGallery_filter.addClickNextHandler(dataset_tag_editor_gl_filter_images_next_clicked)
dteModifiedGallery_filter.addClickCloseHandler(dataset_tag_editor_gl_filter_images_close_clicked)
}
+
+ function changeTokenCounterPos(id, id_counter){
+ var prompt = gradioApp().getElementById(id)
+ var counter = gradioApp().getElementById(id_counter)
+
+ if(counter.parentElement == prompt.parentElement){
+ return
+ }
+
+ prompt.parentElement.insertBefore(counter, prompt)
+ counter.classList.add("token-counter-dte")
+ prompt.parentElement.style.position = "relative"
+ counter.style.width = "auto"
+ }
+ changeTokenCounterPos('dte_caption', 'dte_caption_counter')
+ changeTokenCounterPos('dte_edit_caption', 'dte_edit_caption_counter')
+ changeTokenCounterPos('dte_interrogate', 'dte_interrogate_counter')
});
o.observe(gradioApp(), { childList: true, subtree: true })
diff --git a/scripts/main.py b/scripts/main.py
index b586089..6a82528 100644
--- a/scripts/main.py
+++ b/scripts/main.py
@@ -40,14 +40,14 @@ GeneralConfig = namedtuple('GeneralConfig', [
])
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order'])
-EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
+EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value)
-CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
+CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
class Config:
@@ -255,6 +255,7 @@ def on_ui_tabs():
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order
]
components_edit_selected = [
+ ui.edit_caption_of_selected_image.cb_use_raw_token,
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order
diff --git a/scripts/ui/tab_edit_caption_of_selected_image.py b/scripts/ui/tab_edit_caption_of_selected_image.py
index 1f998f8..76ec283 100644
--- a/scripts/ui/tab_edit_caption_of_selected_image.py
+++ b/scripts/ui/tab_edit_caption_of_selected_image.py
@@ -1,12 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Callable
+from functools import reduce
import gradio as gr
-from modules import shared
+from modules import shared, extra_networks, prompt_parser
+from modules.call_queue import wrap_queued_call
+from modules.sd_hijack import model_hijack
from scripts.dte_instance import dte_module
from .ui_common import *
from .uibase import UIBase
+from .tokenizer import clip_tokenizer
if TYPE_CHECKING:
from .ui_classes import *
@@ -24,7 +28,9 @@ class EditCaptionOfSelectedImageUI(UIBase):
self.tb_hidden_edit_caption = gr.Textbox()
self.btn_hidden_save_caption = gr.Button(elem_id="dataset_tag_editor_btn_hidden_save_caption")
with gr.Tab(label='Read Caption from Selected Image'):
- self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6)
+ self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption')
+ self.token_counter_caption = gr.HTML(value='', elem_id='dte_caption_counter')
+ self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)')
with gr.Row():
self.btn_copy_caption = gr.Button(value='Copy and Overwrite')
self.btn_prepend_caption = gr.Button(value='Prepend')
@@ -34,18 +40,23 @@ class EditCaptionOfSelectedImageUI(UIBase):
with gr.Row():
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
self.btn_interrogate_si = gr.Button(value='Interrogate')
- self.tb_interrogate_selected_image = gr.Textbox(label='Interrogate Result', interactive=True, lines=6)
+ with gr.Column():
+ self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate')
+ self.token_counter_interrogate = gr.HTML(value='', elem_id='dte_interrogate_counter')
with gr.Row():
self.btn_copy_interrogate = gr.Button(value='Copy and Overwrite')
self.btn_prepend_interrogate = gr.Button(value='Prepend')
self.btn_append_interrogate = gr.Button(value='Append')
- self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
- self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
- with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
- self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
- self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
- self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
- self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6)
+ with gr.Column():
+ self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
+ self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
+ with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
+ self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
+ self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
+ self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
+ with gr.Column():
+ self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6, elem_id= 'dte_edit_caption')
+ self.token_counter_edit_caption = gr.HTML(value='', elem_id='dte_edit_caption_counter')
self.btn_apply_changes_selected_image = gr.Button(value='Apply changes to selected image', variant='primary')
gr.HTML("""Changes are not applied to the text files until the "Save all changes" button is pressed.""")
@@ -134,24 +145,24 @@ class EditCaptionOfSelectedImageUI(UIBase):
self.btn_interrogate_si.click(
fn=interrogate_selected_image,
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
- outputs=[self.tb_interrogate_selected_image]
+ outputs=[self.tb_interrogate]
)
self.btn_copy_interrogate.click(
fn=lambda a:a,
- inputs=[self.tb_interrogate_selected_image],
+ inputs=[self.tb_interrogate],
outputs=[self.tb_edit_caption]
)
self.btn_append_interrogate.click(
fn=lambda a, b : b + (', ' if a and b else '') + a,
- inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption],
+ inputs=[self.tb_interrogate, self.tb_edit_caption],
outputs=[self.tb_edit_caption]
)
self.btn_prepend_interrogate.click(
fn=lambda a, b : a + (', ' if a and b else '') + b,
- inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption],
+ inputs=[self.tb_interrogate, self.tb_edit_caption],
outputs=[self.tb_edit_caption]
)
@@ -182,3 +193,40 @@ class EditCaptionOfSelectedImageUI(UIBase):
inputs=self.cb_sort_caption_on_save,
outputs=self.sort_settings
)
+
+ def update_token_counter(text:str, use_raw:bool):
+ if use_raw:
+ token_count = clip_tokenizer.token_count(text)
+ max_length = model_hijack.clip.get_target_prompt_token_count(token_count)
+ else:
+ try:
+ text, _ = extra_networks.parse_prompt(text)
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
+ except Exception:
+ prompt = text
+ token_count, max_length = model_hijack.get_prompt_lengths(prompt)
+ return f"{token_count}/{max_length}"
+
+ update_caption_token_counter_args = {
+ 'fn' : wrap_queued_call(update_token_counter),
+ 'inputs' : [self.tb_caption, self.cb_use_raw_token],
+ 'outputs' : [self.token_counter_caption]
+ }
+ update_edit_caption_token_counter_args = {
+ 'fn' : wrap_queued_call(update_token_counter),
+ 'inputs' : [self.tb_edit_caption, self.cb_use_raw_token],
+ 'outputs' : [self.token_counter_edit_caption]
+ }
+ update_interrogate_token_counter_args = {
+ 'fn' : wrap_queued_call(update_token_counter),
+ 'inputs' : [self.tb_interrogate, self.cb_use_raw_token],
+ 'outputs' : [self.token_counter_interrogate]
+ }
+
+ self.cb_use_raw_token.change(**update_caption_token_counter_args)
+ self.cb_use_raw_token.change(**update_edit_caption_token_counter_args)
+ self.cb_use_raw_token.change(**update_interrogate_token_counter_args)
+ self.tb_caption.change(**update_caption_token_counter_args)
+ self.tb_edit_caption.change(**update_edit_caption_token_counter_args)
+ self.tb_interrogate.change(**update_interrogate_token_counter_args)
diff --git a/scripts/ui/tokenizer/__init__.py b/scripts/ui/tokenizer/__init__.py
new file mode 100644
index 0000000..76ace67
--- /dev/null
+++ b/scripts/ui/tokenizer/__init__.py
@@ -0,0 +1 @@
+from . import clip_tokenizer
\ No newline at end of file
diff --git a/scripts/ui/tokenizer/clip_tokenizer.py b/scripts/ui/tokenizer/clip_tokenizer.py
new file mode 100644
index 0000000..d257eec
--- /dev/null
+++ b/scripts/ui/tokenizer/clip_tokenizer.py
@@ -0,0 +1,42 @@
+# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
+# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
+
+from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
+from modules import shared
+import open_clip.tokenizer
+
+class VanillaClip:
+ def __init__(self, clip):
+ self.clip = clip
+
+ def vocab(self):
+ return self.clip.tokenizer.get_vocab()
+
+ def byte_decoder(self):
+ return self.clip.tokenizer.byte_decoder
+
+class OpenClip:
+ def __init__(self, clip):
+ self.clip = clip
+ self.tokenizer = open_clip.tokenizer._tokenizer
+
+ def vocab(self):
+ return self.tokenizer.encoder
+
+ def byte_decoder(self):
+ return self.tokenizer.byte_decoder
+
+def tokenize(text:str):
+ clip = shared.sd_model.cond_stage_model.wrapped
+ if isinstance(clip, FrozenCLIPEmbedder):
+ clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
+ elif isinstance(clip, FrozenOpenCLIPEmbedder):
+ clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
+ else:
+ raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
+
+ tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
+ return tokens
+
+def token_count(text:str):
+ return len(tokenize(text))
diff --git a/style.css b/style.css
new file mode 100644
index 0000000..290108c
--- /dev/null
+++ b/style.css
@@ -0,0 +1,16 @@
+.token-counter-dte{
+ position: absolute;
+ display: inline-block;
+ right: 2em;
+ min-width: 0 !important;
+ width: auto;
+ z-index: 100;
+}
+
+.token-counter-dte div{
+ display: inline;
+}
+
+.token-counter-dte span{
+ padding: 0.1em 0.75em;
+}
\ No newline at end of file