mirror of https://github.com/vladmandic/automatic
add load-checkpoint api endpoint and test all models script
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4115/head
parent
44a20a47c1
commit
73049f7bb8
|
|
@ -93,6 +93,8 @@ And (*as always*) many bugfixes and improvements to existing features!
|
|||
- **prompt enhance** add `allura-org/Gemma-3-Glitter-4B` model support
|
||||
- remove **LDSR**
|
||||
- remove `api-only` cli option
|
||||
- **API**
|
||||
- add `/sdapi/v1/checkpoint` POST endpoint to simply load a model
|
||||
- **Fixes**
|
||||
- refactor legacy processing loop
|
||||
- fix settings components mismatch
|
||||
|
|
@ -111,6 +113,7 @@ And (*as always*) many bugfixes and improvements to existing features!
|
|||
- fix *Flux.1-Kontext-Dev* with variable resolution
|
||||
- use `utf_16_be` as primary metadata decoding
|
||||
- fix `sd35` width/height alignment
|
||||
- fix `nudenet` api
|
||||
- reapply offloading on ipadapter load
|
||||
- api set default script-name
|
||||
- avoid forced gc and rely on thresholds
|
||||
|
|
|
|||
|
|
@ -1,7 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import logging
|
||||
import requests
|
||||
import urllib3
|
||||
import pathvalidate
|
||||
from PIL import Image
|
||||
|
||||
|
||||
logging.basicConfig(level = logging.INFO, format = '%(asctime)s %(levelname)s: %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
output_folder = 'outputs/compare'
|
||||
models = [
|
||||
"sd_xl_base_1.0",
|
||||
"tempestByVlad_baseV01",
|
||||
"huggingface/stabilityai/stable-cascade",
|
||||
"sdxl-base-v10-vaefix",
|
||||
"tempest-by-vlad-0.1",
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
|
|
@ -10,24 +27,97 @@ models = [
|
|||
"vladmandic/chroma-unlocked-v50",
|
||||
"vladmandic/chroma-unlocked-v50-annealed",
|
||||
"Qwen/Qwen-Image",
|
||||
"ostris/Flex.2-preview",
|
||||
"HiDream-ai/HiDream-I1-Dev",
|
||||
"HiDream-ai/HiDream-I1-Full",
|
||||
"briaai/BRIA-3.2",
|
||||
"nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||
"nvidia/Cosmos-Predict2-14B-Text2Image",
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
||||
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
||||
"OmniGen2/OmniGen2",
|
||||
"stabilityai/stable-cascade",
|
||||
"ostris/Flex.2-preview",
|
||||
"Freepik/F-Lite",
|
||||
"Freepik/F-Lite-Texture",
|
||||
"Efficient-Large-Model/SANA1.5_1.6B_1024px_diffusers",
|
||||
"Efficient-Large-Model/SANA1.5_4.8B_1024px_diffusers",
|
||||
"nvidia/Cosmos-Predict2-2B-Text2Image",
|
||||
"nvidia/Cosmos-Predict2-14B-Text2Image",
|
||||
"OmniGen2/OmniGen2",
|
||||
"fal/AuraFlow-v0.3",
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
|
||||
"Alpha-VLLM/Lumina-Image-2.0",
|
||||
"HiDream-ai/HiDream-I1-Dev",
|
||||
"HiDream-ai/HiDream-I1-Full",
|
||||
"Kwai-Kolors/Kolors-diffusers",
|
||||
"briaai/BRIA-3.2",
|
||||
"THUDM/CogView4-6B",
|
||||
"kandinsky-community/kandinsky-2-1",
|
||||
"kandinsky-community/kandinsky-2-2-decoder",
|
||||
"kandinsky-community/kandinsky-3",
|
||||
# "juggernautXL_juggXIByRundiffusion.safetensors@https://civitai.com/api/download/models/782002",
|
||||
# "playground-v2.5-1024px-aesthetic.fp16.safetensors@https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/resolve/main/playground-v2.5-1024px-aesthetic.fp16.safetensors?download=true",
|
||||
]
|
||||
styles = [
|
||||
'Fixed Astronaut',
|
||||
'Fixed Bear',
|
||||
'Fixed Steampunk City',
|
||||
'Fixed Road sign',
|
||||
'Fixed Futuristic hypercar',
|
||||
'Fixed Pirate Ship in Space',
|
||||
'Fixed Fallout girl',
|
||||
'Fixed Kneeling on Bed',
|
||||
'Fixed Girl in Sin City',
|
||||
'Fixed Girl in a city',
|
||||
'Fixed Lady in Tokyo',
|
||||
'Fixed MadMax selfie',
|
||||
'Fixed SDNext Neon',
|
||||
]
|
||||
|
||||
|
||||
def request(endpoint: str, dct: dict = None, method: str = 'POST'):
|
||||
def auth():
|
||||
if sd_username is not None and sd_password is not None:
|
||||
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
||||
return None
|
||||
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860")
|
||||
sd_username = os.environ.get('SDAPI_USR', None)
|
||||
sd_password = os.environ.get('SDAPI_PWD', None)
|
||||
method = requests.post if method.upper() == 'POST' else requests.get
|
||||
req = method(f'{sd_url}{endpoint}', json = dct, timeout=300, verify=False, auth=auth())
|
||||
if req.status_code != 200:
|
||||
return { 'error': req.status_code, 'reason': req.reason, 'url': req.url }
|
||||
else:
|
||||
return req.json()
|
||||
|
||||
|
||||
def generate(): # pylint: disable=redefined-outer-name
|
||||
for m, model in enumerate(models):
|
||||
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
|
||||
log.info(f'model: name="{model}" n={m+1}/{len(models)}')
|
||||
for s, style in enumerate(styles):
|
||||
model_name = pathvalidate.sanitize_filename(model, replacement_text='_')
|
||||
style_name = pathvalidate.sanitize_filename(style, replacement_text='_')
|
||||
fn = os.path.join(output_folder, f'{model_name}__{style_name}.jpg')
|
||||
if os.path.exists(fn):
|
||||
continue
|
||||
request(f'/sdapi/v1/checkpoint?sd_model_checkpoint={model}', method='POST')
|
||||
loaded = request('/sdapi/v1/checkpoint', method='GET')
|
||||
if not (model in loaded.get('checkpoint') or model in loaded.get('title') or model in loaded.get('name')):
|
||||
log.error(f' model: error="{model}"')
|
||||
continue
|
||||
log.info(f' style: name="{style}" n={s+1}/{len(styles)} fn="{fn}"')
|
||||
t0 = time.time()
|
||||
data = request('/sdapi/v1/txt2img', { 'styles': [style] })
|
||||
t1 = time.time()
|
||||
if 'images' in data and len(data['images']) > 0:
|
||||
b64 = data['images'][0].split(',',1)[0]
|
||||
image = Image.open(io.BytesIO(base64.b64decode(b64)))
|
||||
info = data['info']
|
||||
log.info(f' image: size={image.size} time={t1-t0:.2f} info="{len(info)}" fn="{fn}"')
|
||||
image.save(fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.info('test-all-models')
|
||||
log.info(f'output="{output_folder}" models={len(models)} styles={len(styles)}')
|
||||
log.info('start...')
|
||||
generate()
|
||||
log.info('done...')
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ class Api:
|
|||
self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/checkpoint", endpoints.get_checkpoint, methods=["GET"])
|
||||
self.add_api_route("/sdapi/v1/checkpoint", endpoints.set_checkpoint, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
|
||||
|
|
|
|||
|
|
@ -140,6 +140,14 @@ def get_checkpoint():
|
|||
checkpoint['hash'] = shared.sd_model.sd_checkpoint_info.shorthash
|
||||
return checkpoint
|
||||
|
||||
def set_checkpoint(sd_model_checkpoint: str, force:bool=False):
|
||||
from modules import sd_models
|
||||
if force:
|
||||
sd_models.unload_model_weights(op='model')
|
||||
shared.opts.sd_model_checkpoint = sd_model_checkpoint
|
||||
model = sd_models.reload_model_weights()
|
||||
return { 'ok': model is not None }
|
||||
|
||||
def post_refresh_checkpoints():
|
||||
shared.refresh_checkpoints()
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def banned_words(
|
|||
|
||||
def register_api():
|
||||
from modules.shared import api as api_instance
|
||||
api_instance.add_api_route("/sdapi/v1//nudenet", nudenet_censor, methods=["POST"], response_model=dict)
|
||||
api_instance.add_api_route("/sdapi/v1//prompt-lang", prompt_check, methods=["POST"], response_model=dict)
|
||||
api_instance.add_api_route("/sdapi/v1//image-guard", image_guard, methods=["POST"], response_model=dict)
|
||||
api_instance.add_api_route("/sdapi/v1//prompt-banned", banned_words, methods=["POST"], response_model=list)
|
||||
api_instance.add_api_route("/sdapi/v1/nudenet", nudenet_censor, methods=["POST"], response_model=dict)
|
||||
api_instance.add_api_route("/sdapi/v1/prompt-lang", prompt_check, methods=["POST"], response_model=dict)
|
||||
api_instance.add_api_route("/sdapi/v1/image-guard", image_guard, methods=["POST"], response_model=dict)
|
||||
api_instance.add_api_route("/sdapi/v1/prompt-banned", banned_words, methods=["POST"], response_model=list)
|
||||
|
|
|
|||
|
|
@ -201,6 +201,7 @@ def get_closet_checkpoint_match(s: str) -> CheckpointInfo:
|
|||
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
|
||||
checkpoint_info.type = 'huggingface'
|
||||
return checkpoint_info
|
||||
|
||||
if s.startswith('huggingface/'):
|
||||
model_name = s.replace('huggingface/', '')
|
||||
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
|
||||
|
|
|
|||
Loading…
Reference in New Issue