mirror of https://github.com/bmaltais/kohya_ss
671 lines
27 KiB
Python
671 lines
27 KiB
Python
import gradio as gr
|
|
from easygui import boolbox
|
|
from PIL import Image
|
|
from .common_gui import get_folder_path, scriptdir, list_dirs
|
|
from math import ceil
|
|
import os
|
|
import re
|
|
|
|
from .custom_logging import setup_logging
|
|
|
|
# Set up logging
|
|
log = setup_logging()
|
|
|
|
IMAGES_TO_SHOW = 5
|
|
IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
|
|
|
|
|
|
def _get_caption_path(image_file, images_dir, caption_ext):
|
|
"""
|
|
Returns the expected path of a caption file for a given image path
|
|
"""
|
|
if not image_file:
|
|
return None
|
|
# REFACTOR: Use os.path.basename to ensure we only have the filename
|
|
base_name = os.path.basename(image_file)
|
|
caption_file_name = os.path.splitext(base_name)[0] + caption_ext
|
|
caption_file_path = os.path.join(images_dir, caption_file_name)
|
|
return caption_file_path
|
|
|
|
|
|
def _get_quick_tags(quick_tags_text):
|
|
"""
|
|
Gets a list of tags from the quick tags text box
|
|
"""
|
|
quick_tags = [t.strip() for t in quick_tags_text.split(",") if t.strip()]
|
|
quick_tags_set = set(tag.lower() for tag in quick_tags) # REFACTOR: Use lowercase for matching
|
|
return quick_tags, quick_tags_set
|
|
|
|
|
|
def _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set):
|
|
"""
|
|
Updates a list of caption checkboxes to show possible tags and tags
|
|
already included in the caption
|
|
"""
|
|
caption_tags_have = [c.strip() for c in caption.split(",") if c.strip()]
|
|
# REFACTOR: Match case-insensitively against quick_tags_set
|
|
caption_tags_unique = [t for t in caption_tags_have if t.lower() not in quick_tags_set]
|
|
caption_tags_all = quick_tags + caption_tags_unique
|
|
return gr.CheckboxGroup(choices=caption_tags_all, value=caption_tags_have)
|
|
|
|
|
|
def paginate_go(page, max_page):
|
|
try:
|
|
page_num = int(page) # REFACTOR: Use int, pages are not fractional
|
|
return paginate(page_num, max_page, 0)
|
|
except (ValueError, TypeError):
|
|
gr.Warning(f"Invalid page number: {page}")
|
|
return gr.update()
|
|
|
|
|
|
def derive_target_folder(control_folder):
|
|
if not control_folder or not os.path.exists(control_folder):
|
|
return ""
|
|
|
|
parent_dir = os.path.dirname(control_folder)
|
|
|
|
for item in os.listdir(parent_dir):
|
|
if os.path.isdir(os.path.join(parent_dir, item)):
|
|
if "target" in item.lower():
|
|
return os.path.join(parent_dir, item)
|
|
|
|
return os.path.join(parent_dir, "target")
|
|
|
|
|
|
def paginate(page, max_page, page_change):
|
|
# REFACTOR: Ensure page is treated as an integer
|
|
return int(max(min(int(page) + page_change, max_page), 1))
|
|
|
|
|
|
def save_caption(caption, caption_ext, image_file, images_dir):
|
|
caption_path = _get_caption_path(image_file, images_dir, caption_ext)
|
|
if caption_path:
|
|
# REFACTOR: Use 'w' which is sufficient and standard for writing/overwriting a file.
|
|
with open(caption_path, "w", encoding="utf-8") as f:
|
|
f.write(caption)
|
|
log.info(f"Wrote captions to {caption_path}")
|
|
return gr.Markdown(f"💾 Caption saved to `{caption_path}`", visible=True)
|
|
return gr.Markdown(visible=False)
|
|
|
|
|
|
def delete_images_and_caption(
|
|
image_file, control_images_dir, target_images_dir, caption_ext
|
|
):
|
|
if not image_file:
|
|
return gr.Markdown(visible=False)
|
|
|
|
# Delete control image
|
|
control_image_path = os.path.join(control_images_dir, image_file)
|
|
if os.path.exists(control_image_path):
|
|
os.remove(control_image_path)
|
|
log.info(f"Deleted control image: {control_image_path}")
|
|
|
|
# Delete target image
|
|
target_image_path = os.path.join(target_images_dir, image_file)
|
|
if os.path.exists(target_image_path):
|
|
os.remove(target_image_path)
|
|
log.info(f"Deleted target image: {target_image_path}")
|
|
|
|
# Delete caption file
|
|
caption_path = _get_caption_path(
|
|
image_file, target_images_dir, caption_ext
|
|
)
|
|
if caption_path and os.path.exists(caption_path):
|
|
os.remove(caption_path)
|
|
log.info(f"Deleted caption file: {caption_path}")
|
|
|
|
return gr.Markdown(
|
|
f"🗑️ Deleted files for `{image_file}`", visible=True
|
|
)
|
|
|
|
|
|
def update_quick_tags(quick_tags_text, *image_caption_texts):
|
|
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text)
|
|
return [
|
|
_get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
|
|
for caption in image_caption_texts
|
|
]
|
|
|
|
|
|
def update_image_caption(
|
|
quick_tags_text, caption, image_file, images_dir, caption_ext, auto_save_is_checked
|
|
):
|
|
# REFACTOR: Changed parameter name to avoid shadowing built-in
|
|
if auto_save_is_checked:
|
|
save_caption(caption, caption_ext, image_file, images_dir)
|
|
|
|
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text)
|
|
return _get_tag_checkbox_updates(caption, quick_tags, quick_tags_set)
|
|
|
|
|
|
def update_image_tags(
|
|
quick_tags_text,
|
|
selected_tags,
|
|
image_file,
|
|
images_dir,
|
|
caption_ext,
|
|
auto_save_is_checked,
|
|
):
|
|
# REFACTOR: Changed parameter name
|
|
# Try to determine order by quick tags
|
|
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text)
|
|
selected_tags_set = set(selected_tags)
|
|
|
|
output_tags = [t for t in quick_tags if t in selected_tags_set] + [
|
|
t for t in selected_tags if t not in quick_tags_set
|
|
]
|
|
caption = ", ".join(output_tags)
|
|
|
|
if auto_save_is_checked:
|
|
save_caption(caption, caption_ext, image_file, images_dir)
|
|
|
|
return caption
|
|
|
|
|
|
def import_tags_from_captions(
|
|
images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
ignore_load_tags_word_count,
|
|
ask_overwrite=True,
|
|
):
|
|
if not images_dir or not os.path.exists(images_dir):
|
|
gr.Warning(
|
|
"Image folder is not set or does not exist. Please load images first."
|
|
)
|
|
return gr.update()
|
|
|
|
if not caption_ext:
|
|
gr.Warning("Please provide an extension for the caption files.")
|
|
return gr.update()
|
|
|
|
if quick_tags_text and ask_overwrite:
|
|
if not boolbox(
|
|
"Are you sure you wish to overwrite the current quick tags?",
|
|
choices=("Yes", "No"),
|
|
):
|
|
return gr.update()
|
|
|
|
# REFACTOR: Directly iterate over files from os.scandir for slight performance gain
|
|
tags = []
|
|
tags_set = set()
|
|
for entry in os.scandir(images_dir):
|
|
if entry.is_file() and entry.name.lower().endswith(IMAGE_EXTENSIONS):
|
|
caption_file_path = _get_caption_path(entry.name, images_dir, caption_ext)
|
|
if os.path.exists(caption_file_path):
|
|
with open(caption_file_path, "r", encoding="utf-8") as f:
|
|
caption = f.read()
|
|
for tag in caption.split(","):
|
|
tag = tag.strip()
|
|
if not tag:
|
|
continue
|
|
tag_key = tag.lower()
|
|
if tag_key not in tags_set:
|
|
# REFACTOR: Simplified word count logic
|
|
if tag.count(" ") + 1 <= ignore_load_tags_word_count:
|
|
tags.append(tag)
|
|
tags_set.add(tag_key)
|
|
|
|
# Ensure tags are alphabetically sorted (case-insensitive)
|
|
tags_sorted = sorted(tags, key=lambda t: t.lower())
|
|
gr.Info(f"Imported {len(tags_sorted)} unique tags.")
|
|
return ", ".join(tags_sorted)
|
|
|
|
|
|
def load_images(
|
|
target_images_dir,
|
|
control_images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
ignore_load_tags_word_count,
|
|
):
|
|
def error_message(msg):
|
|
gr.Warning(msg)
|
|
# REFACTOR: Return updates for all outputs to clear state on error
|
|
return [None, None, None, 1, 1, gr.Markdown(f"⚠️ {msg}", visible=True), gr.update()]
|
|
|
|
if not target_images_dir or not os.path.exists(target_images_dir):
|
|
return error_message("Target image folder is missing or does not exist.")
|
|
if not control_images_dir or not os.path.exists(control_images_dir):
|
|
return error_message("Control image folder is missing or does not exist.")
|
|
if not caption_ext:
|
|
return error_message("Please provide an extension for the caption files.")
|
|
|
|
target_image_files = {
|
|
f
|
|
for f in os.listdir(target_images_dir)
|
|
if f.lower().endswith(IMAGE_EXTENSIONS)
|
|
}
|
|
control_image_files = {
|
|
f
|
|
for f in os.listdir(control_images_dir)
|
|
if f.lower().endswith(IMAGE_EXTENSIONS)
|
|
}
|
|
# REFACTOR: Sort files here once and store them
|
|
shared_files = sorted(list(target_image_files.intersection(control_image_files)))
|
|
|
|
if not shared_files:
|
|
return error_message(
|
|
"No shared images found between the target and control directories."
|
|
)
|
|
|
|
mismatched_files = []
|
|
for image_file in shared_files:
|
|
target_image_path = os.path.join(target_images_dir, image_file)
|
|
control_image_path = os.path.join(control_images_dir, image_file)
|
|
|
|
try:
|
|
with Image.open(target_image_path) as target_img, Image.open(control_image_path) as control_img:
|
|
target_aspect_ratio = target_img.width / target_img.height
|
|
control_aspect_ratio = control_img.width / control_img.height
|
|
|
|
if abs(target_aspect_ratio - control_aspect_ratio) > 1e-2:
|
|
mismatched_files.append(image_file)
|
|
log.warning(f"Aspect ratio mismatch for {image_file}: Target AR is {target_aspect_ratio:.4f}, Control AR is {control_aspect_ratio:.4f}")
|
|
except Exception as e:
|
|
log.error(f"Could not load or process image {image_file}. Error: {e}")
|
|
|
|
if mismatched_files:
|
|
gr.Warning(f"Found {len(mismatched_files)} images with aspect ratio mismatches. Use the 'Apply Correction' button to fix them.")
|
|
|
|
total_images = len(shared_files)
|
|
max_pages = ceil(total_images / IMAGES_TO_SHOW)
|
|
|
|
info = f"✅ Loaded {total_images} shared images. {max_pages} pages total."
|
|
gr.Info(info)
|
|
|
|
# Import tags
|
|
new_quick_tags = import_tags_from_captions(
|
|
target_images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
ignore_load_tags_word_count,
|
|
ask_overwrite=False,
|
|
)
|
|
|
|
# REFACTOR: Return the computed file list to be stored in gr.State
|
|
return [
|
|
shared_files,
|
|
target_images_dir,
|
|
control_images_dir,
|
|
1,
|
|
max_pages,
|
|
gr.Markdown(info, visible=True),
|
|
new_quick_tags,
|
|
]
|
|
|
|
|
|
def crop_image(image, target_aspect_ratio):
|
|
width, height = image.size
|
|
current_aspect_ratio = width / height
|
|
|
|
if current_aspect_ratio > target_aspect_ratio:
|
|
# Crop width
|
|
new_width = int(target_aspect_ratio * height)
|
|
left = (width - new_width) / 2
|
|
top = 0
|
|
right = left + new_width
|
|
bottom = height
|
|
else:
|
|
# Crop height
|
|
new_height = int(width / target_aspect_ratio)
|
|
left = 0
|
|
top = (height - new_height) / 2
|
|
right = width
|
|
bottom = top + new_height
|
|
|
|
return image.crop((left, top, right, bottom))
|
|
|
|
|
|
def pad_image(image, target_aspect_ratio, color="white"):
|
|
width, height = image.size
|
|
current_aspect_ratio = width / height
|
|
|
|
if current_aspect_ratio > target_aspect_ratio:
|
|
# Pad height
|
|
new_height = int(width / target_aspect_ratio)
|
|
padded_image = Image.new(image.mode, (width, new_height), color)
|
|
padded_image.paste(image, (0, int((new_height - height) / 2)))
|
|
else:
|
|
# Pad width
|
|
new_width = int(height * target_aspect_ratio)
|
|
padded_image = Image.new(image.mode, (new_width, height), color)
|
|
padded_image.paste(image, (int((new_width - width) / 2), 0))
|
|
|
|
return padded_image
|
|
|
|
|
|
def save_image_with_backup(image_path, image_to_save):
|
|
if not os.path.exists(image_path):
|
|
log.error(f"Image path does not exist: {image_path}")
|
|
return
|
|
|
|
try:
|
|
directory = os.path.dirname(image_path)
|
|
original_dir = os.path.join(directory, "original")
|
|
|
|
if not os.path.exists(original_dir):
|
|
os.makedirs(original_dir)
|
|
|
|
base_filename = os.path.basename(image_path)
|
|
backup_path = os.path.join(original_dir, base_filename)
|
|
|
|
if not os.path.exists(backup_path):
|
|
os.rename(image_path, backup_path)
|
|
log.info(f"Backed up original image to {backup_path}")
|
|
else:
|
|
log.info(f"Backup already exists for {base_filename}, skipping backup.")
|
|
|
|
image_to_save.save(image_path)
|
|
log.info(f"Saved modified image to {image_path}")
|
|
|
|
except Exception as e:
|
|
log.error(f"Error saving image with backup: {e}")
|
|
|
|
|
|
def apply_correction(
|
|
image_files,
|
|
target_images_dir,
|
|
control_images_dir,
|
|
correction_method,
|
|
save_padded,
|
|
):
|
|
if not image_files:
|
|
gr.Warning("No images loaded.")
|
|
return gr.update()
|
|
|
|
if correction_method == "None":
|
|
gr.Info("No correction method selected.")
|
|
return gr.update()
|
|
|
|
corrected_files = 0
|
|
for image_file in image_files:
|
|
target_image_path = os.path.join(target_images_dir, image_file)
|
|
control_image_path = os.path.join(control_images_dir, image_file)
|
|
|
|
try:
|
|
with Image.open(target_image_path) as target_img, Image.open(
|
|
control_image_path
|
|
) as control_img:
|
|
target_aspect_ratio = target_img.width / target_img.height
|
|
control_aspect_ratio = control_img.width / control_img.height
|
|
|
|
if abs(target_aspect_ratio - control_aspect_ratio) > 1e-2:
|
|
if target_aspect_ratio > control_aspect_ratio:
|
|
# Target is wider, so we correct it
|
|
image_to_correct = target_img
|
|
path_to_save = target_image_path
|
|
correct_aspect_ratio = control_aspect_ratio
|
|
else:
|
|
# Control is wider, so we correct it
|
|
image_to_correct = control_img
|
|
path_to_save = control_image_path
|
|
correct_aspect_ratio = target_aspect_ratio
|
|
|
|
if correction_method == "Crop":
|
|
modified_image = crop_image(image_to_correct, correct_aspect_ratio)
|
|
elif correction_method == "Pad":
|
|
modified_image = pad_image(image_to_correct, correct_aspect_ratio)
|
|
else:
|
|
continue # Should not happen
|
|
|
|
if save_padded:
|
|
save_image_with_backup(path_to_save, modified_image)
|
|
|
|
corrected_files += 1
|
|
|
|
except Exception as e:
|
|
log.error(f"Could not process or save image {image_file}. Error: {e}")
|
|
|
|
gr.Info(f"Corrected {corrected_files} images with aspect ratio mismatches.")
|
|
return gr.update()
|
|
|
|
|
|
def update_images(
|
|
image_files, # REFACTOR: Receive the list of files from gr.State
|
|
target_images_dir,
|
|
control_images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
page,
|
|
):
|
|
# REFACTOR: No more os.listdir here! We get the list directly from state.
|
|
if not image_files or not target_images_dir:
|
|
# Return empty updates if state is not ready
|
|
empty_row = gr.Row(visible=False)
|
|
return [empty_row] * (IMAGES_TO_SHOW * 5 + 2)
|
|
|
|
quick_tags, quick_tags_set = _get_quick_tags(quick_tags_text or "")
|
|
|
|
outputs = []
|
|
start_index = (int(page) - 1) * IMAGES_TO_SHOW
|
|
|
|
# Build component updates in lists
|
|
rows_update, files_update, target_paths, control_paths, captions, tags_checks = [], [], [], [], [], []
|
|
|
|
for i in range(IMAGES_TO_SHOW):
|
|
image_index = start_index + i
|
|
is_visible = image_index < len(image_files)
|
|
|
|
rows_update.append(gr.Row(visible=is_visible))
|
|
|
|
image_file, target_path, control_path, caption = None, None, None, ""
|
|
if is_visible:
|
|
image_file = image_files[image_index]
|
|
target_path = os.path.join(target_images_dir, image_file)
|
|
control_path = os.path.join(control_images_dir, image_file)
|
|
caption_file_path = _get_caption_path(image_file, target_images_dir, caption_ext)
|
|
if caption_file_path and os.path.exists(caption_file_path):
|
|
with open(caption_file_path, "r", encoding="utf-8") as f:
|
|
caption = f.read()
|
|
|
|
files_update.append(image_file)
|
|
target_paths.append(target_path)
|
|
control_paths.append(control_path)
|
|
captions.append(caption)
|
|
tags_checks.append(_get_tag_checkbox_updates(caption, quick_tags, quick_tags_set))
|
|
|
|
# Combine all updates into a single list
|
|
outputs.extend(rows_update)
|
|
outputs.extend(files_update)
|
|
outputs.extend(target_paths)
|
|
outputs.extend(control_paths)
|
|
outputs.extend(captions)
|
|
outputs.extend(tags_checks)
|
|
outputs.extend([gr.Row(visible=True), gr.Row(visible=True)]) # Pagination rows
|
|
|
|
return outputs
|
|
|
|
|
|
# Gradio UI
|
|
def gradio_kontext_manual_caption_gui_tab(headless=False, default_images_dir=None):
|
|
from .common_gui import create_refresh_button
|
|
|
|
default_images_dir = default_images_dir or os.path.join(scriptdir, "data")
|
|
|
|
# REFACTOR: Simplify directory update logic
|
|
def update_dir_list(path):
|
|
# FIX: Convert generator from list_dirs to a list
|
|
return gr.Dropdown(choices=[""] + list(list_dirs(path)))
|
|
|
|
# REFACTOR: Define pagination UI and logic in one place
|
|
def render_pagination_with_logic(page, max_page):
|
|
with gr.Row(visible=False) as pagination_row:
|
|
gr.Button("◀ Prev").click(paginate, inputs=[page, max_page, gr.Number(-1, visible=False)], outputs=[page])
|
|
page_count = gr.Text("Page 1 / 1", show_label=False, interactive=False, text_align="center")
|
|
page_goto_text = gr.Textbox(show_label=False, placeholder="Go to page...", container=False, scale=1)
|
|
gr.Button("Next ▶").click(paginate, inputs=[page, max_page, gr.Number(1, visible=False)], outputs=[page])
|
|
page_goto_text.submit(paginate_go, inputs=[page_goto_text, max_page], outputs=[page])
|
|
return pagination_row, page_count
|
|
|
|
with gr.Tab("Kontext Manual Captioning"):
|
|
gr.Markdown("This utility allows quick captioning and tagging of images for fine-tuning with before and after images.")
|
|
|
|
# REFACTOR: Use gr.State for non-UI data
|
|
image_files_state = gr.State([])
|
|
|
|
info_box = gr.Markdown(visible=False)
|
|
page = gr.Number(value=1, visible=False)
|
|
max_page = gr.Number(value=1, visible=False)
|
|
loaded_images_dir = gr.Text(visible=False)
|
|
loaded_control_images_dir = gr.Text(visible=False)
|
|
|
|
with gr.Group():
|
|
with gr.Row():
|
|
# FIX: Convert generator from list_dirs to a list
|
|
control_images_dir = gr.Dropdown(label="Control image folder", choices=[""] + list(list_dirs(default_images_dir)), value="", interactive=True, allow_custom_value=True, scale=2)
|
|
create_refresh_button(control_images_dir, lambda: None, lambda: {"choices": list(list_dirs(control_images_dir.value or default_images_dir))}, "open_folder_small")
|
|
gr.Button("📂", elem_id="open_folder_small", elem_classes=["tool"], visible=not headless).click(get_folder_path, outputs=control_images_dir, show_progress=False)
|
|
|
|
# FIX: Convert generator from list_dirs to a list
|
|
target_images_dir = gr.Dropdown(label="Target image folder", choices=[""] + list(list_dirs(default_images_dir)), value="", interactive=True, allow_custom_value=True, scale=2)
|
|
create_refresh_button(target_images_dir, lambda: None, lambda: {"choices": list(list_dirs(target_images_dir.value or default_images_dir))}, "open_folder_small")
|
|
gr.Button("📂", elem_id="open_folder_small", elem_classes=["tool"], visible=not headless).click(get_folder_path, outputs=target_images_dir, show_progress=False)
|
|
|
|
with gr.Row():
|
|
caption_ext = gr.Dropdown(label="Caption file extension", choices=[".cap", ".caption", ".txt"], value=".txt", interactive=True, allow_custom_value=True)
|
|
auto_save = gr.Checkbox(label="Autosave", value=True, interactive=True)
|
|
|
|
load_images_button = gr.Button("Load Images", variant="primary")
|
|
|
|
with gr.Row():
|
|
aspect_ratio_correction = gr.Dropdown(["None", "Crop", "Pad"], label="Aspect Ratio Correction", value="None")
|
|
save_padded_images = gr.Checkbox(label="Save Padded Images", value=False)
|
|
|
|
apply_correction_button = gr.Button("Apply Correction")
|
|
|
|
apply_correction_button.click(
|
|
apply_correction,
|
|
inputs=[
|
|
image_files_state,
|
|
loaded_images_dir,
|
|
loaded_control_images_dir,
|
|
aspect_ratio_correction,
|
|
save_padded_images,
|
|
],
|
|
outputs=info_box,
|
|
)
|
|
|
|
target_images_dir.change(update_dir_list, inputs=target_images_dir, outputs=target_images_dir, show_progress=False)
|
|
control_images_dir.change(
|
|
lambda path, current_target: (
|
|
update_dir_list(path),
|
|
derive_target_folder(path)
|
|
if not current_target
|
|
else current_target,
|
|
),
|
|
inputs=[control_images_dir, target_images_dir],
|
|
outputs=[control_images_dir, target_images_dir],
|
|
show_progress=False,
|
|
)
|
|
|
|
with gr.Group():
|
|
quick_tags_text = gr.Textbox(label="Quick Tags", placeholder="Comma separated list of tags", interactive=True)
|
|
with gr.Row():
|
|
ignore_load_tags_word_count = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Ignore Imported Tags Above Word Count", interactive=True, scale=2)
|
|
|
|
with gr.Row():
|
|
import_tags_button = gr.Button("Import tags from captions", scale=1)
|
|
|
|
pagination_row1, page_count1 = render_pagination_with_logic(page, max_page)
|
|
|
|
image_rows, image_files, target_image_images, control_image_images, image_caption_texts, image_tag_checks, save_buttons, delete_buttons = [], [], [], [], [], [], [], []
|
|
for i in range(IMAGES_TO_SHOW):
|
|
with gr.Row(visible=False) as row:
|
|
image_file = gr.Text(visible=False)
|
|
with gr.Column():
|
|
control_image_image = gr.Image(type="filepath", label="Control Image")
|
|
with gr.Column():
|
|
target_image_image = gr.Image(type="filepath", label="Target Image")
|
|
with gr.Column(scale=2):
|
|
image_caption_text = gr.TextArea(label="Captions", placeholder="Input captions for target image", interactive=True)
|
|
tag_checkboxes = gr.CheckboxGroup([], label="Tags", interactive=True)
|
|
with gr.Column(min_width=40):
|
|
save_button = gr.Button("💾", elem_id="save_button", visible=False)
|
|
delete_button = gr.Button("🗑️", elem_id="delete_button")
|
|
|
|
image_rows.append(row); image_files.append(image_file); control_image_images.append(control_image_image); target_image_images.append(target_image_image)
|
|
image_caption_texts.append(image_caption_text); image_tag_checks.append(tag_checkboxes); save_buttons.append(save_button); delete_buttons.append(delete_button)
|
|
|
|
image_caption_text.input(update_image_caption, inputs=[quick_tags_text, image_caption_text, image_file, loaded_images_dir, caption_ext, auto_save], outputs=tag_checkboxes)
|
|
tag_checkboxes.input(update_image_tags, inputs=[quick_tags_text, tag_checkboxes, image_file, loaded_images_dir, caption_ext, auto_save], outputs=[image_caption_text])
|
|
save_button.click(save_caption, inputs=[image_caption_text, caption_ext, image_file, loaded_images_dir], outputs=info_box)
|
|
delete_button.click(
|
|
delete_images_and_caption,
|
|
inputs=[
|
|
image_file,
|
|
loaded_control_images_dir,
|
|
loaded_images_dir,
|
|
caption_ext,
|
|
],
|
|
outputs=info_box,
|
|
).then(
|
|
load_images,
|
|
inputs=[
|
|
loaded_images_dir,
|
|
loaded_control_images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
ignore_load_tags_word_count,
|
|
],
|
|
outputs=[
|
|
image_files_state,
|
|
loaded_images_dir,
|
|
loaded_control_images_dir,
|
|
page,
|
|
max_page,
|
|
info_box,
|
|
quick_tags_text,
|
|
],
|
|
)
|
|
|
|
pagination_row2, page_count2 = render_pagination_with_logic(page, max_page)
|
|
|
|
quick_tags_text.change(update_quick_tags, inputs=[quick_tags_text] + image_caption_texts, outputs=image_tag_checks)
|
|
import_tags_button.click(
|
|
import_tags_from_captions,
|
|
inputs=[
|
|
loaded_images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
ignore_load_tags_word_count,
|
|
],
|
|
outputs=quick_tags_text,
|
|
)
|
|
|
|
load_images_outputs = [
|
|
image_files_state,
|
|
loaded_images_dir,
|
|
loaded_control_images_dir,
|
|
page,
|
|
max_page,
|
|
info_box,
|
|
quick_tags_text,
|
|
]
|
|
load_images_button.click(
|
|
load_images,
|
|
inputs=[
|
|
target_images_dir,
|
|
control_images_dir,
|
|
caption_ext,
|
|
quick_tags_text,
|
|
ignore_load_tags_word_count,
|
|
],
|
|
outputs=load_images_outputs,
|
|
)
|
|
|
|
# REFACTOR: A single trigger to update the images view
|
|
update_trigger_inputs = [image_files_state, loaded_images_dir, loaded_control_images_dir, caption_ext, quick_tags_text, page]
|
|
update_outputs = (
|
|
image_rows + image_files + target_image_images + control_image_images +
|
|
image_caption_texts + image_tag_checks + [pagination_row1, pagination_row2]
|
|
)
|
|
|
|
# Trigger update when page or data changes
|
|
page.change(update_images, inputs=update_trigger_inputs, outputs=update_outputs, show_progress=False)
|
|
image_files_state.change(update_images, inputs=update_trigger_inputs, outputs=update_outputs, show_progress=False)
|
|
|
|
auto_save.change(lambda is_auto_save: [gr.Button(visible=not is_auto_save)] * IMAGES_TO_SHOW, inputs=auto_save, outputs=save_buttons)
|
|
page.change(lambda p, m: (f"Page {int(p)} / {int(m)}", f"Page {int(p)} / {int(m)}"), inputs=[page, max_page], outputs=[page_count1, page_count2], show_progress=False)
|
|
max_page.change(lambda p, m: (f"Page {int(p)} / {int(m)}", f"Page {int(p)} / {int(m)}"), inputs=[page, max_page], outputs=[page_count1, page_count2], show_progress=False)
|