Merge pull request #50 from toshiaki1729/feature/add-token-count

implement token count on "Edit Caption of Selected Image" tab
related to #40
pull/52/head
toshiaki1729 2023-03-07 01:08:25 +09:00 committed by GitHub
commit 94afc6e025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 141 additions and 16 deletions

View File

@ -137,6 +137,23 @@ document.addEventListener("DOMContentLoaded", function () {
dteModifiedGallery_filter.addClickNextHandler(dataset_tag_editor_gl_filter_images_next_clicked)
dteModifiedGallery_filter.addClickCloseHandler(dataset_tag_editor_gl_filter_images_close_clicked)
}
function changeTokenCounterPos(id, id_counter){
var prompt = gradioApp().getElementById(id)
var counter = gradioApp().getElementById(id_counter)
if(counter.parentElement == prompt.parentElement){
return
}
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter-dte")
prompt.parentElement.style.position = "relative"
counter.style.width = "auto"
}
changeTokenCounterPos('dte_caption', 'dte_caption_counter')
changeTokenCounterPos('dte_edit_caption', 'dte_edit_caption_counter')
changeTokenCounterPos('dte_interrogate', 'dte_interrogate_counter')
});
o.observe(gradioApp(), { childList: true, subtree: true })

View File

@ -40,14 +40,14 @@ GeneralConfig = namedtuple('GeneralConfig', [
])
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order'])
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
class Config:
@ -255,6 +255,7 @@ def on_ui_tabs():
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order
]
components_edit_selected = [
ui.edit_caption_of_selected_image.cb_use_raw_token,
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order

View File

@ -1,12 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Callable
from functools import reduce
import gradio as gr
from modules import shared
from modules import shared, extra_networks, prompt_parser
from modules.call_queue import wrap_queued_call
from modules.sd_hijack import model_hijack
from scripts.dte_instance import dte_module
from .ui_common import *
from .uibase import UIBase
from .tokenizer import clip_tokenizer
if TYPE_CHECKING:
from .ui_classes import *
@ -24,7 +28,9 @@ class EditCaptionOfSelectedImageUI(UIBase):
self.tb_hidden_edit_caption = gr.Textbox()
self.btn_hidden_save_caption = gr.Button(elem_id="dataset_tag_editor_btn_hidden_save_caption")
with gr.Tab(label='Read Caption from Selected Image'):
self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6)
self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption')
self.token_counter_caption = gr.HTML(value='<span></span>', elem_id='dte_caption_counter')
self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)')
with gr.Row():
self.btn_copy_caption = gr.Button(value='Copy and Overwrite')
self.btn_prepend_caption = gr.Button(value='Prepend')
@ -34,18 +40,23 @@ class EditCaptionOfSelectedImageUI(UIBase):
with gr.Row():
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
self.btn_interrogate_si = gr.Button(value='Interrogate')
self.tb_interrogate_selected_image = gr.Textbox(label='Interrogate Result', interactive=True, lines=6)
with gr.Column():
self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate')
self.token_counter_interrogate = gr.HTML(value='<span></span>', elem_id='dte_interrogate_counter')
with gr.Row():
self.btn_copy_interrogate = gr.Button(value='Copy and Overwrite')
self.btn_prepend_interrogate = gr.Button(value='Prepend')
self.btn_append_interrogate = gr.Button(value='Append')
self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6)
with gr.Column():
self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
with gr.Column():
self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6, elem_id= 'dte_edit_caption')
self.token_counter_edit_caption = gr.HTML(value='<span></span>', elem_id='dte_edit_caption_counter')
self.btn_apply_changes_selected_image = gr.Button(value='Apply changes to selected image', variant='primary')
gr.HTML("""Changes are not applied to the text files until the "Save all changes" button is pressed.""")
@ -134,24 +145,24 @@ class EditCaptionOfSelectedImageUI(UIBase):
self.btn_interrogate_si.click(
fn=interrogate_selected_image,
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
outputs=[self.tb_interrogate_selected_image]
outputs=[self.tb_interrogate]
)
self.btn_copy_interrogate.click(
fn=lambda a:a,
inputs=[self.tb_interrogate_selected_image],
inputs=[self.tb_interrogate],
outputs=[self.tb_edit_caption]
)
self.btn_append_interrogate.click(
fn=lambda a, b : b + (', ' if a and b else '') + a,
inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption],
inputs=[self.tb_interrogate, self.tb_edit_caption],
outputs=[self.tb_edit_caption]
)
self.btn_prepend_interrogate.click(
fn=lambda a, b : a + (', ' if a and b else '') + b,
inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption],
inputs=[self.tb_interrogate, self.tb_edit_caption],
outputs=[self.tb_edit_caption]
)
@ -182,3 +193,40 @@ class EditCaptionOfSelectedImageUI(UIBase):
inputs=self.cb_sort_caption_on_save,
outputs=self.sort_settings
)
def update_token_counter(text:str, use_raw:bool):
if use_raw:
token_count = clip_tokenizer.token_count(text)
max_length = model_hijack.clip.get_target_prompt_token_count(token_count)
else:
try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
except Exception:
prompt = text
token_count, max_length = model_hijack.get_prompt_lengths(prompt)
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
update_caption_token_counter_args = {
'fn' : wrap_queued_call(update_token_counter),
'inputs' : [self.tb_caption, self.cb_use_raw_token],
'outputs' : [self.token_counter_caption]
}
update_edit_caption_token_counter_args = {
'fn' : wrap_queued_call(update_token_counter),
'inputs' : [self.tb_edit_caption, self.cb_use_raw_token],
'outputs' : [self.token_counter_edit_caption]
}
update_interrogate_token_counter_args = {
'fn' : wrap_queued_call(update_token_counter),
'inputs' : [self.tb_interrogate, self.cb_use_raw_token],
'outputs' : [self.token_counter_interrogate]
}
self.cb_use_raw_token.change(**update_caption_token_counter_args)
self.cb_use_raw_token.change(**update_edit_caption_token_counter_args)
self.cb_use_raw_token.change(**update_interrogate_token_counter_args)
self.tb_caption.change(**update_caption_token_counter_args)
self.tb_edit_caption.change(**update_edit_caption_token_counter_args)
self.tb_interrogate.change(**update_interrogate_token_counter_args)

View File

@ -0,0 +1 @@
from . import clip_tokenizer

View File

@ -0,0 +1,42 @@
# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
from modules import shared
import open_clip.tokenizer
class VanillaClip:
def __init__(self, clip):
self.clip = clip
def vocab(self):
return self.clip.tokenizer.get_vocab()
def byte_decoder(self):
return self.clip.tokenizer.byte_decoder
class OpenClip:
def __init__(self, clip):
self.clip = clip
self.tokenizer = open_clip.tokenizer._tokenizer
def vocab(self):
return self.tokenizer.encoder
def byte_decoder(self):
return self.tokenizer.byte_decoder
def tokenize(text:str):
clip = shared.sd_model.cond_stage_model.wrapped
if isinstance(clip, FrozenCLIPEmbedder):
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
elif isinstance(clip, FrozenOpenCLIPEmbedder):
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
else:
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
return tokens
def token_count(text:str):
return len(tokenize(text))

16
style.css Normal file
View File

@ -0,0 +1,16 @@
.token-counter-dte{
position: absolute;
display: inline-block;
right: 2em;
min-width: 0 !important;
width: auto;
z-index: 100;
}
.token-counter-dte div{
display: inline;
}
.token-counter-dte span{
padding: 0.1em 0.75em;
}