Add debug option to add tags to all records, Add tags multi-selection while importing from civitai

main
Alexander Sokol 2025-03-27 18:25:33 +02:00
parent 7072fce409
commit 26434174aa
2 changed files with 106 additions and 30 deletions

View File

@ -6,13 +6,14 @@ from urllib.parse import urlparse, parse_qs
import gradio as gr
import requests
from scripts.mo.data.mapping_utils import create_version_dict
from scripts.mo.data.storage import map_record_to_dict
from scripts.mo.environment import env
from scripts.mo.models import ModelType, Record
from scripts.mo.ui_styled_html import alert_danger, alert_warning
from scripts.mo.utils import is_blank
from scripts.mo.data.mapping_utils import create_version_dict
all_available_tags = []
def _get_model_images(model_version_dict):
result = []
@ -116,30 +117,38 @@ def _on_fetch_url_clicked(url):
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
data_dict = create_model_dict(data)
try:
if response.status_code == 200:
data = response.json()
data_dict = create_model_dict(data)
duplicate_warning = ''
if env.check_duplicates():
civurl = f"https://civitai.com/models/{model_id}"
duplicate_candidates = env.storage.get_records_by_query(f"SELECT * FROM RECORD WHERE URL = '{civurl}'")
if len(duplicate_candidates) > 0:
duplicate_list = ['Fetched Model already has at least a version present as record']
for record in duplicate_candidates:
duplicate_list.append(record.name)
duplicate_warning = alert_warning(duplicate_list)
duplicate_warning = ''
if env.check_duplicates():
civurl = f"https://civitai.com/models/{model_id}"
duplicate_candidates = env.storage.get_records_by_query(f"SELECT * FROM RECORD WHERE URL = '{civurl}'")
if len(duplicate_candidates) > 0:
duplicate_list = ['Fetched Model already has at least a version present as record']
for record in duplicate_candidates:
duplicate_list.append(record.name)
duplicate_warning = alert_warning(duplicate_list)
return [
data_dict,
gr.HTML.update(value='' if duplicate_warning == '' else duplicate_warning),
gr.Column.update(visible=True),
*_create_ui_update(data_dict=data_dict, selected_version_id=selected_model_version_id)
]
else:
return [
data_dict,
gr.HTML.update(value='' if duplicate_warning == '' else duplicate_warning),
gr.Column.update(visible=True),
*_create_ui_update(data_dict=data_dict, selected_version_id=selected_model_version_id)
]
else:
return [
None,
gr.HTML.update(value=alert_danger(f'Request failed with status code: {response.status_code}')),
gr.Column.update(visible=False),
*_create_ui_update()
]
except Exception as e:
return [
None,
gr.HTML.update(value=alert_danger(f'Request failed with status code: {response.status_code}')),
gr.HTML.update(value=alert_danger(f'Request failed with error: {e}')),
gr.Column.update(visible=False),
*_create_ui_update()
]
@ -150,7 +159,7 @@ def _create_ui_update(data_dict=None, selected_version=None, selected_version_id
return [
gr.Textbox.update(value=''),
gr.Dropdown.update(value=''),
gr.Textbox.update(value=''),
gr.Dropdown.update(value='', choices=[]),
gr.Dropdown.update(value='', choices=[]),
gr.Textbox.update(value=''),
gr.Gallery.update(value=None),
@ -196,6 +205,17 @@ def _create_ui_update(data_dict=None, selected_version=None, selected_version_id
name = f"{data_dict['name']} [{version['name']}]"
model_type = data_dict['mode_type'].value
tags = data_dict['tags']
available_groups = env.storage.get_available_groups()
tags = [tag.strip() for tag in tags.split(',') if tag.strip()] if tags else []
all_tags = list(set(available_groups + tags))
all_tags.sort()
global all_available_tags
all_available_tags = list(set(all_tags))
model_version = version['name']
model_versions = list(map(lambda x: x['name'], data_dict['versions']))
@ -211,7 +231,7 @@ def _create_ui_update(data_dict=None, selected_version=None, selected_version_id
return [
gr.Textbox.update(value=name),
gr.Dropdown.update(value=model_type),
gr.Textbox.update(value=tags),
gr.Dropdown.update(value=tags, choices=all_tags),
gr.Dropdown.update(value=model_version, choices=model_versions),
gr.Textbox.update(value=image_url),
gr.Gallery.update(version['images']),
@ -286,7 +306,7 @@ def _prepare_import_data(state, import_url, name, use_model_name_as_download_fil
download_url = file['download_url']
description = state['description'] if include_description else ''
groups = list(map(lambda x: x.strip(), tags.split(','))) if bool(tags.strip()) else []
groups = tags
sha256 = file['sha256']
if use_model_name_as_download_filename:
@ -346,7 +366,8 @@ def _on_import_clicked(state, import_url, name, use_model_name_as_download_filen
]
def _on_edit_clicked(state, import_url, name, use_model_name_as_download_filename, model_type_value, tags, model_version_value, preview_url, file_value,
def _on_edit_clicked(state, import_url, name, use_model_name_as_download_filename, model_type_value, tags,
model_version_value, preview_url, file_value,
prompts, include_description):
result = _prepare_import_data(
state=state,
@ -389,6 +410,19 @@ def _on_gallery_select(data: gr.SelectData):
return data.value
def _on_new_tags_to_add(tags_add_textbox, tags_dropdown):
tags = [tag.strip() for tag in tags_add_textbox.split(',') if tag.strip()] if tags_add_textbox else []
choices_tags = tags_dropdown + [tag for tag in tags if tag not in tags_dropdown]
global all_available_tags
all_tags = all_available_tags + [tag for tag in tags if tag not in all_available_tags]
return [
gr.Textbox.update(value=''),
gr.Dropdown.update(value=choices_tags, choices=all_tags),
]
def civitai_import_ui_block():
import_url_textbox = gr.Textbox('',
label='civitai model url or id.',
@ -418,7 +452,25 @@ def civitai_import_ui_block():
'correct model type.',
interactive=True)
tags_textbox = gr.Textbox(label='Tags', interactive=True)
tags_dropdown = gr.Dropdown(label='Tags',
value='',
info='Edit model tags(groups)',
multiselect=True,
interactive=True)
with gr.Column():
tags_add_textbox = gr.Textbox(label='Add comma-separated tags new tags here',
interactive=True,
info='Add comma-separated tags new tags here. '
'For example: "tag1,tag2,tag3" and hit the button to add tags.'
)
with gr.Row():
gr.Markdown()
gr.Markdown()
gr.Markdown()
gr.Markdown()
tags_add_button = gr.Button('Add tags')
model_version_dropdown = gr.Dropdown(label='Model Version',
interactive=True,
info='Select model version.')
@ -457,6 +509,11 @@ def civitai_import_ui_block():
interactive=False,
visible=False)
tags_add_button.click(_on_new_tags_to_add,
inputs=[tags_add_textbox, tags_dropdown],
outputs=[tags_add_textbox, tags_dropdown])
fetch_url_button.click(_on_fetch_url_clicked,
inputs=import_url_textbox,
outputs=[import_model_state,
@ -464,7 +521,7 @@ def civitai_import_ui_block():
content_container,
name_widget,
model_type_dropdown,
tags_textbox,
tags_dropdown,
model_version_dropdown,
preview_url_textbox,
preview_gallery,
@ -482,7 +539,7 @@ def civitai_import_ui_block():
outputs=[
name_widget,
model_type_dropdown,
tags_textbox,
tags_dropdown,
model_version_dropdown,
preview_url_textbox,
preview_gallery,
@ -505,7 +562,7 @@ def civitai_import_ui_block():
name_widget,
use_name_as_filename,
model_type_dropdown,
tags_textbox,
tags_dropdown,
model_version_dropdown,
preview_url_textbox,
files_dropdown,
@ -522,7 +579,7 @@ def civitai_import_ui_block():
name_widget,
use_name_as_filename,
model_type_dropdown,
tags_textbox,
tags_dropdown,
model_version_dropdown,
preview_url_textbox,
files_dropdown,

View File

@ -285,16 +285,35 @@ def _on_remove_all_records_click():
return "All records has been removed."
def _on_add_tag_to_all_records_click(tag):
records = env.storage.get_all_records()
records_updated_count = 0
for record in records:
if tag not in record.groups:
record.groups.append(tag)
env.storage.update_record(record)
records_updated_count += 1
return f'{records_updated_count} records has been updated.'
def _ui_debug_utils():
with gr.Row():
with gr.Column():
remove_duplicates_button = gr.Button("Remove Records duplicate")
remove_all_records = gr.Button("Remove all Records")
add_all_records_tag_text = gr.Textbox(label='Tag to add all to records:',
value='',
max_lines=1,
info='This tag will be added to all records.')
add_tag_to_all_records_button = gr.Button("Add tag to all records")
with gr.Column():
debug_html_output = gr.HTML()
remove_duplicates_button.click(fn=_on_remove_duplicates_click, outputs=[debug_html_output])
remove_all_records.click(fn=_on_remove_all_records_click, outputs=[debug_html_output])
add_tag_to_all_records_button.click(fn=_on_add_tag_to_all_records_click,
inputs=[add_all_records_tag_text], outputs=[debug_html_output])
def debug_ui_block():