Merge pull request #50 from toshiaki1729/feature/add-token-count
implement token count on "Edit Caption of Selected Image" tab related to #40pull/52/head
commit
94afc6e025
|
|
@ -137,6 +137,23 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
dteModifiedGallery_filter.addClickNextHandler(dataset_tag_editor_gl_filter_images_next_clicked)
|
||||
dteModifiedGallery_filter.addClickCloseHandler(dataset_tag_editor_gl_filter_images_close_clicked)
|
||||
}
|
||||
|
||||
function changeTokenCounterPos(id, id_counter){
|
||||
var prompt = gradioApp().getElementById(id)
|
||||
var counter = gradioApp().getElementById(id_counter)
|
||||
|
||||
if(counter.parentElement == prompt.parentElement){
|
||||
return
|
||||
}
|
||||
|
||||
prompt.parentElement.insertBefore(counter, prompt)
|
||||
counter.classList.add("token-counter-dte")
|
||||
prompt.parentElement.style.position = "relative"
|
||||
counter.style.width = "auto"
|
||||
}
|
||||
changeTokenCounterPos('dte_caption', 'dte_caption_counter')
|
||||
changeTokenCounterPos('dte_edit_caption', 'dte_edit_caption_counter')
|
||||
changeTokenCounterPos('dte_interrogate', 'dte_interrogate_counter')
|
||||
});
|
||||
|
||||
o.observe(gradioApp(), { childList: true, subtree: true })
|
||||
|
|
|
|||
|
|
@ -40,14 +40,14 @@ GeneralConfig = namedtuple('GeneralConfig', [
|
|||
])
|
||||
FilterConfig = namedtuple('FilterConfig', ['sw_prefix', 'sw_suffix', 'sw_regex','sort_by', 'sort_order', 'logic'])
|
||||
BatchEditConfig = namedtuple('BatchEditConfig', ['show_only_selected', 'prepend', 'use_regex', 'target', 'sw_prefix', 'sw_suffix', 'sw_regex', 'sory_by', 'sort_order', 'batch_sort_by', 'batch_sort_order'])
|
||||
EditSelectedConfig = namedtuple('EditSelectedConfig', ['auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
|
||||
EditSelectedConfig = namedtuple('EditSelectedConfig', ['use_raw_token', 'auto_copy', 'sort_on_save', 'warn_change_not_saved', 'use_interrogator_name', 'sort_by', 'sort_order'])
|
||||
MoveDeleteConfig = namedtuple('MoveDeleteConfig', ['range', 'target', 'caption_ext', 'destination'])
|
||||
|
||||
CFG_GENERAL_DEFAULT = GeneralConfig(True, '', '.txt', False, True, 'No', [], False, 0.7, False, 0.35, False, '', '', True, False, False)
|
||||
CFG_FILTER_P_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'AND')
|
||||
CFG_FILTER_N_DEFAULT = FilterConfig(False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, 'OR')
|
||||
CFG_BATCH_EDIT_DEFAULT = BatchEditConfig(True, False, False, 'Only Selected Tags', False, False, False, SortBy.ALPHA.value, SortOrder.ASC.value, SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_EDIT_SELECTED_DEFAULT = EditSelectedConfig(True, False, False, False, '', SortBy.ALPHA.value, SortOrder.ASC.value)
|
||||
CFG_MOVE_DELETE_DEFAULT = MoveDeleteConfig('Selected One', [], '.txt', '')
|
||||
|
||||
class Config:
|
||||
|
|
@ -255,6 +255,7 @@ def on_ui_tabs():
|
|||
ui.batch_edit_captions.rb_sort_by, ui.batch_edit_captions.rb_sort_order
|
||||
]
|
||||
components_edit_selected = [
|
||||
ui.edit_caption_of_selected_image.cb_use_raw_token,
|
||||
ui.edit_caption_of_selected_image.cb_copy_caption_automatically, ui.edit_caption_of_selected_image.cb_sort_caption_on_save,
|
||||
ui.edit_caption_of_selected_image.cb_ask_save_when_caption_changed, ui.edit_caption_of_selected_image.dd_intterogator_names_si,
|
||||
ui.edit_caption_of_selected_image.rb_sort_by, ui.edit_caption_of_selected_image.rb_sort_order
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, List, Callable
|
||||
from functools import reduce
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared
|
||||
from modules import shared, extra_networks, prompt_parser
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules.sd_hijack import model_hijack
|
||||
from scripts.dte_instance import dte_module
|
||||
|
||||
from .ui_common import *
|
||||
from .uibase import UIBase
|
||||
from .tokenizer import clip_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ui_classes import *
|
||||
|
|
@ -24,7 +28,9 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
self.tb_hidden_edit_caption = gr.Textbox()
|
||||
self.btn_hidden_save_caption = gr.Button(elem_id="dataset_tag_editor_btn_hidden_save_caption")
|
||||
with gr.Tab(label='Read Caption from Selected Image'):
|
||||
self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6)
|
||||
self.tb_caption = gr.Textbox(label='Caption of Selected Image', interactive=False, lines=6, elem_id='dte_caption')
|
||||
self.token_counter_caption = gr.HTML(value='<span></span>', elem_id='dte_caption_counter')
|
||||
self.cb_use_raw_token = gr.Checkbox(value=cfg_edit_selected.use_raw_token, label='Use raw CLIP token for token count (without embeddings)')
|
||||
with gr.Row():
|
||||
self.btn_copy_caption = gr.Button(value='Copy and Overwrite')
|
||||
self.btn_prepend_caption = gr.Button(value='Prepend')
|
||||
|
|
@ -34,18 +40,23 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
with gr.Row():
|
||||
self.dd_intterogator_names_si = gr.Dropdown(label = 'Interrogator', choices=dte_module.INTERROGATOR_NAMES, value=cfg_edit_selected.use_interrogator_name, interactive=True, multiselect=False)
|
||||
self.btn_interrogate_si = gr.Button(value='Interrogate')
|
||||
self.tb_interrogate_selected_image = gr.Textbox(label='Interrogate Result', interactive=True, lines=6)
|
||||
with gr.Column():
|
||||
self.tb_interrogate = gr.Textbox(label='Interrogate Result', interactive=True, lines=6, elem_id='dte_interrogate')
|
||||
self.token_counter_interrogate = gr.HTML(value='<span></span>', elem_id='dte_interrogate_counter')
|
||||
with gr.Row():
|
||||
self.btn_copy_interrogate = gr.Button(value='Copy and Overwrite')
|
||||
self.btn_prepend_interrogate = gr.Button(value='Prepend')
|
||||
self.btn_append_interrogate = gr.Button(value='Append')
|
||||
self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
|
||||
self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
|
||||
with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
|
||||
self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
|
||||
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
|
||||
self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
|
||||
self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6)
|
||||
with gr.Column():
|
||||
self.cb_copy_caption_automatically = gr.Checkbox(value=cfg_edit_selected.auto_copy, label='Copy caption from selected images automatically')
|
||||
self.cb_sort_caption_on_save = gr.Checkbox(value=cfg_edit_selected.sort_on_save, label='Sort caption on save')
|
||||
with gr.Row(visible=cfg_edit_selected.sort_on_save) as self.sort_settings:
|
||||
self.rb_sort_by = gr.Radio(choices=[e.value for e in SortBy], value=cfg_edit_selected.sort_by, interactive=True, label='Sort by')
|
||||
self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_edit_selected.sort_order, interactive=True, label='Sort Order')
|
||||
self.cb_ask_save_when_caption_changed = gr.Checkbox(value=cfg_edit_selected.warn_change_not_saved, label='Warn if changes in caption is not saved')
|
||||
with gr.Column():
|
||||
self.tb_edit_caption = gr.Textbox(label='Edit Caption', interactive=True, lines=6, elem_id= 'dte_edit_caption')
|
||||
self.token_counter_edit_caption = gr.HTML(value='<span></span>', elem_id='dte_edit_caption_counter')
|
||||
self.btn_apply_changes_selected_image = gr.Button(value='Apply changes to selected image', variant='primary')
|
||||
|
||||
gr.HTML("""Changes are not applied to the text files until the "Save all changes" button is pressed.""")
|
||||
|
|
@ -134,24 +145,24 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
self.btn_interrogate_si.click(
|
||||
fn=interrogate_selected_image,
|
||||
inputs=[self.dd_intterogator_names_si, load_dataset.cb_use_custom_threshold_booru, load_dataset.sl_custom_threshold_booru, load_dataset.cb_use_custom_threshold_waifu, load_dataset.sl_custom_threshold_waifu],
|
||||
outputs=[self.tb_interrogate_selected_image]
|
||||
outputs=[self.tb_interrogate]
|
||||
)
|
||||
|
||||
self.btn_copy_interrogate.click(
|
||||
fn=lambda a:a,
|
||||
inputs=[self.tb_interrogate_selected_image],
|
||||
inputs=[self.tb_interrogate],
|
||||
outputs=[self.tb_edit_caption]
|
||||
)
|
||||
|
||||
self.btn_append_interrogate.click(
|
||||
fn=lambda a, b : b + (', ' if a and b else '') + a,
|
||||
inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption],
|
||||
inputs=[self.tb_interrogate, self.tb_edit_caption],
|
||||
outputs=[self.tb_edit_caption]
|
||||
)
|
||||
|
||||
self.btn_prepend_interrogate.click(
|
||||
fn=lambda a, b : a + (', ' if a and b else '') + b,
|
||||
inputs=[self.tb_interrogate_selected_image, self.tb_edit_caption],
|
||||
inputs=[self.tb_interrogate, self.tb_edit_caption],
|
||||
outputs=[self.tb_edit_caption]
|
||||
)
|
||||
|
||||
|
|
@ -182,3 +193,40 @@ class EditCaptionOfSelectedImageUI(UIBase):
|
|||
inputs=self.cb_sort_caption_on_save,
|
||||
outputs=self.sort_settings
|
||||
)
|
||||
|
||||
def update_token_counter(text:str, use_raw:bool):
|
||||
if use_raw:
|
||||
token_count = clip_tokenizer.token_count(text)
|
||||
max_length = model_hijack.clip.get_target_prompt_token_count(token_count)
|
||||
else:
|
||||
try:
|
||||
text, _ = extra_networks.parse_prompt(text)
|
||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||
prompt = reduce(lambda list1, list2: list1+list2, prompt_flat_list)
|
||||
except Exception:
|
||||
prompt = text
|
||||
token_count, max_length = model_hijack.get_prompt_lengths(prompt)
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
update_caption_token_counter_args = {
|
||||
'fn' : wrap_queued_call(update_token_counter),
|
||||
'inputs' : [self.tb_caption, self.cb_use_raw_token],
|
||||
'outputs' : [self.token_counter_caption]
|
||||
}
|
||||
update_edit_caption_token_counter_args = {
|
||||
'fn' : wrap_queued_call(update_token_counter),
|
||||
'inputs' : [self.tb_edit_caption, self.cb_use_raw_token],
|
||||
'outputs' : [self.token_counter_edit_caption]
|
||||
}
|
||||
update_interrogate_token_counter_args = {
|
||||
'fn' : wrap_queued_call(update_token_counter),
|
||||
'inputs' : [self.tb_interrogate, self.cb_use_raw_token],
|
||||
'outputs' : [self.token_counter_interrogate]
|
||||
}
|
||||
|
||||
self.cb_use_raw_token.change(**update_caption_token_counter_args)
|
||||
self.cb_use_raw_token.change(**update_edit_caption_token_counter_args)
|
||||
self.cb_use_raw_token.change(**update_interrogate_token_counter_args)
|
||||
self.tb_caption.change(**update_caption_token_counter_args)
|
||||
self.tb_edit_caption.change(**update_edit_caption_token_counter_args)
|
||||
self.tb_interrogate.change(**update_interrogate_token_counter_args)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from . import clip_tokenizer
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# Brought from AUTOMATIC1111's stable-diffusion-webui-tokenizer and modified
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui-tokenizer/blob/master/scripts/tokenizer.py
|
||||
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenOpenCLIPEmbedder
|
||||
from modules import shared
|
||||
import open_clip.tokenizer
|
||||
|
||||
class VanillaClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
|
||||
def vocab(self):
|
||||
return self.clip.tokenizer.get_vocab()
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.clip.tokenizer.byte_decoder
|
||||
|
||||
class OpenClip:
|
||||
def __init__(self, clip):
|
||||
self.clip = clip
|
||||
self.tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
def vocab(self):
|
||||
return self.tokenizer.encoder
|
||||
|
||||
def byte_decoder(self):
|
||||
return self.tokenizer.byte_decoder
|
||||
|
||||
def tokenize(text:str):
|
||||
clip = shared.sd_model.cond_stage_model.wrapped
|
||||
if isinstance(clip, FrozenCLIPEmbedder):
|
||||
clip = VanillaClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
elif isinstance(clip, FrozenOpenCLIPEmbedder):
|
||||
clip = OpenClip(shared.sd_model.cond_stage_model.wrapped)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown CLIP model: {type(clip).__name__}')
|
||||
|
||||
tokens = shared.sd_model.cond_stage_model.tokenize([text])[0]
|
||||
return tokens
|
||||
|
||||
def token_count(text:str):
|
||||
return len(tokenize(text))
|
||||
Loading…
Reference in New Issue