452 lines
18 KiB
Python
452 lines
18 KiB
Python
import gradio as gr
|
|
from PIL import Image
|
|
from typing import Dict, Tuple, List
|
|
|
|
from modules import ui
|
|
from modules import generation_parameters_copypaste as parameters_copypaste
|
|
|
|
from tagger import utils
|
|
from tagger.interrogator import Interrogator as It
|
|
from webui import wrap_gradio_gpu_call
|
|
from tagger.uiset import IOData, QData, it_ret_tp
|
|
from tensorflow import __version__ as tf_version
|
|
from packaging import version
|
|
|
|
BATCH_REWRITE = 'Update tag lists'
|
|
|
|
|
|
def unload_interrogators() -> List[str]:
|
|
unloaded_models = 0
|
|
|
|
for i in utils.interrogators.values():
|
|
if i.unload():
|
|
unloaded_models = unloaded_models + 1
|
|
|
|
return (f'Successfully unload {unloaded_models} model(s)',)
|
|
|
|
|
|
def check_for_errors(name) -> str:
|
|
if name not in utils.interrogators:
|
|
return f"'{name}': invalid interrogator"
|
|
|
|
if len(QData.search_tags) != len(QData.replace_tags):
|
|
return 'search, replace: unequal len, replacements > 1.'
|
|
|
|
return ''
|
|
|
|
|
|
def on_interrogate(button: str, name: str, inverse=False) -> it_ret_tp:
|
|
if It.input["input_glob"] == '':
|
|
return (None, None, None, 'No input directory selected')
|
|
|
|
err = check_for_errors(name)
|
|
if err != '':
|
|
return (None, None, None, err)
|
|
|
|
it: It = utils.interrogators[name]
|
|
QData.inverse = inverse
|
|
return it.batch_interrogate(button == BATCH_REWRITE)
|
|
|
|
|
|
def on_inverse_interrogate(
|
|
button: str, name: str
|
|
) -> Tuple[str, Dict[str, float], str]:
|
|
ret = on_interrogate(button, name, True)
|
|
return (ret[0], ret[2], ret[3])
|
|
|
|
|
|
def on_interrogate_image(image: Image, interrogator: str) -> it_ret_tp:
|
|
if image is None:
|
|
return (None, None, None, 'No image selected')
|
|
err = check_for_errors(interrogator)
|
|
if err != '':
|
|
return (None, None, None, err)
|
|
|
|
interrogator: It = utils.interrogators[interrogator]
|
|
return interrogator.interrogate_image(image)
|
|
|
|
|
|
def move_filter_to_input_fn(
|
|
tag_search_filter: str,
|
|
name: str,
|
|
field: str
|
|
) -> Tuple[str, str, Dict[str, float], Dict[str, float], str]:
|
|
if It.output is None:
|
|
return (None, None, None, None, '')
|
|
|
|
filt = {(k, v) for k, v in It.output[2].items() if tag_search_filter in k}
|
|
if len(filt) == 0:
|
|
return (None, None, None, None, '')
|
|
|
|
add = set(dict(filt).keys())
|
|
if It.input[field] != '':
|
|
add = add.union({x.strip() for x in It.input[field].split(',')})
|
|
|
|
It.input[field] = ', '.join(add)
|
|
|
|
ret = on_interrogate(BATCH_REWRITE, name, QData.inverse)
|
|
return (It.input[field],) + ret
|
|
|
|
|
|
def move_filter_to_keep_fn(
|
|
tag_search_filter: str, name: str
|
|
) -> Tuple[str, str, Dict[str, float], str]:
|
|
ret = move_filter_to_input_fn(tag_search_filter, name, "keep")
|
|
# ratings are not displayed on this tab
|
|
return ('',) + ret[:2] + ret[3:]
|
|
|
|
|
|
def move_filter_to_exclude_fn(
|
|
tag_search_filter: str, name: str
|
|
) -> Tuple[str, str, Dict[str, float], Dict[str, float], str]:
|
|
return ('',) + move_filter_to_input_fn(tag_search_filter, name, "exclude")
|
|
|
|
|
|
def on_tag_search_filter_change(
|
|
part: str
|
|
) -> Tuple[str, Dict[str, float], str]:
|
|
if It.output is None:
|
|
return (None, None, '')
|
|
if len(part) < 2:
|
|
return (It.output[0], It.output[2], '')
|
|
tags = dict(filter(lambda x: part in x[0], It.output[2].items()))
|
|
return (', '.join(tags.keys()), tags, '')
|
|
|
|
|
|
def on_ui_tabs():
|
|
with gr.Blocks(analytics_enabled=False) as tagger_interface:
|
|
with gr.Row().style(equal_height=False):
|
|
with gr.Column(variant='panel'):
|
|
|
|
# input components
|
|
with gr.Tabs():
|
|
with gr.TabItem(label='Single process'):
|
|
image = gr.Image(
|
|
label='Source',
|
|
source='upload',
|
|
interactive=True,
|
|
type="pil"
|
|
)
|
|
image_submit = gr.Button(
|
|
value='Interrogate image',
|
|
variant='primary'
|
|
)
|
|
|
|
with gr.TabItem(label='Batch from directory'):
|
|
input_glob = utils.preset.component(
|
|
gr.Textbox,
|
|
value=It.input["input_glob"],
|
|
label='Input directory - See also settings tab.',
|
|
placeholder='/path/to/images or to/images/**/*'
|
|
)
|
|
output_dir = utils.preset.component(
|
|
gr.Textbox,
|
|
value=It.input["output_dir"],
|
|
label='Output directory',
|
|
placeholder='Leave blank to save images '
|
|
'to the same path.'
|
|
)
|
|
|
|
batch_rewrite = gr.Button(value=BATCH_REWRITE)
|
|
|
|
batch_submit = gr.Button(
|
|
value='Interrogate',
|
|
variant='primary'
|
|
)
|
|
|
|
info = gr.HTML(
|
|
label='Info',
|
|
interactive=False,
|
|
elem_classes=['info']
|
|
)
|
|
|
|
# preset selector
|
|
with gr.Row(variant='compact'):
|
|
available_presets = utils.preset.list()
|
|
selected_preset = gr.Dropdown(
|
|
label='Preset',
|
|
choices=available_presets,
|
|
value=available_presets[0]
|
|
)
|
|
|
|
save_preset_button = gr.Button(
|
|
value=ui.save_style_symbol
|
|
)
|
|
|
|
ui.create_refresh_button(
|
|
selected_preset,
|
|
lambda: None,
|
|
lambda: {'choices': utils.preset.list()},
|
|
'refresh_preset'
|
|
)
|
|
|
|
# interrogator selector
|
|
with gr.Column():
|
|
with gr.Row(variant='compact'):
|
|
interrogator_names = utils.refresh_interrogators()
|
|
interrogator = utils.preset.component(
|
|
gr.Dropdown,
|
|
label='Interrogator',
|
|
choices=interrogator_names,
|
|
value=(
|
|
None
|
|
if len(interrogator_names) < 1 else
|
|
interrogator_names[-1]
|
|
)
|
|
)
|
|
|
|
ui.create_refresh_button(
|
|
interrogator,
|
|
lambda: None,
|
|
lambda: {'choices': utils.refresh_interrogators()},
|
|
'refresh_interrogator'
|
|
)
|
|
|
|
unload_all_models = gr.Button(
|
|
value='Unload all interrogate models'
|
|
)
|
|
add_tags = utils.preset.component(
|
|
gr.Textbox,
|
|
label='Additional tags (comma split)',
|
|
elem_id='additional-tags'
|
|
)
|
|
with gr.Row(variant='compact'):
|
|
with gr.Column(variant='compact'):
|
|
threshold = utils.preset.component(
|
|
gr.Slider,
|
|
label='Threshold',
|
|
minimum=0,
|
|
maximum=1,
|
|
value=QData.threshold
|
|
)
|
|
cumulative = utils.preset.component(
|
|
gr.Checkbox,
|
|
label='combine interrogations',
|
|
value=It.input["cumulative"]
|
|
)
|
|
threshold_on_average = utils.preset.component(
|
|
gr.Checkbox,
|
|
label='threshold applies on average of all images '
|
|
'and interrogations',
|
|
value=It.input["threshold_on_average"]
|
|
)
|
|
save_tags = utils.preset.component(
|
|
gr.Checkbox,
|
|
label='Save to tags files',
|
|
value=IOData.save_tags
|
|
)
|
|
with gr.Column(variant='compact'):
|
|
count_threshold = utils.preset.component(
|
|
gr.Slider,
|
|
label='Tag count threshold',
|
|
minimum=1,
|
|
maximum=500,
|
|
value=QData.count_threshold,
|
|
step=1.0
|
|
)
|
|
|
|
large_query = utils.preset.component(
|
|
gr.Checkbox,
|
|
label='huge batch query (tensorflow 2.10, '
|
|
'experimental)',
|
|
value=It.input["large_query"],
|
|
interactive=version.parse(tf_version) ==
|
|
version.parse('2.10')
|
|
)
|
|
unload_after = utils.preset.component(
|
|
gr.Checkbox,
|
|
label='Unload model after running',
|
|
value=It.input["unload_after"]
|
|
)
|
|
with gr.Row(variant='compact'):
|
|
with gr.Column(variant='compact'):
|
|
search_tags = utils.preset.component(
|
|
gr.Textbox,
|
|
label='Search tag, .. ->',
|
|
elem_id='search-tags'
|
|
)
|
|
keep_tags = utils.preset.component(
|
|
gr.Textbox,
|
|
label='Kept tag, ..',
|
|
elem_id='keep-tags'
|
|
)
|
|
with gr.Column(variant='compact'):
|
|
replace_tags = utils.preset.component(
|
|
gr.Textbox,
|
|
label='-> Replace tag, ..',
|
|
elem_id='replace-tags'
|
|
)
|
|
exclude_tags = utils.preset.component(
|
|
gr.Textbox,
|
|
label='Exclude tag, ..',
|
|
elem_id='exclude-tags'
|
|
)
|
|
|
|
# output components
|
|
with gr.Column(variant='panel'):
|
|
with gr.Row(variant='compact'):
|
|
with gr.Column(variant='compact'):
|
|
move_filter_to_keep = gr.Button(
|
|
value='Move visible tags to keep tags',
|
|
variant='secondary'
|
|
)
|
|
move_filter_to_exclude = gr.Button(
|
|
value='Move visible tags to exclude tags',
|
|
variant='secondary'
|
|
)
|
|
with gr.Column(variant='compact'):
|
|
tag_search_selection = utils.preset.component(
|
|
gr.Textbox,
|
|
label='string search selected tags'
|
|
)
|
|
with gr.Tabs():
|
|
tab_include = gr.TabItem(label='Ratings and included tags')
|
|
tab_discard = gr.TabItem(label='Excluded tags')
|
|
with tab_include:
|
|
# clickable tags to populate excluded tags
|
|
tags = gr.HTML(
|
|
label='Tags',
|
|
elem_classes=':link',
|
|
)
|
|
|
|
with gr.Row():
|
|
parameters_copypaste.bind_buttons(
|
|
parameters_copypaste.create_buttons(
|
|
["txt2img", "img2img"],
|
|
),
|
|
None,
|
|
tags
|
|
)
|
|
rating_confidences = gr.Label(
|
|
label='Rating confidences',
|
|
elem_id='rating-confidences',
|
|
)
|
|
tag_confidences = gr.Label(
|
|
label='Tag confidences',
|
|
elem_id='tag-confidences',
|
|
)
|
|
with tab_discard:
|
|
# clickable tags to populate keep tags
|
|
discarded_tags = gr.HTML(
|
|
label='Tags',
|
|
elem_classes=':link',
|
|
)
|
|
excluded_tag_confidences = gr.Label(
|
|
label='Excluded Tag confidences',
|
|
elem_id='tag-confidences',
|
|
)
|
|
|
|
tab_include.select(fn=wrap_gradio_gpu_call(on_interrogate),
|
|
inputs=[batch_rewrite, interrogator],
|
|
outputs=[tags, rating_confidences, tag_confidences,
|
|
info])
|
|
|
|
tab_discard.select(fn=wrap_gradio_gpu_call(on_inverse_interrogate),
|
|
inputs=[batch_rewrite, interrogator],
|
|
outputs=[discarded_tags, excluded_tag_confidences,
|
|
info])
|
|
|
|
move_filter_to_keep.click(
|
|
fn=wrap_gradio_gpu_call(move_filter_to_keep_fn),
|
|
inputs=[tag_search_selection, interrogator],
|
|
outputs=[tag_search_selection, keep_tags, discarded_tags,
|
|
excluded_tag_confidences, info])
|
|
|
|
move_filter_to_exclude.click(
|
|
fn=wrap_gradio_gpu_call(move_filter_to_exclude_fn),
|
|
inputs=[tag_search_selection, interrogator],
|
|
outputs=[tag_search_selection, exclude_tags, tags,
|
|
rating_confidences, tag_confidences, info])
|
|
|
|
cumulative.input(fn=wrap_gradio_gpu_call(It.flip('cumulative')),
|
|
inputs=[], outputs=[info])
|
|
large_query.input(fn=wrap_gradio_gpu_call(It.flip('large_query')),
|
|
inputs=[], outputs=[info])
|
|
unload_after.input(fn=wrap_gradio_gpu_call(It.flip('unload_after')),
|
|
inputs=[], outputs=[info])
|
|
threshold_on_average.input(fn=wrap_gradio_gpu_call(
|
|
It.flip('threshold_on_average')), inputs=[],
|
|
outputs=[info])
|
|
|
|
save_tags.input(fn=wrap_gradio_gpu_call(IOData.flip_save_tags()),
|
|
inputs=[], outputs=[info])
|
|
|
|
input_glob.blur(fn=wrap_gradio_gpu_call(It.set("input_glob")),
|
|
inputs=[input_glob], outputs=[input_glob, info])
|
|
output_dir.blur(fn=wrap_gradio_gpu_call(It.set("output_dir")),
|
|
inputs=[output_dir], outputs=[output_dir, info])
|
|
|
|
threshold.input(fn=wrap_gradio_gpu_call(QData.set("threshold")),
|
|
inputs=[threshold], outputs=[info])
|
|
threshold.release(fn=wrap_gradio_gpu_call(QData.set("threshold")),
|
|
inputs=[threshold], outputs=[info])
|
|
|
|
count_threshold.input(fn=wrap_gradio_gpu_call(
|
|
QData.set("count_threshold")),
|
|
inputs=[count_threshold], outputs=[info])
|
|
count_threshold.release(fn=wrap_gradio_gpu_call(
|
|
QData.set("count_threshold")),
|
|
inputs=[count_threshold], outputs=[info])
|
|
|
|
add_tags.blur(fn=wrap_gradio_gpu_call(It.set('add')),
|
|
inputs=[add_tags], outputs=[add_tags, info])
|
|
|
|
keep_tags.blur(fn=wrap_gradio_gpu_call(It.set('keep')),
|
|
inputs=[keep_tags], outputs=[keep_tags, info])
|
|
exclude_tags.blur(fn=wrap_gradio_gpu_call(It.set('exclude')),
|
|
inputs=[exclude_tags], outputs=[exclude_tags, info])
|
|
|
|
search_tags.blur(fn=wrap_gradio_gpu_call(It.set('search')),
|
|
inputs=[search_tags], outputs=[search_tags, info])
|
|
replace_tags.blur(fn=wrap_gradio_gpu_call(It.set('replace')),
|
|
inputs=[replace_tags], outputs=[replace_tags, info])
|
|
|
|
# register events
|
|
tag_search_selection.change(
|
|
fn=wrap_gradio_gpu_call(on_tag_search_filter_change),
|
|
inputs=[tag_search_selection],
|
|
outputs=[
|
|
discarded_tags if QData.inverse else tags,
|
|
excluded_tag_confidences if QData.inverse else tag_confidences,
|
|
info])
|
|
|
|
# register events
|
|
tag_search_selection.blur(
|
|
fn=wrap_gradio_gpu_call(on_tag_search_filter_change),
|
|
inputs=[tag_search_selection],
|
|
outputs=[
|
|
discarded_tags if QData.inverse else tags,
|
|
excluded_tag_confidences if QData.inverse else tag_confidences,
|
|
info])
|
|
|
|
# register events
|
|
selected_preset.change(
|
|
fn=utils.preset.apply,
|
|
inputs=[selected_preset],
|
|
outputs=[*utils.preset.components, info])
|
|
|
|
save_preset_button.click(
|
|
fn=utils.preset.save,
|
|
inputs=[selected_preset, *utils.preset.components], # values only
|
|
outputs=[info])
|
|
|
|
unload_all_models.click(fn=unload_interrogators, outputs=[info])
|
|
|
|
image.change(
|
|
fn=wrap_gradio_gpu_call(on_interrogate_image),
|
|
inputs=[image, interrogator],
|
|
outputs=[tags, rating_confidences, tag_confidences, info])
|
|
|
|
image_submit.click(
|
|
fn=wrap_gradio_gpu_call(on_interrogate_image),
|
|
inputs=[image, interrogator],
|
|
outputs=[tags, rating_confidences, tag_confidences, info])
|
|
|
|
for button in [batch_rewrite, batch_submit]:
|
|
button.click(
|
|
fn=wrap_gradio_gpu_call(on_interrogate),
|
|
inputs=[button, interrogator],
|
|
outputs=[tags, rating_confidences, tag_confidences, info])
|
|
|
|
return [(tagger_interface, "Tagger", "tagger")]
|