diff --git a/scripts/ui/block_dataset_gallery.py b/scripts/ui/block_dataset_gallery.py index 73daeab..9f07897 100644 --- a/scripts/ui/block_dataset_gallery.py +++ b/scripts/ui/block_dataset_gallery.py @@ -1,8 +1,14 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable, List + import gradio as gr from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * + class DatasetGalleryUI(UIBase): def __init__(self): self.selected_path = '' @@ -16,7 +22,7 @@ class DatasetGalleryUI(UIBase): self.nb_hidden_image_index_prev = gr.Number(value=None, label='hidden_idx_prev') self.gl_dataset_images = gr.Gallery(label='Dataset Images', elem_id="dataset_tag_editor_dataset_gallery").style(grid=image_columns) - def set_callbacks(self, load_dataset, gallery_state, get_filters): + def set_callbacks(self, load_dataset:LoadDatasetUI, gallery_state:GalleryStateUI, get_filters:Callable[[], dte_module.filters.Filter]): gallery_state.register_value('Selected Image', self.selected_path) load_dataset.btn_load_datasets.click( diff --git a/scripts/ui/block_gallery_state.py b/scripts/ui/block_gallery_state.py index cab3013..8c3cf69 100644 --- a/scripts/ui/block_gallery_state.py +++ b/scripts/ui/block_gallery_state.py @@ -1,8 +1,13 @@ +from __future__ import annotations +from typing import TYPE_CHECKING import gradio as gr from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * + class GalleryStateUI(UIBase): def __init__(self): self.texts = dict() @@ -22,7 +27,7 @@ class GalleryStateUI(UIBase): def create_ui(self): self.txt_gallery = gr.HTML(value=self.get_current_gallery_txt()) - def set_callbacks(self, dataset_gallery): + def set_callbacks(self, dataset_gallery:DatasetGalleryUI): dataset_gallery.nb_hidden_image_index.change( fn=self.update_gallery_txt, inputs=None, diff --git a/scripts/ui/block_load_dataset.py b/scripts/ui/block_load_dataset.py index d22ef25..ffff377 100644 --- a/scripts/ui/block_load_dataset.py +++ b/scripts/ui/block_load_dataset.py @@ -1,4 +1,5 @@ -from typing import List +from __future__ import annotations +from typing import TYPE_CHECKING, List, Callable import gradio as gr from modules import shared @@ -7,6 +8,8 @@ from modules.shared import opts from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * INTERROGATOR_NAMES = dte_module.INTERROGATOR_NAMES InterrogateMethod = dte_instance.InterrogateMethod @@ -42,7 +45,7 @@ class LoadDatasetUI(UIBase): self.cb_use_custom_threshold_waifu = gr.Checkbox(value=cfg_general.use_custom_threshold_waifu, label='Use Custom Threshold (WDv1.4 Tagger)', interactive=True) self.sl_custom_threshold_waifu = gr.Slider(minimum=0, maximum=1, value=cfg_general.custom_threshold_waifu, step=0.01, interactive=True, label='WDv1.4 Tagger Score Threshold') - def set_callbacks(self, o_update_filter_and_gallery, toprow, dataset_gallery, filter_by_tags, filter_by_selection, batch_edit_captions, update_filter_and_gallery): + def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], toprow:ToprowUI, dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, filter_by_selection:FilterBySelectionUI, batch_edit_captions:BatchEditCaptionsUI, update_filter_and_gallery:Callable[[], List]): def load_files_from_dir( dir: str, caption_file_ext: str, diff --git a/scripts/ui/block_toprow.py b/scripts/ui/block_toprow.py index 738e07c..9f47454 100644 --- a/scripts/ui/block_toprow.py +++ b/scripts/ui/block_toprow.py @@ -1,8 +1,13 @@ +from __future__ import annotations +from typing import TYPE_CHECKING import gradio as gr from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * + class ToprowUI(UIBase): def create_ui(self, cfg_general): @@ -26,7 +31,7 @@ class ToprowUI(UIBase): with gr.Row(visible=False): self.txt_result = gr.Textbox(label='Results', interactive=False) - def set_callbacks(self, load_dataset): + def set_callbacks(self, load_dataset:LoadDatasetUI): def save_all_changes(backup: bool, save_kohya_metadata:bool, metadata_output:str, metadata_input:str, metadata_overwrite:bool, metadata_as_caption:bool, metadata_use_fullpath:bool): if not metadata_input: diff --git a/scripts/ui/tab_batch_edit_captions.py b/scripts/ui/tab_batch_edit_captions.py index 89395f3..3cd8e6b 100644 --- a/scripts/ui/tab_batch_edit_captions.py +++ b/scripts/ui/tab_batch_edit_captions.py @@ -1,9 +1,14 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, List, Callable import gradio as gr from .ui_common import * from .uibase import UIBase from .block_tag_select import TagSelectUI +if TYPE_CHECKING: + from .ui_classes import * + SortBy = dte_instance.SortBy SortOrder = dte_instance.SortOrder @@ -12,7 +17,7 @@ class BatchEditCaptionsUI(UIBase): self.tag_select_ui_remove = TagSelectUI() self.show_only_selected_tags = False - def create_ui(self, cfg_batch_edit, get_filters): + def create_ui(self, cfg_batch_edit, get_filters:Callable[[], List[dte_module.filters.Filter]]): with gr.Tab(label='Search and Replace'): with gr.Column(variant='panel'): gr.HTML('Edit common tags.') @@ -64,7 +69,7 @@ class BatchEditCaptionsUI(UIBase): self.rb_sort_order = gr.Radio(choices=[e.value for e in SortOrder], value=cfg_batch_edit.batch_sort_order, interactive=True, label='Sort Order') self.btn_sort_selected = gr.Button(value='Sort selected tags', variant='primary') - def set_callbacks(self, o_update_filter_and_gallery, load_dataset, filter_by_tags, get_filters, update_filter_and_gallery): + def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], load_dataset:LoadDatasetUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]): load_dataset.btn_load_datasets.click( fn=lambda:['', ''], outputs=[self.tb_common_tags, self.tb_edit_tags] @@ -169,7 +174,7 @@ class BatchEditCaptionsUI(UIBase): ) - def get_common_tags(self, get_filters, filter_by_tags): + def get_common_tags(self, get_filters:Callable[[], List[dte_module.filters.Filter]], filter_by_tags:FilterByTagsUI): if self.show_only_selected_tags: tags = ', '.join([t for t in dte_instance.get_common_tags(filters=get_filters()) if t in filter_by_tags.tag_filter_ui.filter.tags]) else: diff --git a/scripts/ui/tab_edit_caption_of_selected_image.py b/scripts/ui/tab_edit_caption_of_selected_image.py index 9ef2f17..1f998f8 100644 --- a/scripts/ui/tab_edit_caption_of_selected_image.py +++ b/scripts/ui/tab_edit_caption_of_selected_image.py @@ -1,3 +1,5 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, List, Callable import gradio as gr from modules import shared @@ -6,6 +8,9 @@ from scripts.dte_instance import dte_module from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * + SortBy = dte_instance.SortBy SortOrder = dte_instance.SortOrder @@ -45,7 +50,7 @@ class EditCaptionOfSelectedImageUI(UIBase): gr.HTML("""Changes are not applied to the text files until the "Save all changes" button is pressed.""") - def set_callbacks(self, o_update_filter_and_gallery, dataset_gallery, load_dataset, get_filters, update_filter_and_gallery): + def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, load_dataset:LoadDatasetUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]): load_dataset.btn_load_datasets.click( fn=lambda:['', -1], outputs=[self.tb_caption, self.nb_hidden_image_index_save_or_not] diff --git a/scripts/ui/tab_filter_by_selection.py b/scripts/ui/tab_filter_by_selection.py index 57898c5..7b242cf 100644 --- a/scripts/ui/tab_filter_by_selection.py +++ b/scripts/ui/tab_filter_by_selection.py @@ -1,9 +1,13 @@ -from typing import List +from __future__ import annotations +from typing import TYPE_CHECKING, List, Callable import gradio as gr from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * + filters = dte_module.filters @@ -17,7 +21,7 @@ class FilterBySelectionUI(UIBase): def get_current_txt_selection(self): return f"""Selected Image : {self.selected_path}""" - def create_ui(self, image_columns): + def create_ui(self, image_columns:int): with gr.Row(visible=False): self.btn_hidden_set_selection_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_selection_index") self.nb_hidden_selection_image_index = gr.Number(value=-1) @@ -38,7 +42,7 @@ class FilterBySelectionUI(UIBase): self.btn_apply_image_selection_filter = gr.Button(value='Apply selection filter', variant='primary') - def set_callbacks(self, o_update_filter_and_gallery, dataset_gallery, filter_by_tags, get_filters, update_filter_and_gallery): + def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]): def selection_index_changed(idx:int = -1): idx = int(idx) if idx is not None else -1 img_paths = arrange_selection_order(self.tmp_selection) diff --git a/scripts/ui/tab_filter_by_tags.py b/scripts/ui/tab_filter_by_tags.py index c0466b4..63a44a3 100644 --- a/scripts/ui/tab_filter_by_tags.py +++ b/scripts/ui/tab_filter_by_tags.py @@ -1,9 +1,14 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, List, Callable import gradio as gr from .ui_common import * from .uibase import UIBase from .block_tag_filter import TagFilterUI +if TYPE_CHECKING: + from .ui_classes import * + filters = dte_module.filters class FilterByTagsUI(UIBase): @@ -31,7 +36,7 @@ class FilterByTagsUI(UIBase): logic_n = filters.TagFilter.Logic.AND if cfg_filter_n.logic=='AND' else filters.TagFilter.Logic.NONE if cfg_filter_n.logic=='NONE' else filters.TagFilter.Logic.OR self.tag_filter_ui_neg.create_ui(get_filters, logic_n, cfg_filter_n.sort_by, cfg_filter_n.sort_order, cfg_filter_n.sw_prefix, cfg_filter_n.sw_suffix, cfg_filter_n.sw_regex) - def set_callbacks(self, o_update_gallery, o_update_filter_and_gallery, batch_edit_captions, move_or_delete_files, update_gallery, update_filter_and_gallery, get_filters): + def set_callbacks(self, o_update_gallery:List[gr.components.Component], o_update_filter_and_gallery:List[gr.components.Component], batch_edit_captions:BatchEditCaptionsUI, move_or_delete_files:MoveOrDeleteFilesUI, update_gallery:Callable[[], List], update_filter_and_gallery:Callable[[], List], get_filters:Callable[[], List[dte_module.filters.Filter]]): common_callback = lambda : \ update_gallery() + \ batch_edit_captions.get_common_tags(get_filters, self) + \ diff --git a/scripts/ui/tab_move_or_delete_files.py b/scripts/ui/tab_move_or_delete_files.py index a27b62d..46aa68d 100644 --- a/scripts/ui/tab_move_or_delete_files.py +++ b/scripts/ui/tab_move_or_delete_files.py @@ -1,10 +1,15 @@ -from typing import List +from __future__ import annotations +from typing import TYPE_CHECKING, List, Callable import gradio as gr from .ui_common import * from .uibase import UIBase +if TYPE_CHECKING: + from .ui_classes import * + + class MoveOrDeleteFilesUI(UIBase): def __init__(self): self.target_data = 'Selected One' @@ -25,7 +30,7 @@ class MoveOrDeleteFilesUI(UIBase): def get_current_move_or_delete_target_num(self): return self.current_target_txt - def set_callbacks(self, o_update_filter_and_gallery, dataset_gallery, filter_by_tags, batch_edit_captions, filter_by_selection, edit_caption_of_selected_image, get_filters, update_filter_and_gallery): + def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, batch_edit_captions:BatchEditCaptionsUI, filter_by_selection:FilterBySelectionUI, edit_caption_of_selected_image:EditCaptionOfSelectedImageUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]): def _get_current_move_or_delete_target_num(): if self.target_data == 'Selected One': self.current_target_txt = f'Target dataset num: {1 if dataset_gallery.selected_index != -1 else 0}'