mirror of https://github.com/vladmandic/automatic
96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
import time
|
|
from installer import log
|
|
from modules.civitai.client_civitai import client
|
|
from modules.civitai.models_civitai import CivitModel, CivitSearchResponse
|
|
|
|
|
|
# Hardcoded fallback list — used by Gradio UI if discover_options() fails
|
|
base_models = ['', 'AuraFlow', 'Chroma', 'CogVideoX', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Krea', 'Flux.1 Kontext', 'Flux.2 D', 'HiDream', 'Hunyuan 1', 'Hunyuan Video', 'Illustrious', 'Kolors', 'LTXV', 'Lumina', 'Mochi', 'NoobAI', 'PixArt a', 'PixArt E', 'Pony', 'Pony V7', 'Qwen', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.1', 'SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'Wan Video 2.2 TI2V-5B', 'Wan Video 2.2 I2V-A14B', 'Wan Video 2.2 T2V-A14B', 'Wan Video 2.5 T2V', 'Wan Video 2.5 I2V', 'ZImageTurbo', 'Other']
|
|
|
|
|
|
def search_civitai(
|
|
query: str,
|
|
tag: str = '',
|
|
types: str = '',
|
|
sort: str = '',
|
|
period: str = '',
|
|
nsfw: bool = None,
|
|
limit: int = 0,
|
|
base: str = '',
|
|
token: str = None,
|
|
exact: bool = True,
|
|
) -> list[CivitModel]:
|
|
if not query:
|
|
log.error('CivitAI: empty query')
|
|
return []
|
|
|
|
t0 = time.time()
|
|
|
|
# Numeric query → single model fetch
|
|
if query.isnumeric():
|
|
model = client.get_model(int(query), token=token)
|
|
if model:
|
|
t1 = time.time()
|
|
log.info(f'CivitAI result: id={query} time={t1 - t0:.2f}')
|
|
return [model]
|
|
return []
|
|
|
|
response: CivitSearchResponse = client.search_models(
|
|
query=query,
|
|
tag=tag,
|
|
types=types,
|
|
sort=sort,
|
|
period=period,
|
|
base_models=[base] if base else None,
|
|
nsfw=nsfw,
|
|
limit=limit if limit > 0 else 20,
|
|
token=token,
|
|
)
|
|
|
|
all_models = response.items
|
|
exact_models: list[CivitModel] = []
|
|
if exact:
|
|
q_lower = query.lower()
|
|
for model in all_models:
|
|
names = [model.name.lower()]
|
|
names.extend(v.name.lower() for v in model.versions)
|
|
names.extend(f.name.lower() for v in model.versions for f in v.files)
|
|
if any(q_lower in name for name in names):
|
|
exact_models.append(model)
|
|
|
|
result = exact_models if exact_models else all_models
|
|
t1 = time.time()
|
|
log.info(f'CivitAI result: exact={len(exact_models)} total={len(all_models)} time={t1 - t0:.2f}')
|
|
return result
|
|
|
|
|
|
def create_model_cards(all_models: list[CivitModel]) -> str:
|
|
details = """
|
|
<div id="model-details">
|
|
</div>
|
|
"""
|
|
cards = """
|
|
<div id="model-cards" class="extra-network-cards">
|
|
{cards}
|
|
</div>
|
|
"""
|
|
card = """
|
|
<div class="card" data-id="{id}" onclick="modelCardClick({id})">
|
|
<div class="overlay"><div class="name">{name}</div></div>
|
|
<div class="version">{type}</div>
|
|
<img class="preview" src="{preview}" alt="{name}" loading="lazy" />
|
|
</div>
|
|
"""
|
|
all_cards = ''
|
|
for model in all_models:
|
|
previews = []
|
|
for version in model.versions:
|
|
for image in version.images:
|
|
if image.url and not image.url.lower().endswith('.mp4'):
|
|
previews.append(image.url)
|
|
if not previews:
|
|
previews = ['/sdapi/v1/network/thumb?filename=html/missing.png']
|
|
all_cards += card.format(id=model.id, name=model.name, type=model.type, preview=previews[0])
|
|
html = details + cards.format(cards=all_cards)
|
|
return html
|