From 6c3e80ba2e6ebe5a99b8c8b556b569a2f9d5b2c5 Mon Sep 17 00:00:00 2001 From: toshiaki1729 <116595002+toshiaki1729@users.noreply.github.com> Date: Sun, 4 Dec 2022 09:31:39 +0900 Subject: [PATCH] Refactoring --- .../dataset_tag_editor/dataset_tag_editor.py | 10 +-- scripts/main.py | 81 +++++++++++-------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/scripts/dataset_tag_editor/dataset_tag_editor.py b/scripts/dataset_tag_editor/dataset_tag_editor.py index 9159fa4..374cd56 100644 --- a/scripts/dataset_tag_editor/dataset_tag_editor.py +++ b/scripts/dataset_tag_editor/dataset_tag_editor.py @@ -117,20 +117,13 @@ class DatasetTagEditor: return [] - def set_path_filter(self, path:Optional[Set[str]] = None): - if path: - self.path_filter = PathFilter(path, PathFilter.Mode.INCLUSIVE) - else: - self.path_filter = PathFilter() - - def get_filtered_imgpath_and_tags(self, filters: List[Dataset.Filter] = [], filter_word: Optional[str] = None) -> Tuple[List[str], Set[str]]: filtered_set = self.dataset.copy() for filter in filters: filtered_set.filter(filter) tag_set = filtered_set.get_tagset() - img_paths = sorted(img_paths) + img_paths = sorted(filtered_set.datas.keys()) if filter_word: # all tags with filter_word @@ -253,7 +246,6 @@ class DatasetTagEditor: deepbooru.model.stop() self.construct_tag_counts() - self.set_path_filter() print(f'Loading Completed: {len(self.dataset)} images found') diff --git a/scripts/main.py b/scripts/main.py index bd1a67a..2a4a84c 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -12,11 +12,10 @@ path_filter = PathFilter() total_image_num = 0 displayed_image_num = 0 -____current_tag_filter = '' current_selection = 0 tmp_selection_img_path_set = set() -selected_image_path = '' -selection_selected_image_path____ = '' +gallery_selected_image_path = '' +selection_selected_image_path = '' # ================================================================ # Callbacks for "Filter and Edit Tags" tab @@ -24,6 +23,8 @@ selection_selected_image_path____ = '' def arrange_tag_order(tags: List[str], sort_by: str, sort_order: str) -> List[str]: tags = dataset_tag_editor.sort_tags(tags=tags, sort_by=sort_by, sort_order=sort_order) + tags_in_filter = [tag for tag in tags if tag in tag_filter.tags] + tags = tags_in_filter + [tag for tag in tags if tag not in tag_filter.tags] return tags @@ -32,16 +33,16 @@ def get_current_txt_filter(): Displayed Images : {displayed_image_num} / {total_image_num} total
Current Tag Filter : {tag_filter} AND {tag_filter_neg}
Current Selection Filter : {current_selection} images
- Selected Image : {selected_image_path} + Selected Image : {gallery_selected_image_path} """ def get_current_txt_selection(): - return f"""Selected Image : {selection_selected_image_path____}""" + return f"""Selected Image : {selection_selected_image_path}""" def load_files_from_dir(dir: str, sort_by: str, sort_order: str, recursive: bool, load_caption_from_filename: bool, use_interrogator: str, use_clip: bool, use_booru: bool): - global total_image_num, displayed_image_num, current_selection, tmp_selection_img_path_set, selected_image_path, tag_filter, tag_filter_neg, path_filter + global total_image_num, displayed_image_num, current_selection, tmp_selection_img_path_set, gallery_selected_image_path, selection_selected_image_path, tag_filter, tag_filter_neg, path_filter interrogate_method = InterrogateMethod.NONE if use_interrogator == 'If Empty': @@ -62,7 +63,8 @@ def load_files_from_dir(dir: str, sort_by: str, sort_order: str, recursive: bool total_image_num = displayed_image_num = len(dataset_tag_editor.get_img_path_set()) tmp_selection_img_path_set = set() current_selection = 0 - selected_image_path = '' + gallery_selected_image_path = '' + selection_selected_image_path = '' return [ img_paths, [], @@ -75,7 +77,7 @@ def load_files_from_dir(dir: str, sort_by: str, sort_order: str, recursive: bool def search_tags(filter_word: str, sort_by: str, sort_order: str): filter_tags = dataset_tag_editor.read_tags(filter_tags) - _, tags = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[tag_filter, tag_filter_neg], filter_word=filter_word) + _, tags = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[path_filter, tag_filter, tag_filter_neg], filter_word=filter_word) tags = arrange_tag_order(tags, sort_by=sort_by, sort_order=sort_order) return gr.CheckboxGroup.update(choices=dataset_tag_editor.write_tags(tags)) @@ -86,25 +88,31 @@ def clear_tag_filters(sort_by, sort_order): def rearrange_tag_order(filter_word: str, sort_by: str, sort_order: str): filter_tags = dataset_tag_editor.read_tags(filter_tags) - _, tags = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[tag_filter, tag_filter_neg], filter_word=filter_word) + _, tags = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[path_filter, tag_filter, tag_filter_neg], filter_word=filter_word) tags = arrange_tag_order(tags=tags, sort_by=sort_by, sort_order=sort_order) return gr.CheckboxGroup.update(choices=dataset_tag_editor.write_tags(tags)) def filter_gallery_by_checkbox(filter_tags: List[str], filter_word: str, sort_by: str, sort_order: str): + global tag_filter filter_tags = dataset_tag_editor.read_tags(filter_tags) + tag_filter = TagFilter(set(filter_tags), TagFilter.Logic.AND, TagFilter.Mode.INCLUSIVE) return filter_gallery(filter_word=filter_word, sort_by=sort_by, sort_order=sort_order) def filter_gallery(filter_word: str, sort_by: str, sort_order: str): global displayed_image_num, current_selection - img_paths, tags = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[tag_filter, tag_filter_neg], filter_word=filter_word) - displayed_image_num = len(img_paths) + img_paths, tags = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[path_filter, tag_filter, tag_filter_neg], filter_word=filter_word) + tags = arrange_tag_order(tags=tags, sort_by=sort_by, sort_order=sort_order) - filter_tags = dataset_tag_editor.write_tags(filter_tags) + filter_tags = [tag for tag in tags if tag in tag_filter.tags] tags = dataset_tag_editor.write_tags(tags) + filter_tags = dataset_tag_editor.write_tags(filter_tags) + + displayed_image_num = len(img_paths) current_selection = len(tmp_selection_img_path_set) tag_txt = ', '.join(tag_filter.tags) + if filter_tags and len(filter_tags) == 0: filter_tags = None return [ @@ -139,14 +147,14 @@ def arrange_selection_order(paths: List[str]) -> List[str]: def selection_index_changed(idx: int): - global tmp_selection_img_path_set, selection_selected_image_path____ + global tmp_selection_img_path_set, selection_selected_image_path idx = int(idx) img_paths = arrange_selection_order(tmp_selection_img_path_set) if idx < 0 or len(img_paths) <= idx: - selection_selected_image_path____ = '' + selection_selected_image_path = '' idx = -1 else: - selection_selected_image_path____ = img_paths[idx] + selection_selected_image_path = img_paths[idx] return [get_current_txt_selection(), idx] @@ -176,14 +184,14 @@ def invert_image_selection(): def remove_image_selection(idx: int): - global tmp_selection_img_path_set, selection_selected_image_path____ + global tmp_selection_img_path_set, selection_selected_image_path idx = int(idx) img_paths = arrange_selection_order(tmp_selection_img_path_set) if idx < 0 or len(img_paths) <= idx: idx = -1 else: tmp_selection_img_path_set.remove(img_paths[idx]) - selection_selected_image_path____ = '' + selection_selected_image_path = '' return [ arrange_selection_order(tmp_selection_img_path_set), @@ -193,10 +201,10 @@ def remove_image_selection(idx: int): def clear_image_selection(): - global tmp_selection_img_path_set, selection_selected_image_path____ + global tmp_selection_img_path_set, selection_selected_image_path, path_filter tmp_selection_img_path_set.clear() - selection_selected_image_path____ = '' - dataset_tag_editor.set_path_filter() + selection_selected_image_path = '' + path_filter = PathFilter() return[ [], get_current_txt_selection(), @@ -205,8 +213,11 @@ def clear_image_selection(): def apply_image_selection_filter(filter_word: str, sort_by: str, sort_order: str): - global tmp_selection_img_path_set - dataset_tag_editor.set_path_filter(tmp_selection_img_path_set) + global path_filter + if len(tmp_selection_img_path_set) > 0: + path_filter = PathFilter(tmp_selection_img_path_set, PathFilter.Mode.INCLUSIVE) + else: + path_filter = PathFilter() return filter_gallery(filter_word=filter_word, sort_by=sort_by, sort_order=sort_order) @@ -215,15 +226,15 @@ def apply_image_selection_filter(filter_word: str, sort_by: str, sort_order: str # ================================================================ def gallery_index_changed(idx: int): - global displayed_image_num, total_image_num, ____current_tag_filter, current_selection, selected_image_path + global gallery_selected_image_path idx = int(idx) img_paths, _ = dataset_tag_editor.get_filtered_imgpath_and_tags(filters=[tag_filter, tag_filter_neg]) tags_txt = '' if 0 <= idx and idx < len(img_paths): - selected_image_path = img_paths[idx] - tags_txt = ', '.join(dataset_tag_editor.get_tags_by_image_path(selected_image_path)) + gallery_selected_image_path = img_paths[idx] + tags_txt = ', '.join(dataset_tag_editor.get_tags_by_image_path(gallery_selected_image_path)) else: - selected_image_path = '' + gallery_selected_image_path = '' idx = -1 return [ @@ -250,12 +261,12 @@ def change_tags_selected_image(tags_text: str, sort_by: str, sort_order: str, id def interrogate_selected_image_clip(): - global selected_image_path - return interrogate_image_clip(selected_image_path) + global gallery_selected_image_path + return interrogate_image_clip(gallery_selected_image_path) def interrogate_selected_image_booru(): - global selected_image_path - return interrogate_image_booru(selected_image_path) + global gallery_selected_image_path + return interrogate_image_booru(gallery_selected_image_path) # ================================================================ @@ -263,7 +274,7 @@ def interrogate_selected_image_booru(): # ================================================================ def on_ui_tabs(): - global displayed_image_num, total_image_num, ____current_tag_filter, current_selection, selected_image_path, selection_selected_image_path____ + global displayed_image_num, total_image_num, current_selection, gallery_selected_image_path, selection_selected_image_path with gr.Blocks(analytics_enabled=False) as dataset_tag_editor_interface: with gr.Row(visible=False): btn_hidden_set_index = gr.Button(elem_id="dataset_tag_editor_btn_hidden_set_index") @@ -404,8 +415,8 @@ def on_ui_tabs(): cbg_tags.change( fn=filter_gallery_by_checkbox, - inputs=[tb_search_tags, rd_sort_by, rd_sort_order], - outputs=[gl_dataset_images, tb_selected_tags, tb_edit_tags, lbl_hidden_image_index, txt_filter] + inputs=[cbg_tags, tb_search_tags, rd_sort_by, rd_sort_order], + outputs=[gl_dataset_images, tb_selected_tags, cbg_tags, tb_edit_tags, lbl_hidden_image_index, txt_filter] ) rd_sort_by.change( @@ -495,7 +506,7 @@ def on_ui_tabs(): btn_apply_image_selection_filter.click( fn=apply_image_selection_filter, inputs=[tb_search_tags, rd_sort_by, rd_sort_order], - outputs=[gl_dataset_images, tb_selected_tags, tb_edit_tags, lbl_hidden_image_index, txt_filter] + outputs=[gl_dataset_images, tb_selected_tags, cbg_tags, tb_edit_tags, lbl_hidden_image_index, txt_filter] ) #---------------------------------------------------------------- @@ -504,7 +515,7 @@ def on_ui_tabs(): btn_hidden_set_index.click( fn=gallery_index_changed, _js="(x, y) => [x, dataset_tag_editor_gl_dataset_images_selected_index()]", - inputs=[lbl_hidden_image_index], + inputs=lbl_hidden_image_index, outputs=[tb_caption_selected_image, txt_filter, lbl_hidden_image_index] )