add load-checkpoint api endpoint and test all models script

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/4115/head
Vladimir Mandic 2025-08-09 14:06:54 -04:00
parent 44a20a47c1
commit 73049f7bb8
6 changed files with 117 additions and 14 deletions

View File

@ -93,6 +93,8 @@ And (*as always*) many bugfixes and improvements to existing features!
- **prompt enhance** add `allura-org/Gemma-3-Glitter-4B` model support
- remove **LDSR**
- remove `api-only` cli option
- **API**
- add `/sdapi/v1/checkpoint` POST endpoint to simply load a model
- **Fixes**
- refactor legacy processing loop
- fix settings components mismatch
@ -111,6 +113,7 @@ And (*as always*) many bugfixes and improvements to existing features!
- fix *Flux.1-Kontext-Dev* with variable resolution
- use `utf_16_be` as primary metadata decoding
- fix `sd35` width/height alignment
- fix `nudenet` api
- reapply offloading on ipadapter load
- api set default script-name
- avoid forced gc and rely on thresholds

110
cli/test-all-models.py Normal file → Executable file
View File

@ -1,7 +1,24 @@
#!/usr/bin/env python
import io
import os
import time
import base64
import logging
import requests
import urllib3
import pathvalidate
from PIL import Image
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
log = logging.getLogger(__name__)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
output_folder = 'outputs/compare'
models = [
"sd_xl_base_1.0",
"tempestByVlad_baseV01",
"huggingface/stabilityai/stable-cascade",
"sdxl-base-v10-vaefix",
"tempest-by-vlad-0.1",
"stabilityai/stable-diffusion-3.5-medium",
"stabilityai/stable-diffusion-3.5-large",
"black-forest-labs/FLUX.1-dev",
@ -10,24 +27,97 @@ models = [
"vladmandic/chroma-unlocked-v50",
"vladmandic/chroma-unlocked-v50-annealed",
"Qwen/Qwen-Image",
"ostris/Flex.2-preview",
"HiDream-ai/HiDream-I1-Dev",
"HiDream-ai/HiDream-I1-Full",
"briaai/BRIA-3.2",
"nvidia/Cosmos-Predict2-2B-Text2Image",
"nvidia/Cosmos-Predict2-14B-Text2Image",
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
"OmniGen2/OmniGen2",
"stabilityai/stable-cascade",
"ostris/Flex.2-preview",
"Freepik/F-Lite",
"Freepik/F-Lite-Texture",
"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
"Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers",
"nvidia/Cosmos-Predict2-2B-Text2Image",
"nvidia/Cosmos-Predict2-14B-Text2Image",
"OmniGen2/OmniGen2",
"fal/AuraFlow-v0.3",
"PixArt-alpha/PixArt-XL-2-1024-MS",
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
"Alpha-VLLM/Lumina-Image-2.0",
"HiDream-ai/HiDream-I1-Dev",
"HiDream-ai/HiDream-I1-Full",
"Kwai-Kolors/Kolors-diffusers",
"briaai/BRIA-3.2",
"THUDM/CogView4-6B",
"kandinsky-community/kandinsky-2-1",
"kandinsky-community/kandinsky-2-2-decoder",
"kandinsky-community/kandinsky-3",
# "juggernautXL_juggXIByRundiffusion.safetensors@https://civitai.com/api/download/models/782002",
# "playground-v2.5-1024px-aesthetic.fp16.safetensors@https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/resolve/main/playground-v2.5-1024px-aesthetic.fp16.safetensors?download=true",
]
styles = [
'Fixed Astronaut',
'Fixed Bear',
'Fixed Steampunk City',
'Fixed Road sign',
'Fixed Futuristic hypercar',
'Fixed Pirate Ship in Space',
'Fixed Fallout girl',
'Fixed Kneeling on Bed',
'Fixed Girl in Sin City',
'Fixed Girl in a city',
'Fixed Lady in Tokyo',
'Fixed MadMax selfie',
'Fixed SDNext Neon',
]
def request(endpoint: str, dct: dict = None, method: str = 'POST'):
def auth():
if sd_username is not None and sd_password is not None:
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
return None
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
sd_username = os.environ.get('SDAPI_USR', None)
sd_password = os.environ.get('SDAPI_PWD', None)
method = requests.post if method.upper() == 'POST' else requests.get
req = method(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
if req.status_code != 200:
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
else:
return req.json()
def generate(): # pylint: disable=redefined-outer-name
for m, model in enumerate(models):
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
log.info(f'model: name="{model}" n={m+1}/{len(models)}')
for s, style in enumerate(styles):
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
style_name = pathvalidate.sanitize_filename(style, replacement_text='_')
fn = os.path.join(output_folder, f'{model_name}__{style_name}.jpg')
if os.path.exists(fn):
continue
request(f'/sdapi/v1/checkpoint?sd_model_checkpoint={model}', method='POST')
loaded = request('/sdapi/v1/checkpoint', method='GET')
if not (model in loaded.get('checkpoint') or model in loaded.get('title') or model in loaded.get('name')):
log.error(f' model: error="{model}"')
continue
log.info(f' style: name="{style}" n={s+1}/{len(styles)} fn="{fn}"')
t0 = time.time()
data = request('/sdapi/v1/txt2img', { 'styles': [style] })
t1 = time.time()
if 'images' in data and len(data['images']) > 0:
b64 = data['images'][0].split(',',1)[0]
image = Image.open(io.BytesIO(base64.b64decode(b64)))
info = data['info']
log.info(f' image: size={image.size} time={t1-t0:.2f} info="{len(info)}" fn="{fn}"')
image.save(fn)
if __name__ == "__main__":
log.info('test-all-models')
log.info(f'output="{output_folder}" models={len(models)} styles={len(styles)}')
log.info('start...')
generate()
log.info('done...')

View File

@ -91,6 +91,7 @@ class Api:
self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"])
self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"])
self.add_api_route("/sdapi/v1/checkpoint", endpoints.get_checkpoint, methods=["GET"])
self.add_api_route("/sdapi/v1/checkpoint", endpoints.set_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])

View File

@ -140,6 +140,14 @@ def get_checkpoint():
checkpoint['hash'] = shared.sd_model.sd_checkpoint_info.shorthash
return checkpoint
def set_checkpoint(sd_model_checkpoint: str, force:bool=False):
from modules import sd_models
if force:
sd_models.unload_model_weights(op='model')
shared.opts.sd_model_checkpoint = sd_model_checkpoint
model = sd_models.reload_model_weights()
return { 'ok': model is not None }
def post_refresh_checkpoints():
shared.refresh_checkpoints()
return {}

View File

@ -58,7 +58,7 @@ def banned_words(
def register_api():
from modules.shared import api as api_instance
api_instance.add_api_route("/sdapi/v1//nudenet", nudenet_censor, methods=["POST"], response_model=dict)
api_instance.add_api_route("/sdapi/v1//prompt-lang", prompt_check, methods=["POST"], response_model=dict)
api_instance.add_api_route("/sdapi/v1//image-guard", image_guard, methods=["POST"], response_model=dict)
api_instance.add_api_route("/sdapi/v1//prompt-banned", banned_words, methods=["POST"], response_model=list)
api_instance.add_api_route("/sdapi/v1/nudenet", nudenet_censor, methods=["POST"], response_model=dict)
api_instance.add_api_route("/sdapi/v1/prompt-lang", prompt_check, methods=["POST"], response_model=dict)
api_instance.add_api_route("/sdapi/v1/image-guard", image_guard, methods=["POST"], response_model=dict)
api_instance.add_api_route("/sdapi/v1/prompt-banned", banned_words, methods=["POST"], response_model=list)

View File

@ -201,6 +201,7 @@ def get_closet_checkpoint_match(s: str) -> CheckpointInfo:
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
checkpoint_info.type = 'huggingface'
return checkpoint_info
if s.startswith('huggingface/'):
model_name = s.replace('huggingface/', '')
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info