stable-diffusion-webui-wd14.../tagger/ui.py

452 lines
18 KiB
Python

import gradio as gr
from PIL import Image
from typing import Dict, Tuple, List
from modules import ui
from modules import generation_parameters_copypaste as parameters_copypaste
from tagger import utils
from tagger.interrogator import Interrogator as It
from webui import wrap_gradio_gpu_call
from tagger.uiset import IOData, QData, it_ret_tp
from tensorflow import __version__ as tf_version
from packaging import version
BATCH_REWRITE = 'Update tag lists'
def unload_interrogators() -> List[str]:
unloaded_models = 0
for i in utils.interrogators.values():
if i.unload():
unloaded_models = unloaded_models + 1
return (f'Successfully unload {unloaded_models} model(s)',)
def check_for_errors(name) -> str:
if name not in utils.interrogators:
return f"'{name}': invalid interrogator"
if len(QData.search_tags) != len(QData.replace_tags):
return 'search, replace: unequal len, replacements > 1.'
return ''
def on_interrogate(button: str, name: str, inverse=False) -> it_ret_tp:
if It.input["input_glob"] == '':
return (None, None, None, 'No input directory selected')
err = check_for_errors(name)
if err != '':
return (None, None, None, err)
it: It = utils.interrogators[name]
QData.inverse = inverse
return it.batch_interrogate(button == BATCH_REWRITE)
def on_inverse_interrogate(
button: str, name: str
) -> Tuple[str, Dict[str, float], str]:
ret = on_interrogate(button, name, True)
return (ret[0], ret[2], ret[3])
def on_interrogate_image(image: Image, interrogator: str) -> it_ret_tp:
if image is None:
return (None, None, None, 'No image selected')
err = check_for_errors(interrogator)
if err != '':
return (None, None, None, err)
interrogator: It = utils.interrogators[interrogator]
return interrogator.interrogate_image(image)
def move_filter_to_input_fn(
tag_search_filter: str,
name: str,
field: str
) -> Tuple[str, str, Dict[str, float], Dict[str, float], str]:
if It.output is None:
return (None, None, None, None, '')
filt = {(k, v) for k, v in It.output[2].items() if tag_search_filter in k}
if len(filt) == 0:
return (None, None, None, None, '')
add = set(dict(filt).keys())
if It.input[field] != '':
add = add.union({x.strip() for x in It.input[field].split(',')})
It.input[field] = ', '.join(add)
ret = on_interrogate(BATCH_REWRITE, name, QData.inverse)
return (It.input[field],) + ret
def move_filter_to_keep_fn(
tag_search_filter: str, name: str
) -> Tuple[str, str, Dict[str, float], str]:
ret = move_filter_to_input_fn(tag_search_filter, name, "keep")
# ratings are not displayed on this tab
return ('',) + ret[:2] + ret[3:]
def move_filter_to_exclude_fn(
tag_search_filter: str, name: str
) -> Tuple[str, str, Dict[str, float], Dict[str, float], str]:
return ('',) + move_filter_to_input_fn(tag_search_filter, name, "exclude")
def on_tag_search_filter_change(
part: str
) -> Tuple[str, Dict[str, float], str]:
if It.output is None:
return (None, None, '')
if len(part) < 2:
return (It.output[0], It.output[2], '')
tags = dict(filter(lambda x: part in x[0], It.output[2].items()))
return (', '.join(tags.keys()), tags, '')
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as tagger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
# input components
with gr.Tabs():
with gr.TabItem(label='Single process'):
image = gr.Image(
label='Source',
source='upload',
interactive=True,
type="pil"
)
image_submit = gr.Button(
value='Interrogate image',
variant='primary'
)
with gr.TabItem(label='Batch from directory'):
input_glob = utils.preset.component(
gr.Textbox,
value=It.input["input_glob"],
label='Input directory - See also settings tab.',
placeholder='/path/to/images or to/images/**/*'
)
output_dir = utils.preset.component(
gr.Textbox,
value=It.input["output_dir"],
label='Output directory',
placeholder='Leave blank to save images '
'to the same path.'
)
batch_rewrite = gr.Button(value=BATCH_REWRITE)
batch_submit = gr.Button(
value='Interrogate',
variant='primary'
)
info = gr.HTML(
label='Info',
interactive=False,
elem_classes=['info']
)
# preset selector
with gr.Row(variant='compact'):
available_presets = utils.preset.list()
selected_preset = gr.Dropdown(
label='Preset',
choices=available_presets,
value=available_presets[0]
)
save_preset_button = gr.Button(
value=ui.save_style_symbol
)
ui.create_refresh_button(
selected_preset,
lambda: None,
lambda: {'choices': utils.preset.list()},
'refresh_preset'
)
# interrogator selector
with gr.Column():
with gr.Row(variant='compact'):
interrogator_names = utils.refresh_interrogators()
interrogator = utils.preset.component(
gr.Dropdown,
label='Interrogator',
choices=interrogator_names,
value=(
None
if len(interrogator_names) < 1 else
interrogator_names[-1]
)
)
ui.create_refresh_button(
interrogator,
lambda: None,
lambda: {'choices': utils.refresh_interrogators()},
'refresh_interrogator'
)
unload_all_models = gr.Button(
value='Unload all interrogate models'
)
add_tags = utils.preset.component(
gr.Textbox,
label='Additional tags (comma split)',
elem_id='additional-tags'
)
with gr.Row(variant='compact'):
with gr.Column(variant='compact'):
threshold = utils.preset.component(
gr.Slider,
label='Threshold',
minimum=0,
maximum=1,
value=QData.threshold
)
cumulative = utils.preset.component(
gr.Checkbox,
label='combine interrogations',
value=It.input["cumulative"]
)
threshold_on_average = utils.preset.component(
gr.Checkbox,
label='threshold applies on average of all images '
'and interrogations',
value=It.input["threshold_on_average"]
)
save_tags = utils.preset.component(
gr.Checkbox,
label='Save to tags files',
value=IOData.save_tags
)
with gr.Column(variant='compact'):
count_threshold = utils.preset.component(
gr.Slider,
label='Tag count threshold',
minimum=1,
maximum=500,
value=QData.count_threshold,
step=1.0
)
large_query = utils.preset.component(
gr.Checkbox,
label='huge batch query (tensorflow 2.10, '
'experimental)',
value=It.input["large_query"],
interactive=version.parse(tf_version) ==
version.parse('2.10')
)
unload_after = utils.preset.component(
gr.Checkbox,
label='Unload model after running',
value=It.input["unload_after"]
)
with gr.Row(variant='compact'):
with gr.Column(variant='compact'):
search_tags = utils.preset.component(
gr.Textbox,
label='Search tag, .. ->',
elem_id='search-tags'
)
keep_tags = utils.preset.component(
gr.Textbox,
label='Kept tag, ..',
elem_id='keep-tags'
)
with gr.Column(variant='compact'):
replace_tags = utils.preset.component(
gr.Textbox,
label='-> Replace tag, ..',
elem_id='replace-tags'
)
exclude_tags = utils.preset.component(
gr.Textbox,
label='Exclude tag, ..',
elem_id='exclude-tags'
)
# output components
with gr.Column(variant='panel'):
with gr.Row(variant='compact'):
with gr.Column(variant='compact'):
move_filter_to_keep = gr.Button(
value='Move visible tags to keep tags',
variant='secondary'
)
move_filter_to_exclude = gr.Button(
value='Move visible tags to exclude tags',
variant='secondary'
)
with gr.Column(variant='compact'):
tag_search_selection = utils.preset.component(
gr.Textbox,
label='string search selected tags'
)
with gr.Tabs():
tab_include = gr.TabItem(label='Ratings and included tags')
tab_discard = gr.TabItem(label='Excluded tags')
with tab_include:
# clickable tags to populate excluded tags
tags = gr.HTML(
label='Tags',
elem_classes=':link',
)
with gr.Row():
parameters_copypaste.bind_buttons(
parameters_copypaste.create_buttons(
["txt2img", "img2img"],
),
None,
tags
)
rating_confidences = gr.Label(
label='Rating confidences',
elem_id='rating-confidences',
)
tag_confidences = gr.Label(
label='Tag confidences',
elem_id='tag-confidences',
)
with tab_discard:
# clickable tags to populate keep tags
discarded_tags = gr.HTML(
label='Tags',
elem_classes=':link',
)
excluded_tag_confidences = gr.Label(
label='Excluded Tag confidences',
elem_id='tag-confidences',
)
tab_include.select(fn=wrap_gradio_gpu_call(on_interrogate),
inputs=[batch_rewrite, interrogator],
outputs=[tags, rating_confidences, tag_confidences,
info])
tab_discard.select(fn=wrap_gradio_gpu_call(on_inverse_interrogate),
inputs=[batch_rewrite, interrogator],
outputs=[discarded_tags, excluded_tag_confidences,
info])
move_filter_to_keep.click(
fn=wrap_gradio_gpu_call(move_filter_to_keep_fn),
inputs=[tag_search_selection, interrogator],
outputs=[tag_search_selection, keep_tags, discarded_tags,
excluded_tag_confidences, info])
move_filter_to_exclude.click(
fn=wrap_gradio_gpu_call(move_filter_to_exclude_fn),
inputs=[tag_search_selection, interrogator],
outputs=[tag_search_selection, exclude_tags, tags,
rating_confidences, tag_confidences, info])
cumulative.input(fn=wrap_gradio_gpu_call(It.flip('cumulative')),
inputs=[], outputs=[info])
large_query.input(fn=wrap_gradio_gpu_call(It.flip('large_query')),
inputs=[], outputs=[info])
unload_after.input(fn=wrap_gradio_gpu_call(It.flip('unload_after')),
inputs=[], outputs=[info])
threshold_on_average.input(fn=wrap_gradio_gpu_call(
It.flip('threshold_on_average')), inputs=[],
outputs=[info])
save_tags.input(fn=wrap_gradio_gpu_call(IOData.flip_save_tags()),
inputs=[], outputs=[info])
input_glob.blur(fn=wrap_gradio_gpu_call(It.set("input_glob")),
inputs=[input_glob], outputs=[input_glob, info])
output_dir.blur(fn=wrap_gradio_gpu_call(It.set("output_dir")),
inputs=[output_dir], outputs=[output_dir, info])
threshold.input(fn=wrap_gradio_gpu_call(QData.set("threshold")),
inputs=[threshold], outputs=[info])
threshold.release(fn=wrap_gradio_gpu_call(QData.set("threshold")),
inputs=[threshold], outputs=[info])
count_threshold.input(fn=wrap_gradio_gpu_call(
QData.set("count_threshold")),
inputs=[count_threshold], outputs=[info])
count_threshold.release(fn=wrap_gradio_gpu_call(
QData.set("count_threshold")),
inputs=[count_threshold], outputs=[info])
add_tags.blur(fn=wrap_gradio_gpu_call(It.set('add')),
inputs=[add_tags], outputs=[add_tags, info])
keep_tags.blur(fn=wrap_gradio_gpu_call(It.set('keep')),
inputs=[keep_tags], outputs=[keep_tags, info])
exclude_tags.blur(fn=wrap_gradio_gpu_call(It.set('exclude')),
inputs=[exclude_tags], outputs=[exclude_tags, info])
search_tags.blur(fn=wrap_gradio_gpu_call(It.set('search')),
inputs=[search_tags], outputs=[search_tags, info])
replace_tags.blur(fn=wrap_gradio_gpu_call(It.set('replace')),
inputs=[replace_tags], outputs=[replace_tags, info])
# register events
tag_search_selection.change(
fn=wrap_gradio_gpu_call(on_tag_search_filter_change),
inputs=[tag_search_selection],
outputs=[
discarded_tags if QData.inverse else tags,
excluded_tag_confidences if QData.inverse else tag_confidences,
info])
# register events
tag_search_selection.blur(
fn=wrap_gradio_gpu_call(on_tag_search_filter_change),
inputs=[tag_search_selection],
outputs=[
discarded_tags if QData.inverse else tags,
excluded_tag_confidences if QData.inverse else tag_confidences,
info])
# register events
selected_preset.change(
fn=utils.preset.apply,
inputs=[selected_preset],
outputs=[*utils.preset.components, info])
save_preset_button.click(
fn=utils.preset.save,
inputs=[selected_preset, *utils.preset.components], # values only
outputs=[info])
unload_all_models.click(fn=unload_interrogators, outputs=[info])
image.change(
fn=wrap_gradio_gpu_call(on_interrogate_image),
inputs=[image, interrogator],
outputs=[tags, rating_confidences, tag_confidences, info])
image_submit.click(
fn=wrap_gradio_gpu_call(on_interrogate_image),
inputs=[image, interrogator],
outputs=[tags, rating_confidences, tag_confidences, info])
for button in [batch_rewrite, batch_submit]:
button.click(
fn=wrap_gradio_gpu_call(on_interrogate),
inputs=[button, interrogator],
outputs=[tags, rating_confidences, tag_confidences, info])
return [(tagger_interface, "Tagger", "tagger")]