mirror of https://github.com/vladmandic/automatic
222 lines
8.7 KiB
Python
222 lines
8.7 KiB
Python
from dataclasses import dataclass
|
|
import os
|
|
import json
|
|
import time
|
|
from installer import install, log
|
|
|
|
|
|
full_dct = False
|
|
full_html = False
|
|
base_models = ['', 'ODOR', 'SD 1.4', 'SD 1.5', 'SD 1.5 LCM', 'SD 1.5 Hyper', 'SD 2.0', 'SD 2.0 768', 'SD 2.1', 'SD 2.1 768', 'SD 2.1 Unclip', 'SDXL 0.9', 'SDXL 1.0', 'SD 3', 'SD 3.5', 'SD 3.5 Medium', 'SD 3.5 Large', 'SD 3.5 Large Turbo', 'Pony', 'Flux.1 S', 'Flux.1 D', 'Flux.1 Kontext', 'AuraFlow', 'SDXL 1.0 LCM', 'SDXL Distilled', 'SDXL Turbo', 'SDXL Lightning', 'SDXL Hyper', 'Stable Cascade', 'SVD', 'SVD XT', 'Playground v2', 'PixArt a', 'PixArt E', 'Hunyuan 1', 'Hunyuan Video', 'Lumina', 'Kolors', 'Illustrious', 'Mochi', 'LTXV', 'CogVideoX', 'NoobAI', 'Wan Video', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p', 'HiDream', 'OpenAI', 'Imagen4', 'Other']
|
|
|
|
|
|
@dataclass
|
|
class ModelImage():
|
|
def __init__(self, dct: dict):
|
|
if isinstance(dct, str):
|
|
dct = json.loads(dct)
|
|
self.id: int = dct.get('id', 0)
|
|
self.url: str = dct.get('url', '')
|
|
self.width: int = dct.get('width', 0)
|
|
self.height: int = dct.get('height', 0)
|
|
self.type: str = dct.get('type', 'Unknown')
|
|
self.dct: dict = dct if full_dct else {}
|
|
|
|
def __str__(self):
|
|
return f'ModelImage(id={self.id} url="{self.url}" width={self.width} height={self.height} type="{self.type}")'
|
|
|
|
|
|
@dataclass
|
|
class ModelFile():
|
|
def __init__(self, dct: dict):
|
|
if isinstance(dct, str):
|
|
dct = json.loads(dct)
|
|
self.id: int = dct.get('id', 0)
|
|
self.size: int = int(1024 * dct.get('sizeKB', 0))
|
|
self.name: str = dct.get('name', 'Unknown')
|
|
self.type: str = dct.get('type', 'Unknown')
|
|
self.hashes: list[str] = [str(h) for h in dct.get('hashes', {}).values()]
|
|
self.url: str = dct.get('downloadUrl', '')
|
|
self.dct: dict = dct if full_dct else {}
|
|
|
|
def __str__(self):
|
|
return f'ModelFile(id={self.id} name="{self.name}" size={self.size} type="{self.type}" url="{self.url}")'
|
|
|
|
|
|
@dataclass
|
|
class ModelVersion():
|
|
def __init__(self, dct: dict):
|
|
import bs4
|
|
if isinstance(dct, str):
|
|
dct = json.loads(dct)
|
|
self.id: int = dct.get('id', 0)
|
|
self.name: str = dct.get('name', 'Unknown')
|
|
self.base: str = dct.get('baseModel', 'Unknown')
|
|
self.mtime: str = dct.get('publishedAt', '')
|
|
self.downloads: int = dct.get('stats', {}).get('downloadCount', 0)
|
|
self.availability: str = dct.get('availability', 'Unknown')
|
|
self.html: str = dct.get('description', '') or '' if full_html else ''
|
|
self.desc: str = bs4.BeautifulSoup(dct.get('description', '') or '', features="html.parser").get_text()
|
|
self.files = [ModelFile(f) for f in dct.get('files', [])]
|
|
self.images = [ModelImage(i) for i in dct.get('images', [])]
|
|
self.dct: dict = dct if full_dct else {}
|
|
|
|
def __str__(self):
|
|
return f'ModelVersion(id={self.id} name="{self.name}" base="{self.base}" mtime="{self.mtime}" downloads={self.downloads} availability={self.availability} desc="{self.desc[:30]}...")'
|
|
|
|
|
|
@dataclass
|
|
class Model():
|
|
def __init__(self, dct: dict):
|
|
import bs4
|
|
if isinstance(dct, str):
|
|
dct = json.loads(dct)
|
|
self.id: int = dct.get('id', 0)
|
|
self.url: str = f'https://civitai.com/models/{self.id}'
|
|
self.type: str = dct.get('type', 'Unknown')
|
|
self.name: str = dct.get('name', 'Unknown')
|
|
self.html: str = dct.get('description', '') or '' if full_html else ''
|
|
self.desc: str = bs4.BeautifulSoup(dct.get('description', '') or '', features="html.parser").get_text()
|
|
self.tags: list[str] = dct.get('tags', [])
|
|
self.nsfw: bool = dct.get('nsfw', False)
|
|
self.level: str = dct.get('nsfwLevel', 0)
|
|
self.availability: str = dct.get('availability', 'Unknown')
|
|
self.downloads: int = dct.get('stats', {}).get('downloadCount', 0)
|
|
self.creator: str = dct.get('creator', {}).get('username', 'Unknown')
|
|
self.versions: list[ModelVersion] = [ModelVersion(v) for v in dct.get('modelVersions', [])]
|
|
self.dct: dict = dct if full_dct else {}
|
|
|
|
def __str__(self):
|
|
return f'Model(id={self.id} type={self.type} name="{self.name}" versions={len(self.versions)} nsfw={self.nsfw}/{self.level} downloads={self.downloads} author="{self.creator}" tags={self.tags} desc="{self.desc[:30]}...")'
|
|
|
|
|
|
models: list[Model] = [] # global cache for civitai search results
|
|
|
|
|
|
def search_civitai(
|
|
query:str,
|
|
tag:str = '', # optional:tag name
|
|
types:str = '', # (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses)
|
|
sort:str = '', # (Highest Rated, Most Downloaded, Newest)
|
|
period:str = '', # (AllTime, Year, Month, Week, Day)
|
|
nsfw:bool = None, # optional:bool
|
|
limit:int = 0,
|
|
base:str = '', # list
|
|
token:str = None,
|
|
exact:bool = True,
|
|
):
|
|
global models # pylint: disable=global-statement
|
|
import requests
|
|
from urllib.parse import urlencode
|
|
install('bs4') # Ensure BeautifulSoup is installed
|
|
|
|
if len(query) == 0:
|
|
log.error('CivitAI: empty query')
|
|
return []
|
|
|
|
t0 = time.time()
|
|
dct = { 'query': query }
|
|
if len(tag) > 0:
|
|
dct['tag'] = tag
|
|
if nsfw is not None:
|
|
dct['nsfw'] = 'true' if nsfw else 'false'
|
|
if limit > 0:
|
|
dct['limit'] = limit
|
|
if len(types) > 0:
|
|
dct['types'] = types
|
|
if len(sort) > 0:
|
|
dct['sort'] = sort
|
|
if len(period) > 0:
|
|
dct['period'] = period
|
|
if len(base) > 0:
|
|
dct['baseModels'] = base
|
|
encoded = urlencode(dct)
|
|
|
|
headers = {}
|
|
if token is None:
|
|
token = os.environ.get('CIVITAI_TOKEN', None)
|
|
if token is not None and len(token) > 0:
|
|
headers['Authorization'] = f'Bearer {token}'
|
|
|
|
url = 'https://civitai.com/api/v1/models'
|
|
if query.isnumeric():
|
|
uri = f'{url}/{query}'
|
|
else:
|
|
uri = f'{url}?{encoded}'
|
|
|
|
log.info(f'CivitAI request: uri="{uri}" dct={dct} token={token is not None}')
|
|
result = requests.get(uri, headers=headers, timeout=60)
|
|
|
|
if result.status_code != 200:
|
|
log.error(f'CivitAI: code={result.status_code} reason={result.reason} uri={result.url}')
|
|
return []
|
|
|
|
all_models: list[Model] = []
|
|
exact_models: list[Model] = []
|
|
dct = result.json()
|
|
if 'items' not in dct:
|
|
items = [dct] # single model
|
|
else:
|
|
items = dct.get('items', [])
|
|
for item in items:
|
|
all_models.append(Model(item))
|
|
|
|
if exact:
|
|
for model in all_models:
|
|
model_names = [model.name.lower()]
|
|
version_names = [v.name.lower() for v in model.versions]
|
|
file_names = [f.name.lower() for v in model.versions for f in v.files]
|
|
if any([query.lower() in name for name in model_names + version_names + file_names]): # noqa: C419 # pylint: disable=use-a-generator
|
|
exact_models.append(model)
|
|
|
|
t1 = time.time()
|
|
log.info(f'CivitAI result: code={result.status_code} exact={len(exact_models)} total={len(models)} time={t1-t0:.2f}')
|
|
models = exact_models if len(exact_models) > 0 else all_models
|
|
return models
|
|
|
|
|
|
def create_model_cards(all_models: list[Model]) -> str:
|
|
details = """
|
|
<div id="model-details">
|
|
</div>
|
|
"""
|
|
cards = """
|
|
<div id="model-cards" class="extra-network-cards">
|
|
{cards}
|
|
</div>
|
|
"""
|
|
card = """
|
|
<div class="card" data-id="{id}" onclick="modelCardClick({id})">
|
|
<div class="overlay"><div class="name">{name}</div></div>
|
|
<div class="version">{type}</div>
|
|
<img class="preview" src="{preview}" alt="{name}" loading="lazy" />
|
|
</div>
|
|
"""
|
|
all_cards = ''
|
|
for model in all_models:
|
|
previews = []
|
|
for version in model.versions:
|
|
for image in version.images:
|
|
if image.url and len(image.url) > 0 and not image.url.lower().endswith('.mp4'):
|
|
previews.append(image.url)
|
|
if len(previews) == 0:
|
|
previews = ['/sdapi/v1/network/thumb?filename=html/missing.png']
|
|
all_cards += card.format(id=model.id, name=model.name, type=model.type, preview=previews[0])
|
|
html = details + cards.format(cards=all_cards)
|
|
return html
|
|
|
|
|
|
def print_models(all_models: list[Model]):
|
|
for model in all_models:
|
|
log.info(f' {model}')
|
|
log.trace('Model', model.dct)
|
|
for version in model.versions:
|
|
log.info(f' {version}')
|
|
log.trace('ModelVersion', version.dct)
|
|
for file in version.files:
|
|
log.info(f' {file}')
|
|
log.trace('ModelFile', file.dct)
|
|
for image in version.images:
|
|
log.info(f' {image}')
|
|
log.trace('ModelImage', image.dct)
|