automatic/cli/civitai-search.py

218 lines
7.8 KiB
Python
Executable File

#!/usr/bin/env python
from dataclasses import dataclass
import os
import sys
import json
import time
import logging
full_dct = False
full_html = False
debug = False
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
log = logging.getLogger(__name__)
@dataclass
class ModelImage():
def __init__(self, dct: dict):
if isinstance(dct, str):
dct = json.loads(dct)
self.id: int = dct.get('id', 0)
self.url: str = dct.get('url', '')
self.width: int = dct.get('width', 0)
self.height: int = dct.get('height', 0)
self.type: str = dct.get('type', 'Unknown')
self.dct: dict = dct if full_dct else {}
def __str__(self):
return f'ModelImage(id={self.id} url="{self.url}" width={self.width} height={self.height} type="{self.type}")'
@dataclass
class ModelFile():
def __init__(self, dct: dict):
if isinstance(dct, str):
dct = json.loads(dct)
self.id: int = dct.get('id', 0)
self.size: int = int(1024 * dct.get('sizeKB', 0))
self.name: str = dct.get('name', 'Unknown')
self.type: str = dct.get('type', 'Unknown')
self.hashes: list[str] = dct.get('hashes', {}).values()
self.url: str = dct.get('downloadUrl', '')
self.dct: dict = dct if full_dct else {}
def __str__(self):
return f'ModelFile(id={self.id} name="{self.name}" size={self.size} type="{self.type}" url="{self.url}")'
@dataclass
class ModelVersion():
def __init__(self, dct: dict):
import bs4
if isinstance(dct, str):
dct = json.loads(dct)
self.id: int = dct.get('id', 0)
self.name: str = dct.get('name', 'Unknown')
self.base: str = dct.get('baseModel', 'Unknown')
self.mtime: str = dct.get('publishedAt', '')
self.downloads: int = dct.get('stats', {}).get('downloadCount', 0)
self.availability: str = dct.get('availability', 'Unknown')
self.html: str = dct.get('description', '') or '' if full_html else ''
self.desc: str = bs4.BeautifulSoup(dct.get('description', '') or '', features="html.parser").get_text()
self.files = [ModelFile(f) for f in dct.get('files', [])]
self.images = [ModelImage(i) for i in dct.get('images', [])]
self.dct: dict = dct if full_dct else {}
def __str__(self):
return f'ModelVersion(id={self.id} name="{self.name}" base="{self.base}" mtime="{self.mtime}" downloads={self.downloads} availability={self.availability} desc="{self.desc[:30]}...")'
@dataclass
class Model():
def __init__(self, dct: dict):
import bs4
if isinstance(dct, str):
dct = json.loads(dct)
self.id: int = dct.get('id', 0)
self.url: str = f'https://civitai.com/models/{self.id}'
self.type: str = dct.get('type', 'Unknown')
self.name: str = dct.get('name', 'Unknown')
self.html: str = dct.get('description', '') or '' if full_html else ''
self.desc: str = bs4.BeautifulSoup(dct.get('description', '') or '', features="html.parser").get_text()
self.tags: list[str] = dct.get('tags', [])
self.nsfw: bool = dct.get('nsfw', False)
self.level: str = dct.get('nsfwLevel', 0)
self.availability: str = dct.get('availability', 'Unknown')
self.downloads: int = dct.get('stats', {}).get('downloadCount', 0)
self.creator: str = dct.get('creator', {}).get('username', 'Unknown')
self.versions: list[ModelVersion] = [ModelVersion(v) for v in dct.get('modelVersions', [])]
self.dct: dict = dct if full_dct else {}
def __str__(self):
return f'Model(id={self.id} type={self.type} name="{self.name}" versions={len(self.versions)} nsfw={self.nsfw}/{self.level} downloads={self.downloads} author="{self.creator}" tags={self.tags} desc="{self.desc[:30]}...")'
def search_civitai(
query:str,
tag:str = '', # optional:tag name
types:str = '', # (Checkpoint, TextualInversion, Hypernetwork, AestheticGradient, LORA, Controlnet, Poses)
sort:str = '', # (Highest Rated, Most Downloaded, Newest)
period:str = '', # (AllTime, Year, Month, Week, Day)
nsfw:bool = None, # optional:bool
limit:int = 0,
base:list[str] = [], # list
token:str = None,
exact:bool = True,
):
import requests
from urllib.parse import urlencode
if len(query) == 0:
log.error('CivitAI: empty query')
return []
t0 = time.time()
dct = { 'query': query }
if len(tag) > 0:
dct['tag'] = tag
if nsfw is not None:
dct['nsfw'] = 'true' if nsfw else 'false'
if limit > 0:
dct['limit'] = limit
if len(types) > 0:
dct['types'] = types
if len(sort) > 0:
dct['sort'] = sort
if len(period) > 0:
dct['period'] = period
if len(base) > 0:
dct['baseModels'] = ','.join(base)
encoded = urlencode(dct)
headers = {}
if token is None:
token = os.environ.get('CIVITAI_TOKEN', None)
if token is not None and len(token) > 0:
headers['Authorization'] = f'Bearer {token}'
url = 'https://civitai.com/api/v1/models'
uri = f'{url}?{encoded}'
log.info(f'CivitAI request: uri="{uri}" dct={dct} token={token is not None}')
result = requests.get(uri, headers=headers, timeout=60)
if result.status_code != 200:
log.error(f'CivitAI: code={result.status_code} reason={result.reason} uri={result.url}')
return []
models: list[Model] = []
exact_models: list[Model] = []
items = result.json().get('items', [])
for item in items:
models.append(Model(item))
if exact:
for model in models:
model_names = [model.name.lower()]
version_names = [v.name.lower() for v in model.versions]
file_names = [f.name.lower() for v in model.versions for f in v.files]
if any([query.lower() in name for name in model_names + version_names + file_names]): # noqa: C419
exact_models.append(model)
t1 = time.time()
log.info(f'CivitAI result: code={result.status_code} exact={len(exact_models)} total={len(models)} time={t1-t0:.2f}')
return exact_models if len(exact_models) > 0 else models
def models_to_dct(all_models:list, model_id:int=None):
dct = []
for model in all_models:
if model_id is not None and model.id != model_id:
continue
model_dct = model.__dict__.copy()
versions_dct = []
for version in model.versions:
version_dct = version.__dict__.copy()
version_dct['files'] = [f.__dict__.copy() for f in version.files]
version_dct['images'] = [i.__dict__.copy() for i in version.images]
versions_dct.append(version_dct)
model_dct['versions'] = versions_dct
dct.append(model_dct)
return dct
def print_models(models: list[Model]):
if debug:
from rich import print as dbg
else:
dbg = lambda *args, **kwargs: None # pylint: disable=unnecessary-lambda-assignment
for model in models:
log.info(f' {model}')
dbg('Model', model.dct)
for version in model.versions:
log.info(f' {version}')
dbg('ModelVersion', version.dct)
for file in version.files:
log.info(f' {file}')
dbg('ModelFile', file.dct)
for image in version.images:
log.info(f' {image}')
dbg('ModelImage', image.dct)
if __name__ == "__main__":
sys.argv.pop(0)
txt = ' '.join(sys.argv)
res = search_civitai(
query=txt,
# tag = '',
# types = '',
# sort = 'Most Downloaded',
# period = 'Year',
# nsfw = True,
# base = [],
# exact= True,
# limit=100,
)
print_models(res)