Add debug option to add tags to all records, Add tags multi-selection while importing from civitai
parent
7072fce409
commit
26434174aa
|
|
@ -6,13 +6,14 @@ from urllib.parse import urlparse, parse_qs
|
|||
import gradio as gr
|
||||
import requests
|
||||
|
||||
from scripts.mo.data.mapping_utils import create_version_dict
|
||||
from scripts.mo.data.storage import map_record_to_dict
|
||||
from scripts.mo.environment import env
|
||||
from scripts.mo.models import ModelType, Record
|
||||
from scripts.mo.ui_styled_html import alert_danger, alert_warning
|
||||
from scripts.mo.utils import is_blank
|
||||
from scripts.mo.data.mapping_utils import create_version_dict
|
||||
|
||||
all_available_tags = []
|
||||
|
||||
def _get_model_images(model_version_dict):
|
||||
result = []
|
||||
|
|
@ -116,30 +117,38 @@ def _on_fetch_url_clicked(url):
|
|||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
data_dict = create_model_dict(data)
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
data_dict = create_model_dict(data)
|
||||
|
||||
duplicate_warning = ''
|
||||
if env.check_duplicates():
|
||||
civurl = f"https://civitai.com/models/{model_id}"
|
||||
duplicate_candidates = env.storage.get_records_by_query(f"SELECT * FROM RECORD WHERE URL = '{civurl}'")
|
||||
if len(duplicate_candidates) > 0:
|
||||
duplicate_list = ['Fetched Model already has at least a version present as record']
|
||||
for record in duplicate_candidates:
|
||||
duplicate_list.append(record.name)
|
||||
duplicate_warning = alert_warning(duplicate_list)
|
||||
duplicate_warning = ''
|
||||
if env.check_duplicates():
|
||||
civurl = f"https://civitai.com/models/{model_id}"
|
||||
duplicate_candidates = env.storage.get_records_by_query(f"SELECT * FROM RECORD WHERE URL = '{civurl}'")
|
||||
if len(duplicate_candidates) > 0:
|
||||
duplicate_list = ['Fetched Model already has at least a version present as record']
|
||||
for record in duplicate_candidates:
|
||||
duplicate_list.append(record.name)
|
||||
duplicate_warning = alert_warning(duplicate_list)
|
||||
|
||||
return [
|
||||
data_dict,
|
||||
gr.HTML.update(value='' if duplicate_warning == '' else duplicate_warning),
|
||||
gr.Column.update(visible=True),
|
||||
*_create_ui_update(data_dict=data_dict, selected_version_id=selected_model_version_id)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
data_dict,
|
||||
gr.HTML.update(value='' if duplicate_warning == '' else duplicate_warning),
|
||||
gr.Column.update(visible=True),
|
||||
*_create_ui_update(data_dict=data_dict, selected_version_id=selected_model_version_id)
|
||||
]
|
||||
else:
|
||||
return [
|
||||
None,
|
||||
gr.HTML.update(value=alert_danger(f'Request failed with status code: {response.status_code}')),
|
||||
gr.Column.update(visible=False),
|
||||
*_create_ui_update()
|
||||
]
|
||||
except Exception as e:
|
||||
return [
|
||||
None,
|
||||
gr.HTML.update(value=alert_danger(f'Request failed with status code: {response.status_code}')),
|
||||
gr.HTML.update(value=alert_danger(f'Request failed with error: {e}')),
|
||||
gr.Column.update(visible=False),
|
||||
*_create_ui_update()
|
||||
]
|
||||
|
|
@ -150,7 +159,7 @@ def _create_ui_update(data_dict=None, selected_version=None, selected_version_id
|
|||
return [
|
||||
gr.Textbox.update(value=''),
|
||||
gr.Dropdown.update(value=''),
|
||||
gr.Textbox.update(value=''),
|
||||
gr.Dropdown.update(value='', choices=[]),
|
||||
gr.Dropdown.update(value='', choices=[]),
|
||||
gr.Textbox.update(value=''),
|
||||
gr.Gallery.update(value=None),
|
||||
|
|
@ -196,6 +205,17 @@ def _create_ui_update(data_dict=None, selected_version=None, selected_version_id
|
|||
name = f"{data_dict['name']} [{version['name']}]"
|
||||
model_type = data_dict['mode_type'].value
|
||||
tags = data_dict['tags']
|
||||
|
||||
available_groups = env.storage.get_available_groups()
|
||||
|
||||
tags = [tag.strip() for tag in tags.split(',') if tag.strip()] if tags else []
|
||||
all_tags = list(set(available_groups + tags))
|
||||
all_tags.sort()
|
||||
|
||||
global all_available_tags
|
||||
|
||||
all_available_tags = list(set(all_tags))
|
||||
|
||||
model_version = version['name']
|
||||
model_versions = list(map(lambda x: x['name'], data_dict['versions']))
|
||||
|
||||
|
|
@ -211,7 +231,7 @@ def _create_ui_update(data_dict=None, selected_version=None, selected_version_id
|
|||
return [
|
||||
gr.Textbox.update(value=name),
|
||||
gr.Dropdown.update(value=model_type),
|
||||
gr.Textbox.update(value=tags),
|
||||
gr.Dropdown.update(value=tags, choices=all_tags),
|
||||
gr.Dropdown.update(value=model_version, choices=model_versions),
|
||||
gr.Textbox.update(value=image_url),
|
||||
gr.Gallery.update(version['images']),
|
||||
|
|
@ -286,7 +306,7 @@ def _prepare_import_data(state, import_url, name, use_model_name_as_download_fil
|
|||
|
||||
download_url = file['download_url']
|
||||
description = state['description'] if include_description else ''
|
||||
groups = list(map(lambda x: x.strip(), tags.split(','))) if bool(tags.strip()) else []
|
||||
groups = tags
|
||||
sha256 = file['sha256']
|
||||
|
||||
if use_model_name_as_download_filename:
|
||||
|
|
@ -346,7 +366,8 @@ def _on_import_clicked(state, import_url, name, use_model_name_as_download_filen
|
|||
]
|
||||
|
||||
|
||||
def _on_edit_clicked(state, import_url, name, use_model_name_as_download_filename, model_type_value, tags, model_version_value, preview_url, file_value,
|
||||
def _on_edit_clicked(state, import_url, name, use_model_name_as_download_filename, model_type_value, tags,
|
||||
model_version_value, preview_url, file_value,
|
||||
prompts, include_description):
|
||||
result = _prepare_import_data(
|
||||
state=state,
|
||||
|
|
@ -389,6 +410,19 @@ def _on_gallery_select(data: gr.SelectData):
|
|||
return data.value
|
||||
|
||||
|
||||
def _on_new_tags_to_add(tags_add_textbox, tags_dropdown):
|
||||
tags = [tag.strip() for tag in tags_add_textbox.split(',') if tag.strip()] if tags_add_textbox else []
|
||||
choices_tags = tags_dropdown + [tag for tag in tags if tag not in tags_dropdown]
|
||||
|
||||
global all_available_tags
|
||||
all_tags = all_available_tags + [tag for tag in tags if tag not in all_available_tags]
|
||||
|
||||
return [
|
||||
gr.Textbox.update(value=''),
|
||||
gr.Dropdown.update(value=choices_tags, choices=all_tags),
|
||||
]
|
||||
|
||||
|
||||
def civitai_import_ui_block():
|
||||
import_url_textbox = gr.Textbox('',
|
||||
label='civitai model url or id.',
|
||||
|
|
@ -418,7 +452,25 @@ def civitai_import_ui_block():
|
|||
'correct model type.',
|
||||
interactive=True)
|
||||
|
||||
tags_textbox = gr.Textbox(label='Tags', interactive=True)
|
||||
tags_dropdown = gr.Dropdown(label='Tags',
|
||||
value='',
|
||||
info='Edit model tags(groups)',
|
||||
multiselect=True,
|
||||
interactive=True)
|
||||
|
||||
with gr.Column():
|
||||
tags_add_textbox = gr.Textbox(label='Add comma-separated tags new tags here',
|
||||
interactive=True,
|
||||
info='Add comma-separated tags new tags here. '
|
||||
'For example: "tag1,tag2,tag3" and hit the button to add tags.'
|
||||
)
|
||||
with gr.Row():
|
||||
gr.Markdown()
|
||||
gr.Markdown()
|
||||
gr.Markdown()
|
||||
gr.Markdown()
|
||||
tags_add_button = gr.Button('Add tags')
|
||||
|
||||
model_version_dropdown = gr.Dropdown(label='Model Version',
|
||||
interactive=True,
|
||||
info='Select model version.')
|
||||
|
|
@ -457,6 +509,11 @@ def civitai_import_ui_block():
|
|||
interactive=False,
|
||||
visible=False)
|
||||
|
||||
|
||||
tags_add_button.click(_on_new_tags_to_add,
|
||||
inputs=[tags_add_textbox, tags_dropdown],
|
||||
outputs=[tags_add_textbox, tags_dropdown])
|
||||
|
||||
fetch_url_button.click(_on_fetch_url_clicked,
|
||||
inputs=import_url_textbox,
|
||||
outputs=[import_model_state,
|
||||
|
|
@ -464,7 +521,7 @@ def civitai_import_ui_block():
|
|||
content_container,
|
||||
name_widget,
|
||||
model_type_dropdown,
|
||||
tags_textbox,
|
||||
tags_dropdown,
|
||||
model_version_dropdown,
|
||||
preview_url_textbox,
|
||||
preview_gallery,
|
||||
|
|
@ -482,7 +539,7 @@ def civitai_import_ui_block():
|
|||
outputs=[
|
||||
name_widget,
|
||||
model_type_dropdown,
|
||||
tags_textbox,
|
||||
tags_dropdown,
|
||||
model_version_dropdown,
|
||||
preview_url_textbox,
|
||||
preview_gallery,
|
||||
|
|
@ -505,7 +562,7 @@ def civitai_import_ui_block():
|
|||
name_widget,
|
||||
use_name_as_filename,
|
||||
model_type_dropdown,
|
||||
tags_textbox,
|
||||
tags_dropdown,
|
||||
model_version_dropdown,
|
||||
preview_url_textbox,
|
||||
files_dropdown,
|
||||
|
|
@ -522,7 +579,7 @@ def civitai_import_ui_block():
|
|||
name_widget,
|
||||
use_name_as_filename,
|
||||
model_type_dropdown,
|
||||
tags_textbox,
|
||||
tags_dropdown,
|
||||
model_version_dropdown,
|
||||
preview_url_textbox,
|
||||
files_dropdown,
|
||||
|
|
|
|||
|
|
@ -285,16 +285,35 @@ def _on_remove_all_records_click():
|
|||
return "All records has been removed."
|
||||
|
||||
|
||||
def _on_add_tag_to_all_records_click(tag):
|
||||
records = env.storage.get_all_records()
|
||||
records_updated_count = 0
|
||||
for record in records:
|
||||
if tag not in record.groups:
|
||||
record.groups.append(tag)
|
||||
env.storage.update_record(record)
|
||||
records_updated_count += 1
|
||||
|
||||
return f'{records_updated_count} records has been updated.'
|
||||
|
||||
|
||||
def _ui_debug_utils():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
remove_duplicates_button = gr.Button("Remove Records duplicate")
|
||||
remove_all_records = gr.Button("Remove all Records")
|
||||
add_all_records_tag_text = gr.Textbox(label='Tag to add all to records:',
|
||||
value='',
|
||||
max_lines=1,
|
||||
info='This tag will be added to all records.')
|
||||
add_tag_to_all_records_button = gr.Button("Add tag to all records")
|
||||
with gr.Column():
|
||||
debug_html_output = gr.HTML()
|
||||
|
||||
remove_duplicates_button.click(fn=_on_remove_duplicates_click, outputs=[debug_html_output])
|
||||
remove_all_records.click(fn=_on_remove_all_records_click, outputs=[debug_html_output])
|
||||
add_tag_to_all_records_button.click(fn=_on_add_tag_to_all_records_click,
|
||||
inputs=[add_all_records_tag_text], outputs=[debug_html_output])
|
||||
|
||||
|
||||
def debug_ui_block():
|
||||
|
|
|
|||
Loading…
Reference in New Issue