mirror of https://github.com/vladmandic/automatic
test: add api tests for generation params and detailer endpoints
Generation tests cover scheduler params, color grading, and latent corrections. Detailer tests cover model enumeration and object detection. Both require a running SD.Next instance.pull/4694/head
parent
0ec1e9bca2
commit
813b61eda8
|
|
@ -0,0 +1,653 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
API tests for YOLO Detailer endpoints.
|
||||
|
||||
Tests:
|
||||
- GET /sdapi/v1/detailers — model enumeration
|
||||
- POST /sdapi/v1/detect — object detection on test images
|
||||
- POST /sdapi/v1/txt2img — generation with detailer enabled
|
||||
|
||||
Requires a running SD.Next instance with a model loaded.
|
||||
|
||||
Usage:
|
||||
python test/test-detailer-api.py [--url URL] [--image PATH]
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
import argparse
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# Reference model cover images with faces (best for detailer testing)
|
||||
FACE_TEST_IMAGES = [
|
||||
'models/Reference/ponyRealism_V23.jpg', # realistic woman, clear face
|
||||
'models/Reference/HiDream-ai--HiDream-I1-Fast.jpg', # realistic man, clear face + text
|
||||
'models/Reference/stabilityai--stable-diffusion-xl-base-1.0.jpg', # realistic woman portrait
|
||||
'models/Reference/CalamitousFelicitousness--Anima-sdnext-diffusers.jpg', # anime face (non-realistic test)
|
||||
]
|
||||
|
||||
# Fallback images (no guaranteed faces)
|
||||
FALLBACK_IMAGES = [
|
||||
'html/sdnext-robot-2k.jpg',
|
||||
'html/favicon.png',
|
||||
]
|
||||
|
||||
|
||||
class DetailerAPITest:
|
||||
"""Test harness for YOLO Detailer API endpoints."""
|
||||
|
||||
def __init__(self, base_url, image_path=None, timeout=300):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.test_images = {} # name -> base64
|
||||
self.timeout = timeout
|
||||
self.results = {
|
||||
'enumerate': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'detect': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'generate': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'detailer_params': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
}
|
||||
self._category = 'enumerate'
|
||||
self._critical_error = None
|
||||
self._load_images(image_path)
|
||||
|
||||
def _encode_image(self, path):
|
||||
from PIL import Image
|
||||
image = Image.open(path)
|
||||
if image.mode == 'RGBA':
|
||||
image = image.convert('RGB')
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, 'JPEG')
|
||||
return base64.b64encode(buf.getvalue()).decode(), image.size
|
||||
|
||||
def _load_images(self, image_path=None):
|
||||
if image_path and os.path.exists(image_path):
|
||||
b64, size = self._encode_image(image_path)
|
||||
name = os.path.basename(image_path)
|
||||
self.test_images[name] = b64
|
||||
print(f" Test image: {image_path} ({size})")
|
||||
return
|
||||
|
||||
# Load all available face test images
|
||||
for p in FACE_TEST_IMAGES:
|
||||
if os.path.exists(p):
|
||||
b64, size = self._encode_image(p)
|
||||
name = os.path.basename(p)
|
||||
self.test_images[name] = b64
|
||||
print(f" Loaded: {name} ({size[0]}x{size[1]})")
|
||||
|
||||
# Fallback if no face images found
|
||||
if not self.test_images:
|
||||
for p in FALLBACK_IMAGES:
|
||||
if os.path.exists(p):
|
||||
b64, size = self._encode_image(p)
|
||||
name = os.path.basename(p)
|
||||
self.test_images[name] = b64
|
||||
print(f" Fallback: {name} ({size[0]}x{size[1]})")
|
||||
break
|
||||
|
||||
if not self.test_images:
|
||||
print(" WARNING: No test images found, detect tests will be skipped")
|
||||
|
||||
@property
|
||||
def image_b64(self):
|
||||
"""Return the first available test image for backwards compat."""
|
||||
if self.test_images:
|
||||
return next(iter(self.test_images.values()))
|
||||
return None
|
||||
|
||||
def _get(self, endpoint):
|
||||
try:
|
||||
r = requests.get(f'{self.base_url}{endpoint}', timeout=self.timeout, verify=False)
|
||||
if r.status_code != 200:
|
||||
return {'error': r.status_code, 'reason': r.reason}
|
||||
return r.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {'error': 'connection_refused', 'reason': 'Server not running'}
|
||||
except Exception as e:
|
||||
return {'error': 'exception', 'reason': str(e)}
|
||||
|
||||
def _post(self, endpoint, data):
|
||||
try:
|
||||
r = requests.post(f'{self.base_url}{endpoint}', json=data, timeout=self.timeout, verify=False)
|
||||
if r.status_code != 200:
|
||||
return {'error': r.status_code, 'reason': r.reason}
|
||||
return r.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {'error': 'connection_refused', 'reason': 'Server not running'}
|
||||
except Exception as e:
|
||||
return {'error': 'exception', 'reason': str(e)}
|
||||
|
||||
def record(self, passed, name, detail=''):
|
||||
status = 'PASS' if passed else 'FAIL'
|
||||
self.results[self._category]['passed' if passed else 'failed'] += 1
|
||||
self.results[self._category]['tests'].append((status, name))
|
||||
msg = f' {status}: {name}'
|
||||
if detail:
|
||||
msg += f' ({detail})'
|
||||
print(msg)
|
||||
|
||||
def skip(self, name, reason):
|
||||
self.results[self._category]['skipped'] += 1
|
||||
self.results[self._category]['tests'].append(('SKIP', name))
|
||||
print(f' SKIP: {name} ({reason})')
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Model Enumeration
|
||||
# =========================================================================
|
||||
|
||||
def test_detailers_list(self):
|
||||
"""GET /sdapi/v1/detailers returns a list of available models."""
|
||||
self._category = 'enumerate'
|
||||
print("\n--- Detailer Model Enumeration ---")
|
||||
|
||||
data = self._get('/sdapi/v1/detailers')
|
||||
if 'error' in data:
|
||||
self.record(False, 'detailers_list', f"error: {data}")
|
||||
self._critical_error = f"Server error: {data}"
|
||||
return []
|
||||
|
||||
if not isinstance(data, list):
|
||||
self.record(False, 'detailers_list', f"expected list, got {type(data).__name__}")
|
||||
return []
|
||||
|
||||
self.record(True, 'detailers_list', f"{len(data)} models found")
|
||||
|
||||
# Verify each entry has expected fields
|
||||
if len(data) > 0:
|
||||
sample = data[0]
|
||||
has_name = 'name' in sample
|
||||
self.record(has_name, 'detailer_entry_has_name', f"sample: {sample}")
|
||||
if not has_name:
|
||||
self.record(False, 'detailer_entry_schema', "missing 'name' field")
|
||||
|
||||
return data
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Detection
|
||||
# =========================================================================
|
||||
|
||||
def _validate_detect_response(self, data, label):
|
||||
"""Validate detection response schema and return detection count."""
|
||||
expected_keys = ['classes', 'labels', 'boxes', 'scores']
|
||||
for key in expected_keys:
|
||||
if key not in data:
|
||||
self.record(False, f'{label}_schema_{key}', f"missing '{key}'")
|
||||
return -1
|
||||
|
||||
# All arrays should have the same length
|
||||
lengths = [len(data[key]) for key in expected_keys]
|
||||
all_same = len(set(lengths)) <= 1
|
||||
if not all_same:
|
||||
self.record(False, f'{label}_array_lengths', f"mismatched: {dict(zip(expected_keys, lengths))}")
|
||||
return -1
|
||||
|
||||
n = lengths[0]
|
||||
|
||||
if n > 0:
|
||||
# Scores should be in [0, 1]
|
||||
scores_valid = all(0 <= s <= 1 for s in data['scores'])
|
||||
if not scores_valid:
|
||||
self.record(False, f'{label}_scores_range', f"scores: {data['scores']}")
|
||||
|
||||
# Boxes should be lists of 4 numbers
|
||||
boxes_valid = all(isinstance(b, list) and len(b) == 4 for b in data['boxes'])
|
||||
if not boxes_valid:
|
||||
self.record(False, f'{label}_boxes_format', "bad box format")
|
||||
|
||||
return n
|
||||
|
||||
# Face detection models to try (in priority order)
|
||||
FACE_MODELS = ['face-yolo8n', 'face-yolo8m', 'anzhc-face-1024-seg-8n']
|
||||
|
||||
def _pick_face_model(self, available_models):
|
||||
"""Pick the best face detection model from available ones."""
|
||||
available_names = [m.get('name', '') for m in available_models] if available_models else []
|
||||
for model in self.FACE_MODELS:
|
||||
if model in available_names:
|
||||
return model
|
||||
return '' # fall back to server default
|
||||
|
||||
def test_detect_all_images(self, available_models=None):
|
||||
"""POST /sdapi/v1/detect on each loaded test image with a face model."""
|
||||
self._category = 'detect'
|
||||
print("\n--- Detection Tests (per-image) ---")
|
||||
|
||||
if not self.test_images:
|
||||
self.skip('detect_all', 'no test images')
|
||||
return
|
||||
|
||||
if self._critical_error:
|
||||
self.skip('detect_all', self._critical_error)
|
||||
return
|
||||
|
||||
face_model = self._pick_face_model(available_models)
|
||||
if face_model:
|
||||
print(f" Using face model: {face_model}")
|
||||
else:
|
||||
print(" No face model available, using server default")
|
||||
|
||||
total_detections = 0
|
||||
any_face_found = False
|
||||
|
||||
for img_name, img_b64 in self.test_images.items():
|
||||
short = img_name.replace('.jpg', '')[:40]
|
||||
data = self._post('/sdapi/v1/detect', {'image': img_b64, 'model': face_model})
|
||||
|
||||
if 'error' in data:
|
||||
self.record(False, f'detect_{short}', f"error: {data}")
|
||||
continue
|
||||
|
||||
n = self._validate_detect_response(data, f'detect_{short}')
|
||||
if n < 0:
|
||||
continue
|
||||
|
||||
labels = data.get('labels', [])
|
||||
scores = data.get('scores', [])
|
||||
detail_parts = [f"{n} detections"]
|
||||
if labels:
|
||||
detail_parts.append(f"labels={labels}")
|
||||
if scores:
|
||||
detail_parts.append(f"top_score={max(scores):.3f}")
|
||||
|
||||
self.record(True, f'detect_{short}', ', '.join(detail_parts))
|
||||
total_detections += n
|
||||
if n > 0:
|
||||
any_face_found = True
|
||||
|
||||
self.record(any_face_found, 'detect_found_faces',
|
||||
f"{total_detections} total detections across {len(self.test_images)} images")
|
||||
|
||||
def test_detect_with_model(self, model_name):
|
||||
"""POST /sdapi/v1/detect with a specific model on all images."""
|
||||
if not self.test_images:
|
||||
self.skip(f'detect_model_{model_name}', 'no test images')
|
||||
return
|
||||
|
||||
total = 0
|
||||
for _img_name, img_b64 in self.test_images.items():
|
||||
data = self._post('/sdapi/v1/detect', {'image': img_b64, 'model': model_name})
|
||||
if 'error' not in data:
|
||||
total += len(data.get('scores', []))
|
||||
|
||||
self.record(True, f'detect_model_{model_name}', f"{total} detections across {len(self.test_images)} images")
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Generation with Detailer
|
||||
# =========================================================================
|
||||
|
||||
def test_txt2img_with_detailer(self):
|
||||
"""POST /sdapi/v1/txt2img with detailer_enabled=True."""
|
||||
self._category = 'generate'
|
||||
print("\n--- Generation with Detailer ---")
|
||||
|
||||
if self._critical_error:
|
||||
self.skip('txt2img_detailer', self._critical_error)
|
||||
return
|
||||
|
||||
payload = {
|
||||
'prompt': 'a photo of a person, face, portrait',
|
||||
'negative_prompt': '',
|
||||
'steps': 10,
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'seed': 42,
|
||||
'save_images': False,
|
||||
'send_images': True,
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
'detailer_max': 3,
|
||||
}
|
||||
|
||||
t0 = time.time()
|
||||
# Detailer generation is multi-pass (generate + detect + inpaint per region), use longer timeout
|
||||
try:
|
||||
r = requests.post(f'{self.base_url}/sdapi/v1/txt2img', json=payload, timeout=600, verify=False)
|
||||
if r.status_code != 200:
|
||||
data = {'error': r.status_code, 'reason': r.reason}
|
||||
else:
|
||||
data = r.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
self.record(False, 'txt2img_detailer', f"connection error (is a model loaded?): {e}")
|
||||
return
|
||||
except requests.exceptions.ReadTimeout:
|
||||
self.record(False, 'txt2img_detailer', 'timeout after 600s')
|
||||
return
|
||||
t1 = time.time()
|
||||
|
||||
if 'error' in data:
|
||||
self.record(False, 'txt2img_detailer', f"error: {data} (ensure a model is loaded)")
|
||||
return
|
||||
|
||||
# Should have images
|
||||
has_images = 'images' in data and len(data['images']) > 0
|
||||
self.record(has_images, 'txt2img_detailer_has_images', f"time={t1 - t0:.1f}s")
|
||||
|
||||
if has_images:
|
||||
# Decode and verify image
|
||||
from PIL import Image
|
||||
img_data = data['images'][0].split(',', 1)[0]
|
||||
img = Image.open(io.BytesIO(base64.b64decode(img_data)))
|
||||
self.record(True, 'txt2img_detailer_image_valid', f"size={img.size}")
|
||||
|
||||
# Check info field for detailer metadata
|
||||
if 'info' in data:
|
||||
info = data['info'] if isinstance(data['info'], str) else json.dumps(data['info'])
|
||||
has_detailer_info = 'detailer' in info.lower() or 'Detailer' in info
|
||||
self.record(has_detailer_info, 'txt2img_detailer_metadata',
|
||||
'detailer info found in metadata' if has_detailer_info else 'no detailer metadata (detection may have found nothing)')
|
||||
|
||||
def test_txt2img_without_detailer(self):
|
||||
"""POST /sdapi/v1/txt2img baseline without detailer (sanity check)."""
|
||||
if self._critical_error:
|
||||
self.skip('txt2img_baseline', self._critical_error)
|
||||
return
|
||||
|
||||
payload = {
|
||||
'prompt': 'a simple landscape',
|
||||
'steps': 5,
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'seed': 42,
|
||||
'save_images': False,
|
||||
'send_images': True,
|
||||
}
|
||||
|
||||
data = self._post('/sdapi/v1/txt2img', payload)
|
||||
if 'error' in data:
|
||||
self.record(False, 'txt2img_baseline', f"error: {data}")
|
||||
return
|
||||
|
||||
has_images = 'images' in data and len(data['images']) > 0
|
||||
self.record(has_images, 'txt2img_baseline', 'generation works without detailer')
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Per-Request Detailer Param Validation
|
||||
# =========================================================================
|
||||
|
||||
def _txt2img(self, extra_params=None):
|
||||
"""Helper: generate a portrait with optional param overrides."""
|
||||
payload = {
|
||||
'prompt': 'a photo of a person, face, portrait',
|
||||
'steps': 10,
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'seed': 42,
|
||||
'save_images': False,
|
||||
'send_images': True,
|
||||
}
|
||||
if extra_params:
|
||||
payload.update(extra_params)
|
||||
try:
|
||||
r = requests.post(f'{self.base_url}/sdapi/v1/txt2img', json=payload, timeout=600, verify=False)
|
||||
if r.status_code != 200:
|
||||
return {'error': r.status_code, 'reason': r.reason}
|
||||
return r.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
return {'error': 'connection_refused', 'reason': str(e)}
|
||||
except requests.exceptions.ReadTimeout:
|
||||
return {'error': 'timeout', 'reason': 'timeout after 600s'}
|
||||
|
||||
def _decode_image(self, data):
|
||||
"""Decode first image from generation response into numpy array."""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
if 'images' not in data or len(data['images']) == 0:
|
||||
return None
|
||||
img_data = data['images'][0].split(',', 1)[0]
|
||||
img = Image.open(io.BytesIO(base64.b64decode(img_data))).convert('RGB')
|
||||
return np.array(img, dtype=np.float32)
|
||||
|
||||
def _pixel_diff(self, arr_a, arr_b):
|
||||
"""Mean absolute pixel difference between two images."""
|
||||
import numpy as np
|
||||
if arr_a is None or arr_b is None or arr_a.shape != arr_b.shape:
|
||||
return -1.0
|
||||
return float(np.abs(arr_a - arr_b).mean())
|
||||
|
||||
def _get_info(self, data):
|
||||
"""Extract info string from generation response."""
|
||||
if 'info' not in data:
|
||||
return ''
|
||||
info = data['info']
|
||||
return info if isinstance(info, str) else json.dumps(info)
|
||||
|
||||
def run_detailer_param_tests(self, available_models=None):
|
||||
"""Verify per-request detailer params change the output."""
|
||||
self._category = 'detailer_params'
|
||||
print("\n--- Per-Request Detailer Param Validation ---")
|
||||
|
||||
if self._critical_error:
|
||||
self.skip('detailer_params_all', self._critical_error)
|
||||
return
|
||||
|
||||
# Generate baseline WITHOUT detailer (same seed/prompt as detailer tests)
|
||||
print(" Generating baseline (no detailer)...")
|
||||
baseline_data = self._txt2img()
|
||||
if 'error' in baseline_data:
|
||||
self.record(False, 'detailer_baseline', f"error: {baseline_data}")
|
||||
return
|
||||
baseline = self._decode_image(baseline_data)
|
||||
if baseline is None:
|
||||
self.record(False, 'detailer_baseline', 'no image')
|
||||
return
|
||||
self.record(True, 'detailer_baseline')
|
||||
|
||||
# Generate WITH detailer enabled (default params)
|
||||
print(" Generating with detailer (defaults)...")
|
||||
detailer_default_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
})
|
||||
if 'error' in detailer_default_data:
|
||||
self.record(False, 'detailer_default', f"error: {detailer_default_data}")
|
||||
return
|
||||
detailer_default = self._decode_image(detailer_default_data)
|
||||
|
||||
# Detailer ON vs OFF should produce different images (if a face was detected)
|
||||
diff_on_off = self._pixel_diff(baseline, detailer_default)
|
||||
self.record(diff_on_off > 0.5, 'detailer_on_vs_off',
|
||||
f"mean_diff={diff_on_off:.2f}" if diff_on_off > 0.5
|
||||
else f"identical (diff={diff_on_off:.4f}) — no face detected?")
|
||||
|
||||
# -- Strength variation --
|
||||
print(" Testing strength variation...")
|
||||
strong_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.7,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
})
|
||||
if 'error' not in strong_data:
|
||||
strong = self._decode_image(strong_data)
|
||||
diff_strong = self._pixel_diff(detailer_default, strong)
|
||||
self.record(diff_strong > 0.5, 'detailer_strength_effect',
|
||||
f"strength 0.3 vs 0.7: diff={diff_strong:.2f}")
|
||||
|
||||
# -- Steps variation --
|
||||
print(" Testing steps variation...")
|
||||
more_steps_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 20,
|
||||
'detailer_conf': 0.3,
|
||||
})
|
||||
if 'error' not in more_steps_data:
|
||||
more_steps = self._decode_image(more_steps_data)
|
||||
diff_steps = self._pixel_diff(detailer_default, more_steps)
|
||||
self.record(diff_steps > 0.5, 'detailer_steps_effect',
|
||||
f"steps 5 vs 20: diff={diff_steps:.2f}")
|
||||
|
||||
# -- Resolution variation --
|
||||
print(" Testing resolution variation...")
|
||||
hires_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
'detailer_resolution': 512,
|
||||
})
|
||||
if 'error' not in hires_data:
|
||||
hires = self._decode_image(hires_data)
|
||||
diff_res = self._pixel_diff(detailer_default, hires)
|
||||
self.record(diff_res > 0.5, 'detailer_resolution_effect',
|
||||
f"resolution 1024 vs 512: diff={diff_res:.2f}")
|
||||
|
||||
# -- Segmentation mode --
|
||||
# Segmentation requires a -seg model (e.g. anzhc-face-1024-seg-8n).
|
||||
# Detection-only models (face-yolo8n) don't produce masks, so the flag has no effect.
|
||||
seg_models = [m.get('name', '') for m in (available_models or [])
|
||||
if 'seg' in m.get('name', '').lower() and 'face' in m.get('name', '').lower()]
|
||||
if seg_models:
|
||||
seg_model = seg_models[0]
|
||||
print(f" Testing segmentation mode (model={seg_model})...")
|
||||
# bbox baseline with the seg model
|
||||
seg_bbox_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
'detailer_segmentation': False,
|
||||
'detailer_models': [seg_model],
|
||||
})
|
||||
seg_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
'detailer_segmentation': True,
|
||||
'detailer_models': [seg_model],
|
||||
})
|
||||
if 'error' not in seg_data and 'error' not in seg_bbox_data:
|
||||
seg_bbox = self._decode_image(seg_bbox_data)
|
||||
seg_mask = self._decode_image(seg_data)
|
||||
diff_seg = self._pixel_diff(seg_bbox, seg_mask)
|
||||
self.record(diff_seg > 0.5, 'detailer_segmentation_effect',
|
||||
f"bbox vs seg mask ({seg_model}): diff={diff_seg:.2f}")
|
||||
else:
|
||||
err = seg_data if 'error' in seg_data else seg_bbox_data
|
||||
self.record(False, 'detailer_segmentation_effect', f"error: {err}")
|
||||
else:
|
||||
print(" Testing segmentation mode...")
|
||||
seg_data = {'error': 'skipped'}
|
||||
self.skip('detailer_segmentation_effect', 'no face-seg model available')
|
||||
|
||||
# -- Confidence threshold --
|
||||
print(" Testing confidence threshold...")
|
||||
high_conf_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.3,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.95,
|
||||
})
|
||||
if 'error' not in high_conf_data:
|
||||
high_conf = self._decode_image(high_conf_data)
|
||||
diff_conf = self._pixel_diff(baseline, high_conf)
|
||||
# High confidence may reject detections, making output closer to baseline
|
||||
self.record(True, 'detailer_conf_effect',
|
||||
f"conf=0.95 vs baseline: diff={diff_conf:.2f} "
|
||||
f"(low diff = detections filtered out, high diff = still detected)")
|
||||
|
||||
# -- Custom detailer prompt --
|
||||
print(" Testing detailer prompt override...")
|
||||
prompt_data = self._txt2img({
|
||||
'detailer_enabled': True,
|
||||
'detailer_strength': 0.5,
|
||||
'detailer_steps': 5,
|
||||
'detailer_conf': 0.3,
|
||||
'detailer_prompt': 'a detailed close-up face with freckles',
|
||||
})
|
||||
if 'error' not in prompt_data:
|
||||
prompt_result = self._decode_image(prompt_data)
|
||||
diff_prompt = self._pixel_diff(detailer_default, prompt_result)
|
||||
self.record(diff_prompt > 0.5, 'detailer_prompt_effect',
|
||||
f"custom prompt vs default: diff={diff_prompt:.2f}")
|
||||
|
||||
# -- Metadata verification across params --
|
||||
for test_data, label in [
|
||||
(detailer_default_data, 'detailer_default'),
|
||||
(strong_data if 'error' not in strong_data else None, 'detailer_strong'),
|
||||
(more_steps_data if 'error' not in more_steps_data else None, 'detailer_more_steps'),
|
||||
(seg_data if 'error' not in seg_data else None, 'detailer_segmentation'),
|
||||
]:
|
||||
if test_data is None:
|
||||
continue
|
||||
info = self._get_info(test_data)
|
||||
has_meta = 'detailer' in info.lower() or 'Detailer' in info
|
||||
self.record(has_meta, f'{label}_metadata',
|
||||
'detailer info in metadata' if has_meta else 'no detailer metadata')
|
||||
|
||||
# -- Param isolation: generate without detailer after all detailer runs --
|
||||
print(" Testing param isolation...")
|
||||
after_data = self._txt2img()
|
||||
if 'error' not in after_data:
|
||||
after = self._decode_image(after_data)
|
||||
leak_diff = self._pixel_diff(baseline, after)
|
||||
self.record(leak_diff < 0.5, 'detailer_param_isolation',
|
||||
f"post-detailer baseline diff={leak_diff:.4f}" if leak_diff < 0.5
|
||||
else f"LEAK: baseline changed (diff={leak_diff:.2f})")
|
||||
|
||||
# =========================================================================
|
||||
# Runner
|
||||
# =========================================================================
|
||||
|
||||
def run_all(self):
|
||||
print("=" * 60)
|
||||
print("YOLO Detailer API Test Suite")
|
||||
print(f"Server: {self.base_url}")
|
||||
print("=" * 60)
|
||||
|
||||
# Enumerate
|
||||
models = self.test_detailers_list()
|
||||
|
||||
# Detect across all loaded test images
|
||||
self.test_detect_all_images(models)
|
||||
# Test with first available model if any
|
||||
if models and len(models) > 0:
|
||||
model_name = models[0].get('name', models[0].get('filename', ''))
|
||||
if model_name:
|
||||
self.test_detect_with_model(model_name)
|
||||
|
||||
# Generate
|
||||
self.test_txt2img_without_detailer()
|
||||
self.test_txt2img_with_detailer()
|
||||
|
||||
# Per-request detailer param validation
|
||||
self.run_detailer_param_tests(models)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Results")
|
||||
print("=" * 60)
|
||||
total_passed = 0
|
||||
total_failed = 0
|
||||
total_skipped = 0
|
||||
for cat, data in self.results.items():
|
||||
total_passed += data['passed']
|
||||
total_failed += data['failed']
|
||||
total_skipped += data['skipped']
|
||||
status = 'PASS' if data['failed'] == 0 else 'FAIL'
|
||||
print(f" {cat}: {data['passed']} passed, {data['failed']} failed, {data['skipped']} skipped [{status}]")
|
||||
print(f" Total: {total_passed} passed, {total_failed} failed, {total_skipped} skipped")
|
||||
print("=" * 60)
|
||||
return total_failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='YOLO Detailer API Tests')
|
||||
parser.add_argument('--url', default=os.environ.get('SDAPI_URL', 'http://127.0.0.1:7860'), help='server URL')
|
||||
parser.add_argument('--image', default=None, help='test image path')
|
||||
args = parser.parse_args()
|
||||
test = DetailerAPITest(args.url, args.image)
|
||||
success = test.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -0,0 +1,615 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
API tests for generation with scheduler params, color grading, and latent corrections.
|
||||
|
||||
Tests:
|
||||
- GET /sdapi/v1/samplers — sampler enumeration and config
|
||||
- POST /sdapi/v1/txt2img — generation with various samplers
|
||||
- POST /sdapi/v1/txt2img — generation with color grading params
|
||||
- POST /sdapi/v1/txt2img — generation with latent correction params
|
||||
|
||||
Requires a running SD.Next instance with a model loaded.
|
||||
|
||||
Usage:
|
||||
python test/test-generation-api.py [--url URL] [--steps STEPS]
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import base64
|
||||
import argparse
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
class GenerationAPITest:
|
||||
"""Test harness for generation API with scheduler and grading params."""
|
||||
|
||||
# Samplers to test — a representative subset covering different scheduler families
|
||||
TEST_SAMPLERS = [
|
||||
'Euler a',
|
||||
'Euler',
|
||||
'DPM++ 2M',
|
||||
'UniPC',
|
||||
'DDIM',
|
||||
'DPM++ 2M SDE',
|
||||
]
|
||||
|
||||
def __init__(self, base_url, steps=10, timeout=300):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.steps = steps
|
||||
self.timeout = timeout
|
||||
self.results = {
|
||||
'samplers': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'generation': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'grading': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'correction': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
'param_validation': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
|
||||
}
|
||||
self._category = 'samplers'
|
||||
self._critical_error = None
|
||||
|
||||
def _get(self, endpoint):
|
||||
try:
|
||||
r = requests.get(f'{self.base_url}{endpoint}', timeout=self.timeout, verify=False)
|
||||
if r.status_code != 200:
|
||||
return {'error': r.status_code, 'reason': r.reason}
|
||||
return r.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {'error': 'connection_refused', 'reason': 'Server not running'}
|
||||
except Exception as e:
|
||||
return {'error': 'exception', 'reason': str(e)}
|
||||
|
||||
def _post(self, endpoint, data):
|
||||
try:
|
||||
r = requests.post(f'{self.base_url}{endpoint}', json=data, timeout=self.timeout, verify=False)
|
||||
if r.status_code != 200:
|
||||
return {'error': r.status_code, 'reason': r.reason}
|
||||
return r.json()
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {'error': 'connection_refused', 'reason': 'Server not running'}
|
||||
except Exception as e:
|
||||
return {'error': 'exception', 'reason': str(e)}
|
||||
|
||||
def record(self, passed, name, detail=''):
|
||||
status = 'PASS' if passed else 'FAIL'
|
||||
self.results[self._category]['passed' if passed else 'failed'] += 1
|
||||
self.results[self._category]['tests'].append((status, name))
|
||||
msg = f' {status}: {name}'
|
||||
if detail:
|
||||
msg += f' ({detail})'
|
||||
print(msg)
|
||||
|
||||
def skip(self, name, reason):
|
||||
self.results[self._category]['skipped'] += 1
|
||||
self.results[self._category]['tests'].append(('SKIP', name))
|
||||
print(f' SKIP: {name} ({reason})')
|
||||
|
||||
def _txt2img(self, extra_params=None, prompt='a cat'):
|
||||
"""Helper: run txt2img with base params + overrides. Returns (data, time)."""
|
||||
payload = {
|
||||
'prompt': prompt,
|
||||
'steps': self.steps,
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'seed': 42,
|
||||
'save_images': False,
|
||||
'send_images': True,
|
||||
}
|
||||
if extra_params:
|
||||
payload.update(extra_params)
|
||||
t0 = time.time()
|
||||
data = self._post('/sdapi/v1/txt2img', payload)
|
||||
return data, time.time() - t0
|
||||
|
||||
def _check_generation(self, data, test_name, elapsed):
|
||||
"""Validate a generation response has images."""
|
||||
if 'error' in data:
|
||||
self.record(False, test_name, f"error: {data}")
|
||||
return False
|
||||
has_images = 'images' in data and len(data['images']) > 0
|
||||
self.record(has_images, test_name, f"time={elapsed:.1f}s")
|
||||
return has_images
|
||||
|
||||
def _get_info(self, data):
|
||||
"""Extract info string from generation response."""
|
||||
if 'info' not in data:
|
||||
return ''
|
||||
info = data['info']
|
||||
return info if isinstance(info, str) else json.dumps(info)
|
||||
|
||||
def _decode_image(self, data):
|
||||
"""Decode first image from generation response into numpy array."""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
if 'images' not in data or len(data['images']) == 0:
|
||||
return None
|
||||
img_data = data['images'][0].split(',', 1)[0]
|
||||
img = Image.open(io.BytesIO(base64.b64decode(img_data))).convert('RGB')
|
||||
return np.array(img, dtype=np.float32)
|
||||
|
||||
def _pixel_diff(self, arr_a, arr_b):
|
||||
"""Mean absolute pixel difference between two images (0-255 scale)."""
|
||||
import numpy as np
|
||||
if arr_a is None or arr_b is None:
|
||||
return -1.0
|
||||
if arr_a.shape != arr_b.shape:
|
||||
return -1.0
|
||||
return float(np.abs(arr_a - arr_b).mean())
|
||||
|
||||
def _channel_means(self, arr):
|
||||
"""Return per-channel means [R, G, B]."""
|
||||
if arr is None:
|
||||
return [0, 0, 0]
|
||||
return [float(arr[:, :, c].mean()) for c in range(3)]
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Sampler Enumeration
|
||||
# =========================================================================
|
||||
|
||||
def test_samplers_list(self):
|
||||
"""GET /sdapi/v1/samplers returns available samplers with config."""
|
||||
self._category = 'samplers'
|
||||
print("\n--- Sampler Enumeration ---")
|
||||
|
||||
data = self._get('/sdapi/v1/samplers')
|
||||
if 'error' in data:
|
||||
self.record(False, 'samplers_list', f"error: {data}")
|
||||
self._critical_error = f"Server error: {data}"
|
||||
return []
|
||||
|
||||
if not isinstance(data, list):
|
||||
self.record(False, 'samplers_list', f"expected list, got {type(data).__name__}")
|
||||
return []
|
||||
|
||||
self.record(True, 'samplers_list', f"{len(data)} samplers available")
|
||||
|
||||
# Check that each sampler has a name
|
||||
sampler_names = []
|
||||
for s in data:
|
||||
name = s.get('name', '')
|
||||
if name:
|
||||
sampler_names.append(name)
|
||||
|
||||
self.record(len(sampler_names) == len(data), 'samplers_have_names',
|
||||
f"{len(sampler_names)}/{len(data)} have names")
|
||||
|
||||
# Check for our test samplers
|
||||
for test_sampler in self.TEST_SAMPLERS:
|
||||
found = test_sampler in sampler_names
|
||||
if not found:
|
||||
self.skip(f'sampler_available_{test_sampler}', 'not in server sampler list')
|
||||
else:
|
||||
self.record(True, f'sampler_available_{test_sampler}')
|
||||
|
||||
return sampler_names
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Generation with Different Samplers
|
||||
# =========================================================================
|
||||
|
||||
def test_samplers_generate(self, available_samplers):
|
||||
"""Generate with each test sampler and verify success."""
|
||||
self._category = 'generation'
|
||||
print("\n--- Generation with Different Samplers ---")
|
||||
|
||||
if self._critical_error:
|
||||
for s in self.TEST_SAMPLERS:
|
||||
self.skip(f'generate_{s}', self._critical_error)
|
||||
return
|
||||
|
||||
for sampler in self.TEST_SAMPLERS:
|
||||
if sampler not in available_samplers:
|
||||
self.skip(f'generate_{sampler}', 'sampler not available')
|
||||
continue
|
||||
data, elapsed = self._txt2img({'sampler_name': sampler})
|
||||
self._check_generation(data, f'generate_{sampler}', elapsed)
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Color Grading Params
|
||||
# =========================================================================
|
||||
|
||||
def test_grading_brightness_contrast(self):
|
||||
"""Generate with grading brightness and contrast."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_brightness': 0.2,
|
||||
'grading_contrast': 0.3,
|
||||
})
|
||||
self._check_generation(data, 'grading_brightness_contrast', elapsed)
|
||||
|
||||
def test_grading_saturation_hue(self):
|
||||
"""Generate with grading saturation and hue shift."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_saturation': 0.5,
|
||||
'grading_hue': 0.1,
|
||||
})
|
||||
self._check_generation(data, 'grading_saturation_hue', elapsed)
|
||||
|
||||
def test_grading_gamma_sharpness(self):
|
||||
"""Generate with gamma correction and sharpness."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_gamma': 0.8,
|
||||
'grading_sharpness': 0.5,
|
||||
})
|
||||
self._check_generation(data, 'grading_gamma_sharpness', elapsed)
|
||||
|
||||
def test_grading_color_temp(self):
|
||||
"""Generate with warm color temperature."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_color_temp': 3500,
|
||||
})
|
||||
self._check_generation(data, 'grading_color_temp', elapsed)
|
||||
|
||||
def test_grading_tone(self):
|
||||
"""Generate with shadows/midtones/highlights adjustments."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_shadows': 0.3,
|
||||
'grading_midtones': -0.1,
|
||||
'grading_highlights': 0.2,
|
||||
})
|
||||
self._check_generation(data, 'grading_tone', elapsed)
|
||||
|
||||
def test_grading_effects(self):
|
||||
"""Generate with vignette and grain."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_vignette': 0.5,
|
||||
'grading_grain': 0.3,
|
||||
})
|
||||
self._check_generation(data, 'grading_effects', elapsed)
|
||||
|
||||
def test_grading_split_toning(self):
|
||||
"""Generate with split toning colors."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_shadows_tint': '#003366',
|
||||
'grading_highlights_tint': '#ffcc00',
|
||||
'grading_split_tone_balance': 0.6,
|
||||
})
|
||||
self._check_generation(data, 'grading_split_toning', elapsed)
|
||||
|
||||
def test_grading_combined(self):
|
||||
"""Generate with multiple grading params at once."""
|
||||
data, elapsed = self._txt2img({
|
||||
'grading_brightness': 0.1,
|
||||
'grading_contrast': 0.2,
|
||||
'grading_saturation': 0.3,
|
||||
'grading_gamma': 0.9,
|
||||
'grading_color_temp': 5000,
|
||||
'grading_vignette': 0.3,
|
||||
})
|
||||
self._check_generation(data, 'grading_combined', elapsed)
|
||||
|
||||
def run_grading_tests(self):
|
||||
"""Run all grading tests."""
|
||||
self._category = 'grading'
|
||||
print("\n--- Color Grading Tests ---")
|
||||
|
||||
if self._critical_error:
|
||||
self.skip('grading_all', self._critical_error)
|
||||
return
|
||||
|
||||
self.test_grading_brightness_contrast()
|
||||
self.test_grading_saturation_hue()
|
||||
self.test_grading_gamma_sharpness()
|
||||
self.test_grading_color_temp()
|
||||
self.test_grading_tone()
|
||||
self.test_grading_effects()
|
||||
self.test_grading_split_toning()
|
||||
self.test_grading_combined()
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Latent Correction Params
|
||||
# =========================================================================
|
||||
|
||||
def test_correction_brightness(self):
|
||||
"""Generate with latent brightness correction."""
|
||||
data, elapsed = self._txt2img({'hdr_brightness': 1.5})
|
||||
ok = self._check_generation(data, 'correction_brightness', elapsed)
|
||||
if ok:
|
||||
info = self._get_info(data)
|
||||
has_param = 'Latent brightness' in info
|
||||
self.record(has_param, 'correction_brightness_metadata',
|
||||
'found in info' if has_param else 'not found in info')
|
||||
|
||||
def test_correction_color(self):
|
||||
"""Generate with latent color centering."""
|
||||
data, elapsed = self._txt2img({'hdr_color': 0.5, 'hdr_mode': 1})
|
||||
ok = self._check_generation(data, 'correction_color', elapsed)
|
||||
if ok:
|
||||
info = self._get_info(data)
|
||||
has_param = 'Latent color' in info
|
||||
self.record(has_param, 'correction_color_metadata',
|
||||
'found in info' if has_param else 'not found in info')
|
||||
|
||||
def test_correction_clamp(self):
|
||||
"""Generate with latent clamping."""
|
||||
data, elapsed = self._txt2img({
|
||||
'hdr_clamp': True,
|
||||
'hdr_threshold': 0.8,
|
||||
'hdr_boundary': 4.0,
|
||||
})
|
||||
ok = self._check_generation(data, 'correction_clamp', elapsed)
|
||||
if ok:
|
||||
info = self._get_info(data)
|
||||
has_param = 'Latent clamp' in info
|
||||
self.record(has_param, 'correction_clamp_metadata',
|
||||
'found in info' if has_param else 'not found in info')
|
||||
|
||||
def test_correction_sharpen(self):
|
||||
"""Generate with latent sharpening."""
|
||||
data, elapsed = self._txt2img({'hdr_sharpen': 1.0})
|
||||
ok = self._check_generation(data, 'correction_sharpen', elapsed)
|
||||
if ok:
|
||||
info = self._get_info(data)
|
||||
has_param = 'Latent sharpen' in info
|
||||
self.record(has_param, 'correction_sharpen_metadata',
|
||||
'found in info' if has_param else 'not found in info')
|
||||
|
||||
def test_correction_maximize(self):
|
||||
"""Generate with latent maximize/normalize."""
|
||||
data, elapsed = self._txt2img({
|
||||
'hdr_maximize': True,
|
||||
'hdr_max_center': 0.6,
|
||||
'hdr_max_boundary': 2.0,
|
||||
})
|
||||
ok = self._check_generation(data, 'correction_maximize', elapsed)
|
||||
if ok:
|
||||
info = self._get_info(data)
|
||||
has_param = 'Latent max' in info
|
||||
self.record(has_param, 'correction_maximize_metadata',
|
||||
'found in info' if has_param else 'not found in info')
|
||||
|
||||
def test_correction_combined(self):
|
||||
"""Generate with multiple correction params."""
|
||||
data, elapsed = self._txt2img({
|
||||
'hdr_brightness': 1.0,
|
||||
'hdr_color': 0.3,
|
||||
'hdr_sharpen': 0.5,
|
||||
'hdr_clamp': True,
|
||||
})
|
||||
ok = self._check_generation(data, 'correction_combined', elapsed)
|
||||
if ok:
|
||||
info = self._get_info(data)
|
||||
# At least some correction params should appear
|
||||
found = [k for k in ['Latent brightness', 'Latent color', 'Latent sharpen', 'Latent clamp'] if k in info]
|
||||
self.record(len(found) > 0, 'correction_combined_metadata', f"found: {found}")
|
||||
|
||||
def run_correction_tests(self):
|
||||
"""Run all latent correction tests."""
|
||||
self._category = 'correction'
|
||||
print("\n--- Latent Correction Tests ---")
|
||||
|
||||
if self._critical_error:
|
||||
self.skip('correction_all', self._critical_error)
|
||||
return
|
||||
|
||||
self.test_correction_brightness()
|
||||
self.test_correction_color()
|
||||
self.test_correction_clamp()
|
||||
self.test_correction_sharpen()
|
||||
self.test_correction_maximize()
|
||||
self.test_correction_combined()
|
||||
|
||||
# =========================================================================
|
||||
# Tests: Per-Request Param Validation (baseline comparison)
|
||||
# =========================================================================
|
||||
|
||||
def _generate_baseline(self):
|
||||
"""Generate a baseline image with no grading/correction. Cache and reuse."""
|
||||
if hasattr(self, '_baseline_arr') and self._baseline_arr is not None:
|
||||
return self._baseline_arr, self._baseline_data
|
||||
data, elapsed = self._txt2img()
|
||||
if 'error' in data or 'images' not in data:
|
||||
return None, data
|
||||
self._baseline_arr = self._decode_image(data)
|
||||
self._baseline_data = data
|
||||
print(f' Baseline generated: time={elapsed:.1f}s mean={self._channel_means(self._baseline_arr)}')
|
||||
return self._baseline_arr, data
|
||||
|
||||
def _compare_param(self, name, params, check_fn=None):
|
||||
"""Generate with params and compare to baseline. Optionally run check_fn(baseline, result)."""
|
||||
baseline, _ = self._generate_baseline()
|
||||
if baseline is None:
|
||||
self.skip(f'param_{name}', 'baseline generation failed')
|
||||
return
|
||||
|
||||
data, elapsed = self._txt2img(params)
|
||||
if 'error' in data:
|
||||
self.record(False, f'param_{name}', f"generation error: {data}")
|
||||
return
|
||||
|
||||
result = self._decode_image(data)
|
||||
if result is None:
|
||||
self.record(False, f'param_{name}', 'no image in response')
|
||||
return
|
||||
|
||||
diff = self._pixel_diff(baseline, result)
|
||||
differs = diff > 0.5 # more than 0.5/255 mean difference
|
||||
self.record(differs, f'param_{name}_differs',
|
||||
f"mean_diff={diff:.2f}" if differs else f"images identical (diff={diff:.4f})")
|
||||
|
||||
if check_fn and differs:
|
||||
try:
|
||||
ok, detail = check_fn(baseline, result, data)
|
||||
self.record(ok, f'param_{name}_direction', detail)
|
||||
except Exception as e:
|
||||
self.record(False, f'param_{name}_direction', f"check error: {e}")
|
||||
|
||||
def run_param_validation_tests(self):
|
||||
"""Verify per-request grading/correction params actually change the output."""
|
||||
self._category = 'param_validation'
|
||||
print("\n--- Per-Request Param Validation ---")
|
||||
|
||||
if self._critical_error:
|
||||
self.skip('param_validation_all', self._critical_error)
|
||||
return
|
||||
|
||||
import numpy as np
|
||||
|
||||
# -- Grading params --
|
||||
|
||||
# Brightness: positive should increase mean pixel value
|
||||
def check_brightness(base, result, _data):
|
||||
base_mean = float(base.mean())
|
||||
result_mean = float(result.mean())
|
||||
return result_mean > base_mean, f"baseline={base_mean:.1f} graded={result_mean:.1f}"
|
||||
self._compare_param('grading_brightness', {'grading_brightness': 0.3}, check_brightness)
|
||||
|
||||
# Contrast: should increase standard deviation
|
||||
def check_contrast(base, result, _data):
|
||||
return float(result.std()) > float(base.std()), \
|
||||
f"baseline_std={float(base.std()):.1f} graded_std={float(result.std()):.1f}"
|
||||
self._compare_param('grading_contrast', {'grading_contrast': 0.5}, check_contrast)
|
||||
|
||||
# Saturation: desaturation should reduce color channel spread
|
||||
def check_desaturation(base, result, _data):
|
||||
base_spread = max(self._channel_means(base)) - min(self._channel_means(base))
|
||||
result_spread = max(self._channel_means(result)) - min(self._channel_means(result))
|
||||
return result_spread < base_spread, \
|
||||
f"baseline_spread={base_spread:.1f} graded_spread={result_spread:.1f}"
|
||||
self._compare_param('grading_saturation_neg', {'grading_saturation': -0.5}, check_desaturation)
|
||||
|
||||
# Hue shift: just verify it changes
|
||||
self._compare_param('grading_hue', {'grading_hue': 0.2})
|
||||
|
||||
# Gamma < 1: should brighten (raise values that are < 1)
|
||||
def check_gamma(base, result, _data):
|
||||
return float(result.mean()) > float(base.mean()), \
|
||||
f"baseline={float(base.mean()):.1f} gamma={float(result.mean()):.1f}"
|
||||
self._compare_param('grading_gamma', {'grading_gamma': 0.7}, check_gamma)
|
||||
|
||||
# Sharpness: just verify it changes
|
||||
self._compare_param('grading_sharpness', {'grading_sharpness': 0.8})
|
||||
|
||||
# Color temperature warm: red channel mean should increase relative to blue
|
||||
def check_warm(base, result, _data):
|
||||
base_r, _, base_b = self._channel_means(base)
|
||||
res_r, _, res_b = self._channel_means(result)
|
||||
base_rb = base_r - base_b
|
||||
res_rb = res_r - res_b
|
||||
return res_rb > base_rb, f"baseline R-B={base_rb:.1f} warm R-B={res_rb:.1f}"
|
||||
self._compare_param('grading_color_temp_warm', {'grading_color_temp': 3000}, check_warm)
|
||||
|
||||
# Color temperature cool: blue should increase relative to red
|
||||
def check_cool(base, result, _data):
|
||||
base_r, _, base_b = self._channel_means(base)
|
||||
res_r, _, res_b = self._channel_means(result)
|
||||
base_rb = base_r - base_b
|
||||
res_rb = res_r - res_b
|
||||
return res_rb < base_rb, f"baseline R-B={base_rb:.1f} cool R-B={res_rb:.1f}"
|
||||
self._compare_param('grading_color_temp_cool', {'grading_color_temp': 10000}, check_cool)
|
||||
|
||||
# Vignette: corners should be darker than baseline corners
|
||||
def check_vignette(base, result, _data):
|
||||
h, w = base.shape[:2]
|
||||
corner_size = h // 8
|
||||
base_corners = np.concatenate([
|
||||
base[:corner_size, :corner_size].flatten(),
|
||||
base[:corner_size, -corner_size:].flatten(),
|
||||
base[-corner_size:, :corner_size].flatten(),
|
||||
base[-corner_size:, -corner_size:].flatten(),
|
||||
])
|
||||
result_corners = np.concatenate([
|
||||
result[:corner_size, :corner_size].flatten(),
|
||||
result[:corner_size, -corner_size:].flatten(),
|
||||
result[-corner_size:, :corner_size].flatten(),
|
||||
result[-corner_size:, -corner_size:].flatten(),
|
||||
])
|
||||
return float(result_corners.mean()) < float(base_corners.mean()), \
|
||||
f"baseline_corners={float(base_corners.mean()):.1f} vignette_corners={float(result_corners.mean()):.1f}"
|
||||
self._compare_param('grading_vignette', {'grading_vignette': 0.8}, check_vignette)
|
||||
|
||||
# Grain: just verify it changes (stochastic)
|
||||
self._compare_param('grading_grain', {'grading_grain': 0.5})
|
||||
|
||||
# Shadows/midtones/highlights: verify changes
|
||||
self._compare_param('grading_shadows', {'grading_shadows': 0.5})
|
||||
self._compare_param('grading_highlights', {'grading_highlights': -0.3})
|
||||
|
||||
# CLAHE: should increase local contrast
|
||||
self._compare_param('grading_clahe', {'grading_clahe_clip': 2.0})
|
||||
|
||||
# Split toning: verify changes
|
||||
self._compare_param('grading_split_toning', {
|
||||
'grading_shadows_tint': '#003366',
|
||||
'grading_highlights_tint': '#ffcc00',
|
||||
})
|
||||
|
||||
# -- Correction params --
|
||||
|
||||
# Latent brightness: should change output and appear in metadata
|
||||
def check_correction_meta(key):
|
||||
def _check(_base, _result, data):
|
||||
info = self._get_info(data)
|
||||
return key in info, f"'{key}' {'found' if key in info else 'missing'} in info"
|
||||
return _check
|
||||
self._compare_param('hdr_brightness', {'hdr_brightness': 2.0}, check_correction_meta('Latent brightness'))
|
||||
self._compare_param('hdr_color', {'hdr_color': 0.8, 'hdr_mode': 1}, check_correction_meta('Latent color'))
|
||||
self._compare_param('hdr_sharpen', {'hdr_sharpen': 1.5}, check_correction_meta('Latent sharpen'))
|
||||
self._compare_param('hdr_clamp', {'hdr_clamp': True, 'hdr_threshold': 0.7}, check_correction_meta('Latent clamp'))
|
||||
|
||||
# Isolation: verify params from one request don't leak to the next
|
||||
data_after, _ = self._txt2img()
|
||||
arr_after = self._decode_image(data_after)
|
||||
baseline, _ = self._generate_baseline()
|
||||
if baseline is not None and arr_after is not None:
|
||||
leak_diff = self._pixel_diff(baseline, arr_after)
|
||||
no_leak = leak_diff < 0.5
|
||||
self.record(no_leak, 'param_isolation',
|
||||
f"post-grading baseline diff={leak_diff:.4f}" if no_leak
|
||||
else f"LEAK: baseline changed after grading requests (diff={leak_diff:.2f})")
|
||||
|
||||
# =========================================================================
|
||||
# Runner
|
||||
# =========================================================================
|
||||
|
||||
def run_all(self):
|
||||
print("=" * 60)
|
||||
print("Generation API Test Suite")
|
||||
print(f"Server: {self.base_url}")
|
||||
print(f"Steps: {self.steps}")
|
||||
print("=" * 60)
|
||||
|
||||
# Samplers
|
||||
available = self.test_samplers_list()
|
||||
self.test_samplers_generate(available)
|
||||
|
||||
# Grading
|
||||
self.run_grading_tests()
|
||||
|
||||
# Corrections
|
||||
self.run_correction_tests()
|
||||
|
||||
# Per-request param validation (baseline comparison)
|
||||
self.run_param_validation_tests()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Results")
|
||||
print("=" * 60)
|
||||
total_passed = 0
|
||||
total_failed = 0
|
||||
total_skipped = 0
|
||||
for cat, data in self.results.items():
|
||||
total_passed += data['passed']
|
||||
total_failed += data['failed']
|
||||
total_skipped += data['skipped']
|
||||
status = 'PASS' if data['failed'] == 0 else 'FAIL'
|
||||
print(f" {cat}: {data['passed']} passed, {data['failed']} failed, {data['skipped']} skipped [{status}]")
|
||||
print(f" Total: {total_passed} passed, {total_failed} failed, {total_skipped} skipped")
|
||||
print("=" * 60)
|
||||
return total_failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Generation API Tests (samplers, grading, correction)')
|
||||
parser.add_argument('--url', default=os.environ.get('SDAPI_URL', 'http://127.0.0.1:7860'), help='server URL')
|
||||
parser.add_argument('--steps', type=int, default=10, help='generation steps (lower = faster tests)')
|
||||
args = parser.parse_args()
|
||||
test = GenerationAPITest(args.url, args.steps)
|
||||
success = test.run_all()
|
||||
sys.exit(0 if success else 1)
|
||||
Loading…
Reference in New Issue