diff --git a/test/test-detailer-api.py b/test/test-detailer-api.py new file mode 100644 index 000000000..acc4622d7 --- /dev/null +++ b/test/test-detailer-api.py @@ -0,0 +1,653 @@ +#!/usr/bin/env python +""" +API tests for YOLO Detailer endpoints. + +Tests: +- GET /sdapi/v1/detailers — model enumeration +- POST /sdapi/v1/detect — object detection on test images +- POST /sdapi/v1/txt2img — generation with detailer enabled + +Requires a running SD.Next instance with a model loaded. + +Usage: + python test/test-detailer-api.py [--url URL] [--image PATH] +""" + +import io +import os +import sys +import time +import json +import base64 +import argparse +import requests +import urllib3 + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +# Reference model cover images with faces (best for detailer testing) +FACE_TEST_IMAGES = [ + 'models/Reference/ponyRealism_V23.jpg', # realistic woman, clear face + 'models/Reference/HiDream-ai--HiDream-I1-Fast.jpg', # realistic man, clear face + text + 'models/Reference/stabilityai--stable-diffusion-xl-base-1.0.jpg', # realistic woman portrait + 'models/Reference/CalamitousFelicitousness--Anima-sdnext-diffusers.jpg', # anime face (non-realistic test) +] + +# Fallback images (no guaranteed faces) +FALLBACK_IMAGES = [ + 'html/sdnext-robot-2k.jpg', + 'html/favicon.png', +] + + +class DetailerAPITest: + """Test harness for YOLO Detailer API endpoints.""" + + def __init__(self, base_url, image_path=None, timeout=300): + self.base_url = base_url.rstrip('/') + self.test_images = {} # name -> base64 + self.timeout = timeout + self.results = { + 'enumerate': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'detect': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'generate': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'detailer_params': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + } + self._category = 'enumerate' + self._critical_error = None + self._load_images(image_path) + + def _encode_image(self, path): + from PIL import Image + image = Image.open(path) + if image.mode == 'RGBA': + image = image.convert('RGB') + buf = io.BytesIO() + image.save(buf, 'JPEG') + return base64.b64encode(buf.getvalue()).decode(), image.size + + def _load_images(self, image_path=None): + if image_path and os.path.exists(image_path): + b64, size = self._encode_image(image_path) + name = os.path.basename(image_path) + self.test_images[name] = b64 + print(f" Test image: {image_path} ({size})") + return + + # Load all available face test images + for p in FACE_TEST_IMAGES: + if os.path.exists(p): + b64, size = self._encode_image(p) + name = os.path.basename(p) + self.test_images[name] = b64 + print(f" Loaded: {name} ({size[0]}x{size[1]})") + + # Fallback if no face images found + if not self.test_images: + for p in FALLBACK_IMAGES: + if os.path.exists(p): + b64, size = self._encode_image(p) + name = os.path.basename(p) + self.test_images[name] = b64 + print(f" Fallback: {name} ({size[0]}x{size[1]})") + break + + if not self.test_images: + print(" WARNING: No test images found, detect tests will be skipped") + + @property + def image_b64(self): + """Return the first available test image for backwards compat.""" + if self.test_images: + return next(iter(self.test_images.values())) + return None + + def _get(self, endpoint): + try: + r = requests.get(f'{self.base_url}{endpoint}', timeout=self.timeout, verify=False) + if r.status_code != 200: + return {'error': r.status_code, 'reason': r.reason} + return r.json() + except requests.exceptions.ConnectionError: + return {'error': 'connection_refused', 'reason': 'Server not running'} + except Exception as e: + return {'error': 'exception', 'reason': str(e)} + + def _post(self, endpoint, data): + try: + r = requests.post(f'{self.base_url}{endpoint}', json=data, timeout=self.timeout, verify=False) + if r.status_code != 200: + return {'error': r.status_code, 'reason': r.reason} + return r.json() + except requests.exceptions.ConnectionError: + return {'error': 'connection_refused', 'reason': 'Server not running'} + except Exception as e: + return {'error': 'exception', 'reason': str(e)} + + def record(self, passed, name, detail=''): + status = 'PASS' if passed else 'FAIL' + self.results[self._category]['passed' if passed else 'failed'] += 1 + self.results[self._category]['tests'].append((status, name)) + msg = f' {status}: {name}' + if detail: + msg += f' ({detail})' + print(msg) + + def skip(self, name, reason): + self.results[self._category]['skipped'] += 1 + self.results[self._category]['tests'].append(('SKIP', name)) + print(f' SKIP: {name} ({reason})') + + # ========================================================================= + # Tests: Model Enumeration + # ========================================================================= + + def test_detailers_list(self): + """GET /sdapi/v1/detailers returns a list of available models.""" + self._category = 'enumerate' + print("\n--- Detailer Model Enumeration ---") + + data = self._get('/sdapi/v1/detailers') + if 'error' in data: + self.record(False, 'detailers_list', f"error: {data}") + self._critical_error = f"Server error: {data}" + return [] + + if not isinstance(data, list): + self.record(False, 'detailers_list', f"expected list, got {type(data).__name__}") + return [] + + self.record(True, 'detailers_list', f"{len(data)} models found") + + # Verify each entry has expected fields + if len(data) > 0: + sample = data[0] + has_name = 'name' in sample + self.record(has_name, 'detailer_entry_has_name', f"sample: {sample}") + if not has_name: + self.record(False, 'detailer_entry_schema', "missing 'name' field") + + return data + + # ========================================================================= + # Tests: Detection + # ========================================================================= + + def _validate_detect_response(self, data, label): + """Validate detection response schema and return detection count.""" + expected_keys = ['classes', 'labels', 'boxes', 'scores'] + for key in expected_keys: + if key not in data: + self.record(False, f'{label}_schema_{key}', f"missing '{key}'") + return -1 + + # All arrays should have the same length + lengths = [len(data[key]) for key in expected_keys] + all_same = len(set(lengths)) <= 1 + if not all_same: + self.record(False, f'{label}_array_lengths', f"mismatched: {dict(zip(expected_keys, lengths))}") + return -1 + + n = lengths[0] + + if n > 0: + # Scores should be in [0, 1] + scores_valid = all(0 <= s <= 1 for s in data['scores']) + if not scores_valid: + self.record(False, f'{label}_scores_range', f"scores: {data['scores']}") + + # Boxes should be lists of 4 numbers + boxes_valid = all(isinstance(b, list) and len(b) == 4 for b in data['boxes']) + if not boxes_valid: + self.record(False, f'{label}_boxes_format', "bad box format") + + return n + + # Face detection models to try (in priority order) + FACE_MODELS = ['face-yolo8n', 'face-yolo8m', 'anzhc-face-1024-seg-8n'] + + def _pick_face_model(self, available_models): + """Pick the best face detection model from available ones.""" + available_names = [m.get('name', '') for m in available_models] if available_models else [] + for model in self.FACE_MODELS: + if model in available_names: + return model + return '' # fall back to server default + + def test_detect_all_images(self, available_models=None): + """POST /sdapi/v1/detect on each loaded test image with a face model.""" + self._category = 'detect' + print("\n--- Detection Tests (per-image) ---") + + if not self.test_images: + self.skip('detect_all', 'no test images') + return + + if self._critical_error: + self.skip('detect_all', self._critical_error) + return + + face_model = self._pick_face_model(available_models) + if face_model: + print(f" Using face model: {face_model}") + else: + print(" No face model available, using server default") + + total_detections = 0 + any_face_found = False + + for img_name, img_b64 in self.test_images.items(): + short = img_name.replace('.jpg', '')[:40] + data = self._post('/sdapi/v1/detect', {'image': img_b64, 'model': face_model}) + + if 'error' in data: + self.record(False, f'detect_{short}', f"error: {data}") + continue + + n = self._validate_detect_response(data, f'detect_{short}') + if n < 0: + continue + + labels = data.get('labels', []) + scores = data.get('scores', []) + detail_parts = [f"{n} detections"] + if labels: + detail_parts.append(f"labels={labels}") + if scores: + detail_parts.append(f"top_score={max(scores):.3f}") + + self.record(True, f'detect_{short}', ', '.join(detail_parts)) + total_detections += n + if n > 0: + any_face_found = True + + self.record(any_face_found, 'detect_found_faces', + f"{total_detections} total detections across {len(self.test_images)} images") + + def test_detect_with_model(self, model_name): + """POST /sdapi/v1/detect with a specific model on all images.""" + if not self.test_images: + self.skip(f'detect_model_{model_name}', 'no test images') + return + + total = 0 + for _img_name, img_b64 in self.test_images.items(): + data = self._post('/sdapi/v1/detect', {'image': img_b64, 'model': model_name}) + if 'error' not in data: + total += len(data.get('scores', [])) + + self.record(True, f'detect_model_{model_name}', f"{total} detections across {len(self.test_images)} images") + + # ========================================================================= + # Tests: Generation with Detailer + # ========================================================================= + + def test_txt2img_with_detailer(self): + """POST /sdapi/v1/txt2img with detailer_enabled=True.""" + self._category = 'generate' + print("\n--- Generation with Detailer ---") + + if self._critical_error: + self.skip('txt2img_detailer', self._critical_error) + return + + payload = { + 'prompt': 'a photo of a person, face, portrait', + 'negative_prompt': '', + 'steps': 10, + 'width': 512, + 'height': 512, + 'seed': 42, + 'save_images': False, + 'send_images': True, + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + 'detailer_max': 3, + } + + t0 = time.time() + # Detailer generation is multi-pass (generate + detect + inpaint per region), use longer timeout + try: + r = requests.post(f'{self.base_url}/sdapi/v1/txt2img', json=payload, timeout=600, verify=False) + if r.status_code != 200: + data = {'error': r.status_code, 'reason': r.reason} + else: + data = r.json() + except requests.exceptions.ConnectionError as e: + self.record(False, 'txt2img_detailer', f"connection error (is a model loaded?): {e}") + return + except requests.exceptions.ReadTimeout: + self.record(False, 'txt2img_detailer', 'timeout after 600s') + return + t1 = time.time() + + if 'error' in data: + self.record(False, 'txt2img_detailer', f"error: {data} (ensure a model is loaded)") + return + + # Should have images + has_images = 'images' in data and len(data['images']) > 0 + self.record(has_images, 'txt2img_detailer_has_images', f"time={t1 - t0:.1f}s") + + if has_images: + # Decode and verify image + from PIL import Image + img_data = data['images'][0].split(',', 1)[0] + img = Image.open(io.BytesIO(base64.b64decode(img_data))) + self.record(True, 'txt2img_detailer_image_valid', f"size={img.size}") + + # Check info field for detailer metadata + if 'info' in data: + info = data['info'] if isinstance(data['info'], str) else json.dumps(data['info']) + has_detailer_info = 'detailer' in info.lower() or 'Detailer' in info + self.record(has_detailer_info, 'txt2img_detailer_metadata', + 'detailer info found in metadata' if has_detailer_info else 'no detailer metadata (detection may have found nothing)') + + def test_txt2img_without_detailer(self): + """POST /sdapi/v1/txt2img baseline without detailer (sanity check).""" + if self._critical_error: + self.skip('txt2img_baseline', self._critical_error) + return + + payload = { + 'prompt': 'a simple landscape', + 'steps': 5, + 'width': 512, + 'height': 512, + 'seed': 42, + 'save_images': False, + 'send_images': True, + } + + data = self._post('/sdapi/v1/txt2img', payload) + if 'error' in data: + self.record(False, 'txt2img_baseline', f"error: {data}") + return + + has_images = 'images' in data and len(data['images']) > 0 + self.record(has_images, 'txt2img_baseline', 'generation works without detailer') + + # ========================================================================= + # Tests: Per-Request Detailer Param Validation + # ========================================================================= + + def _txt2img(self, extra_params=None): + """Helper: generate a portrait with optional param overrides.""" + payload = { + 'prompt': 'a photo of a person, face, portrait', + 'steps': 10, + 'width': 512, + 'height': 512, + 'seed': 42, + 'save_images': False, + 'send_images': True, + } + if extra_params: + payload.update(extra_params) + try: + r = requests.post(f'{self.base_url}/sdapi/v1/txt2img', json=payload, timeout=600, verify=False) + if r.status_code != 200: + return {'error': r.status_code, 'reason': r.reason} + return r.json() + except requests.exceptions.ConnectionError as e: + return {'error': 'connection_refused', 'reason': str(e)} + except requests.exceptions.ReadTimeout: + return {'error': 'timeout', 'reason': 'timeout after 600s'} + + def _decode_image(self, data): + """Decode first image from generation response into numpy array.""" + import numpy as np + from PIL import Image + if 'images' not in data or len(data['images']) == 0: + return None + img_data = data['images'][0].split(',', 1)[0] + img = Image.open(io.BytesIO(base64.b64decode(img_data))).convert('RGB') + return np.array(img, dtype=np.float32) + + def _pixel_diff(self, arr_a, arr_b): + """Mean absolute pixel difference between two images.""" + import numpy as np + if arr_a is None or arr_b is None or arr_a.shape != arr_b.shape: + return -1.0 + return float(np.abs(arr_a - arr_b).mean()) + + def _get_info(self, data): + """Extract info string from generation response.""" + if 'info' not in data: + return '' + info = data['info'] + return info if isinstance(info, str) else json.dumps(info) + + def run_detailer_param_tests(self, available_models=None): + """Verify per-request detailer params change the output.""" + self._category = 'detailer_params' + print("\n--- Per-Request Detailer Param Validation ---") + + if self._critical_error: + self.skip('detailer_params_all', self._critical_error) + return + + # Generate baseline WITHOUT detailer (same seed/prompt as detailer tests) + print(" Generating baseline (no detailer)...") + baseline_data = self._txt2img() + if 'error' in baseline_data: + self.record(False, 'detailer_baseline', f"error: {baseline_data}") + return + baseline = self._decode_image(baseline_data) + if baseline is None: + self.record(False, 'detailer_baseline', 'no image') + return + self.record(True, 'detailer_baseline') + + # Generate WITH detailer enabled (default params) + print(" Generating with detailer (defaults)...") + detailer_default_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + }) + if 'error' in detailer_default_data: + self.record(False, 'detailer_default', f"error: {detailer_default_data}") + return + detailer_default = self._decode_image(detailer_default_data) + + # Detailer ON vs OFF should produce different images (if a face was detected) + diff_on_off = self._pixel_diff(baseline, detailer_default) + self.record(diff_on_off > 0.5, 'detailer_on_vs_off', + f"mean_diff={diff_on_off:.2f}" if diff_on_off > 0.5 + else f"identical (diff={diff_on_off:.4f}) — no face detected?") + + # -- Strength variation -- + print(" Testing strength variation...") + strong_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.7, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + }) + if 'error' not in strong_data: + strong = self._decode_image(strong_data) + diff_strong = self._pixel_diff(detailer_default, strong) + self.record(diff_strong > 0.5, 'detailer_strength_effect', + f"strength 0.3 vs 0.7: diff={diff_strong:.2f}") + + # -- Steps variation -- + print(" Testing steps variation...") + more_steps_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 20, + 'detailer_conf': 0.3, + }) + if 'error' not in more_steps_data: + more_steps = self._decode_image(more_steps_data) + diff_steps = self._pixel_diff(detailer_default, more_steps) + self.record(diff_steps > 0.5, 'detailer_steps_effect', + f"steps 5 vs 20: diff={diff_steps:.2f}") + + # -- Resolution variation -- + print(" Testing resolution variation...") + hires_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + 'detailer_resolution': 512, + }) + if 'error' not in hires_data: + hires = self._decode_image(hires_data) + diff_res = self._pixel_diff(detailer_default, hires) + self.record(diff_res > 0.5, 'detailer_resolution_effect', + f"resolution 1024 vs 512: diff={diff_res:.2f}") + + # -- Segmentation mode -- + # Segmentation requires a -seg model (e.g. anzhc-face-1024-seg-8n). + # Detection-only models (face-yolo8n) don't produce masks, so the flag has no effect. + seg_models = [m.get('name', '') for m in (available_models or []) + if 'seg' in m.get('name', '').lower() and 'face' in m.get('name', '').lower()] + if seg_models: + seg_model = seg_models[0] + print(f" Testing segmentation mode (model={seg_model})...") + # bbox baseline with the seg model + seg_bbox_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + 'detailer_segmentation': False, + 'detailer_models': [seg_model], + }) + seg_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + 'detailer_segmentation': True, + 'detailer_models': [seg_model], + }) + if 'error' not in seg_data and 'error' not in seg_bbox_data: + seg_bbox = self._decode_image(seg_bbox_data) + seg_mask = self._decode_image(seg_data) + diff_seg = self._pixel_diff(seg_bbox, seg_mask) + self.record(diff_seg > 0.5, 'detailer_segmentation_effect', + f"bbox vs seg mask ({seg_model}): diff={diff_seg:.2f}") + else: + err = seg_data if 'error' in seg_data else seg_bbox_data + self.record(False, 'detailer_segmentation_effect', f"error: {err}") + else: + print(" Testing segmentation mode...") + seg_data = {'error': 'skipped'} + self.skip('detailer_segmentation_effect', 'no face-seg model available') + + # -- Confidence threshold -- + print(" Testing confidence threshold...") + high_conf_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.3, + 'detailer_steps': 5, + 'detailer_conf': 0.95, + }) + if 'error' not in high_conf_data: + high_conf = self._decode_image(high_conf_data) + diff_conf = self._pixel_diff(baseline, high_conf) + # High confidence may reject detections, making output closer to baseline + self.record(True, 'detailer_conf_effect', + f"conf=0.95 vs baseline: diff={diff_conf:.2f} " + f"(low diff = detections filtered out, high diff = still detected)") + + # -- Custom detailer prompt -- + print(" Testing detailer prompt override...") + prompt_data = self._txt2img({ + 'detailer_enabled': True, + 'detailer_strength': 0.5, + 'detailer_steps': 5, + 'detailer_conf': 0.3, + 'detailer_prompt': 'a detailed close-up face with freckles', + }) + if 'error' not in prompt_data: + prompt_result = self._decode_image(prompt_data) + diff_prompt = self._pixel_diff(detailer_default, prompt_result) + self.record(diff_prompt > 0.5, 'detailer_prompt_effect', + f"custom prompt vs default: diff={diff_prompt:.2f}") + + # -- Metadata verification across params -- + for test_data, label in [ + (detailer_default_data, 'detailer_default'), + (strong_data if 'error' not in strong_data else None, 'detailer_strong'), + (more_steps_data if 'error' not in more_steps_data else None, 'detailer_more_steps'), + (seg_data if 'error' not in seg_data else None, 'detailer_segmentation'), + ]: + if test_data is None: + continue + info = self._get_info(test_data) + has_meta = 'detailer' in info.lower() or 'Detailer' in info + self.record(has_meta, f'{label}_metadata', + 'detailer info in metadata' if has_meta else 'no detailer metadata') + + # -- Param isolation: generate without detailer after all detailer runs -- + print(" Testing param isolation...") + after_data = self._txt2img() + if 'error' not in after_data: + after = self._decode_image(after_data) + leak_diff = self._pixel_diff(baseline, after) + self.record(leak_diff < 0.5, 'detailer_param_isolation', + f"post-detailer baseline diff={leak_diff:.4f}" if leak_diff < 0.5 + else f"LEAK: baseline changed (diff={leak_diff:.2f})") + + # ========================================================================= + # Runner + # ========================================================================= + + def run_all(self): + print("=" * 60) + print("YOLO Detailer API Test Suite") + print(f"Server: {self.base_url}") + print("=" * 60) + + # Enumerate + models = self.test_detailers_list() + + # Detect across all loaded test images + self.test_detect_all_images(models) + # Test with first available model if any + if models and len(models) > 0: + model_name = models[0].get('name', models[0].get('filename', '')) + if model_name: + self.test_detect_with_model(model_name) + + # Generate + self.test_txt2img_without_detailer() + self.test_txt2img_with_detailer() + + # Per-request detailer param validation + self.run_detailer_param_tests(models) + + # Summary + print("\n" + "=" * 60) + print("Results") + print("=" * 60) + total_passed = 0 + total_failed = 0 + total_skipped = 0 + for cat, data in self.results.items(): + total_passed += data['passed'] + total_failed += data['failed'] + total_skipped += data['skipped'] + status = 'PASS' if data['failed'] == 0 else 'FAIL' + print(f" {cat}: {data['passed']} passed, {data['failed']} failed, {data['skipped']} skipped [{status}]") + print(f" Total: {total_passed} passed, {total_failed} failed, {total_skipped} skipped") + print("=" * 60) + return total_failed == 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='YOLO Detailer API Tests') + parser.add_argument('--url', default=os.environ.get('SDAPI_URL', 'http://127.0.0.1:7860'), help='server URL') + parser.add_argument('--image', default=None, help='test image path') + args = parser.parse_args() + test = DetailerAPITest(args.url, args.image) + success = test.run_all() + sys.exit(0 if success else 1) diff --git a/test/test-generation-api.py b/test/test-generation-api.py new file mode 100644 index 000000000..1359d59e5 --- /dev/null +++ b/test/test-generation-api.py @@ -0,0 +1,615 @@ +#!/usr/bin/env python +""" +API tests for generation with scheduler params, color grading, and latent corrections. + +Tests: +- GET /sdapi/v1/samplers — sampler enumeration and config +- POST /sdapi/v1/txt2img — generation with various samplers +- POST /sdapi/v1/txt2img — generation with color grading params +- POST /sdapi/v1/txt2img — generation with latent correction params + +Requires a running SD.Next instance with a model loaded. + +Usage: + python test/test-generation-api.py [--url URL] [--steps STEPS] +""" + +import io +import os +import sys +import json +import time +import base64 +import argparse +import requests +import urllib3 + +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class GenerationAPITest: + """Test harness for generation API with scheduler and grading params.""" + + # Samplers to test — a representative subset covering different scheduler families + TEST_SAMPLERS = [ + 'Euler a', + 'Euler', + 'DPM++ 2M', + 'UniPC', + 'DDIM', + 'DPM++ 2M SDE', + ] + + def __init__(self, base_url, steps=10, timeout=300): + self.base_url = base_url.rstrip('/') + self.steps = steps + self.timeout = timeout + self.results = { + 'samplers': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'generation': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'grading': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'correction': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + 'param_validation': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []}, + } + self._category = 'samplers' + self._critical_error = None + + def _get(self, endpoint): + try: + r = requests.get(f'{self.base_url}{endpoint}', timeout=self.timeout, verify=False) + if r.status_code != 200: + return {'error': r.status_code, 'reason': r.reason} + return r.json() + except requests.exceptions.ConnectionError: + return {'error': 'connection_refused', 'reason': 'Server not running'} + except Exception as e: + return {'error': 'exception', 'reason': str(e)} + + def _post(self, endpoint, data): + try: + r = requests.post(f'{self.base_url}{endpoint}', json=data, timeout=self.timeout, verify=False) + if r.status_code != 200: + return {'error': r.status_code, 'reason': r.reason} + return r.json() + except requests.exceptions.ConnectionError: + return {'error': 'connection_refused', 'reason': 'Server not running'} + except Exception as e: + return {'error': 'exception', 'reason': str(e)} + + def record(self, passed, name, detail=''): + status = 'PASS' if passed else 'FAIL' + self.results[self._category]['passed' if passed else 'failed'] += 1 + self.results[self._category]['tests'].append((status, name)) + msg = f' {status}: {name}' + if detail: + msg += f' ({detail})' + print(msg) + + def skip(self, name, reason): + self.results[self._category]['skipped'] += 1 + self.results[self._category]['tests'].append(('SKIP', name)) + print(f' SKIP: {name} ({reason})') + + def _txt2img(self, extra_params=None, prompt='a cat'): + """Helper: run txt2img with base params + overrides. Returns (data, time).""" + payload = { + 'prompt': prompt, + 'steps': self.steps, + 'width': 512, + 'height': 512, + 'seed': 42, + 'save_images': False, + 'send_images': True, + } + if extra_params: + payload.update(extra_params) + t0 = time.time() + data = self._post('/sdapi/v1/txt2img', payload) + return data, time.time() - t0 + + def _check_generation(self, data, test_name, elapsed): + """Validate a generation response has images.""" + if 'error' in data: + self.record(False, test_name, f"error: {data}") + return False + has_images = 'images' in data and len(data['images']) > 0 + self.record(has_images, test_name, f"time={elapsed:.1f}s") + return has_images + + def _get_info(self, data): + """Extract info string from generation response.""" + if 'info' not in data: + return '' + info = data['info'] + return info if isinstance(info, str) else json.dumps(info) + + def _decode_image(self, data): + """Decode first image from generation response into numpy array.""" + import numpy as np + from PIL import Image + if 'images' not in data or len(data['images']) == 0: + return None + img_data = data['images'][0].split(',', 1)[0] + img = Image.open(io.BytesIO(base64.b64decode(img_data))).convert('RGB') + return np.array(img, dtype=np.float32) + + def _pixel_diff(self, arr_a, arr_b): + """Mean absolute pixel difference between two images (0-255 scale).""" + import numpy as np + if arr_a is None or arr_b is None: + return -1.0 + if arr_a.shape != arr_b.shape: + return -1.0 + return float(np.abs(arr_a - arr_b).mean()) + + def _channel_means(self, arr): + """Return per-channel means [R, G, B].""" + if arr is None: + return [0, 0, 0] + return [float(arr[:, :, c].mean()) for c in range(3)] + + # ========================================================================= + # Tests: Sampler Enumeration + # ========================================================================= + + def test_samplers_list(self): + """GET /sdapi/v1/samplers returns available samplers with config.""" + self._category = 'samplers' + print("\n--- Sampler Enumeration ---") + + data = self._get('/sdapi/v1/samplers') + if 'error' in data: + self.record(False, 'samplers_list', f"error: {data}") + self._critical_error = f"Server error: {data}" + return [] + + if not isinstance(data, list): + self.record(False, 'samplers_list', f"expected list, got {type(data).__name__}") + return [] + + self.record(True, 'samplers_list', f"{len(data)} samplers available") + + # Check that each sampler has a name + sampler_names = [] + for s in data: + name = s.get('name', '') + if name: + sampler_names.append(name) + + self.record(len(sampler_names) == len(data), 'samplers_have_names', + f"{len(sampler_names)}/{len(data)} have names") + + # Check for our test samplers + for test_sampler in self.TEST_SAMPLERS: + found = test_sampler in sampler_names + if not found: + self.skip(f'sampler_available_{test_sampler}', 'not in server sampler list') + else: + self.record(True, f'sampler_available_{test_sampler}') + + return sampler_names + + # ========================================================================= + # Tests: Generation with Different Samplers + # ========================================================================= + + def test_samplers_generate(self, available_samplers): + """Generate with each test sampler and verify success.""" + self._category = 'generation' + print("\n--- Generation with Different Samplers ---") + + if self._critical_error: + for s in self.TEST_SAMPLERS: + self.skip(f'generate_{s}', self._critical_error) + return + + for sampler in self.TEST_SAMPLERS: + if sampler not in available_samplers: + self.skip(f'generate_{sampler}', 'sampler not available') + continue + data, elapsed = self._txt2img({'sampler_name': sampler}) + self._check_generation(data, f'generate_{sampler}', elapsed) + + # ========================================================================= + # Tests: Color Grading Params + # ========================================================================= + + def test_grading_brightness_contrast(self): + """Generate with grading brightness and contrast.""" + data, elapsed = self._txt2img({ + 'grading_brightness': 0.2, + 'grading_contrast': 0.3, + }) + self._check_generation(data, 'grading_brightness_contrast', elapsed) + + def test_grading_saturation_hue(self): + """Generate with grading saturation and hue shift.""" + data, elapsed = self._txt2img({ + 'grading_saturation': 0.5, + 'grading_hue': 0.1, + }) + self._check_generation(data, 'grading_saturation_hue', elapsed) + + def test_grading_gamma_sharpness(self): + """Generate with gamma correction and sharpness.""" + data, elapsed = self._txt2img({ + 'grading_gamma': 0.8, + 'grading_sharpness': 0.5, + }) + self._check_generation(data, 'grading_gamma_sharpness', elapsed) + + def test_grading_color_temp(self): + """Generate with warm color temperature.""" + data, elapsed = self._txt2img({ + 'grading_color_temp': 3500, + }) + self._check_generation(data, 'grading_color_temp', elapsed) + + def test_grading_tone(self): + """Generate with shadows/midtones/highlights adjustments.""" + data, elapsed = self._txt2img({ + 'grading_shadows': 0.3, + 'grading_midtones': -0.1, + 'grading_highlights': 0.2, + }) + self._check_generation(data, 'grading_tone', elapsed) + + def test_grading_effects(self): + """Generate with vignette and grain.""" + data, elapsed = self._txt2img({ + 'grading_vignette': 0.5, + 'grading_grain': 0.3, + }) + self._check_generation(data, 'grading_effects', elapsed) + + def test_grading_split_toning(self): + """Generate with split toning colors.""" + data, elapsed = self._txt2img({ + 'grading_shadows_tint': '#003366', + 'grading_highlights_tint': '#ffcc00', + 'grading_split_tone_balance': 0.6, + }) + self._check_generation(data, 'grading_split_toning', elapsed) + + def test_grading_combined(self): + """Generate with multiple grading params at once.""" + data, elapsed = self._txt2img({ + 'grading_brightness': 0.1, + 'grading_contrast': 0.2, + 'grading_saturation': 0.3, + 'grading_gamma': 0.9, + 'grading_color_temp': 5000, + 'grading_vignette': 0.3, + }) + self._check_generation(data, 'grading_combined', elapsed) + + def run_grading_tests(self): + """Run all grading tests.""" + self._category = 'grading' + print("\n--- Color Grading Tests ---") + + if self._critical_error: + self.skip('grading_all', self._critical_error) + return + + self.test_grading_brightness_contrast() + self.test_grading_saturation_hue() + self.test_grading_gamma_sharpness() + self.test_grading_color_temp() + self.test_grading_tone() + self.test_grading_effects() + self.test_grading_split_toning() + self.test_grading_combined() + + # ========================================================================= + # Tests: Latent Correction Params + # ========================================================================= + + def test_correction_brightness(self): + """Generate with latent brightness correction.""" + data, elapsed = self._txt2img({'hdr_brightness': 1.5}) + ok = self._check_generation(data, 'correction_brightness', elapsed) + if ok: + info = self._get_info(data) + has_param = 'Latent brightness' in info + self.record(has_param, 'correction_brightness_metadata', + 'found in info' if has_param else 'not found in info') + + def test_correction_color(self): + """Generate with latent color centering.""" + data, elapsed = self._txt2img({'hdr_color': 0.5, 'hdr_mode': 1}) + ok = self._check_generation(data, 'correction_color', elapsed) + if ok: + info = self._get_info(data) + has_param = 'Latent color' in info + self.record(has_param, 'correction_color_metadata', + 'found in info' if has_param else 'not found in info') + + def test_correction_clamp(self): + """Generate with latent clamping.""" + data, elapsed = self._txt2img({ + 'hdr_clamp': True, + 'hdr_threshold': 0.8, + 'hdr_boundary': 4.0, + }) + ok = self._check_generation(data, 'correction_clamp', elapsed) + if ok: + info = self._get_info(data) + has_param = 'Latent clamp' in info + self.record(has_param, 'correction_clamp_metadata', + 'found in info' if has_param else 'not found in info') + + def test_correction_sharpen(self): + """Generate with latent sharpening.""" + data, elapsed = self._txt2img({'hdr_sharpen': 1.0}) + ok = self._check_generation(data, 'correction_sharpen', elapsed) + if ok: + info = self._get_info(data) + has_param = 'Latent sharpen' in info + self.record(has_param, 'correction_sharpen_metadata', + 'found in info' if has_param else 'not found in info') + + def test_correction_maximize(self): + """Generate with latent maximize/normalize.""" + data, elapsed = self._txt2img({ + 'hdr_maximize': True, + 'hdr_max_center': 0.6, + 'hdr_max_boundary': 2.0, + }) + ok = self._check_generation(data, 'correction_maximize', elapsed) + if ok: + info = self._get_info(data) + has_param = 'Latent max' in info + self.record(has_param, 'correction_maximize_metadata', + 'found in info' if has_param else 'not found in info') + + def test_correction_combined(self): + """Generate with multiple correction params.""" + data, elapsed = self._txt2img({ + 'hdr_brightness': 1.0, + 'hdr_color': 0.3, + 'hdr_sharpen': 0.5, + 'hdr_clamp': True, + }) + ok = self._check_generation(data, 'correction_combined', elapsed) + if ok: + info = self._get_info(data) + # At least some correction params should appear + found = [k for k in ['Latent brightness', 'Latent color', 'Latent sharpen', 'Latent clamp'] if k in info] + self.record(len(found) > 0, 'correction_combined_metadata', f"found: {found}") + + def run_correction_tests(self): + """Run all latent correction tests.""" + self._category = 'correction' + print("\n--- Latent Correction Tests ---") + + if self._critical_error: + self.skip('correction_all', self._critical_error) + return + + self.test_correction_brightness() + self.test_correction_color() + self.test_correction_clamp() + self.test_correction_sharpen() + self.test_correction_maximize() + self.test_correction_combined() + + # ========================================================================= + # Tests: Per-Request Param Validation (baseline comparison) + # ========================================================================= + + def _generate_baseline(self): + """Generate a baseline image with no grading/correction. Cache and reuse.""" + if hasattr(self, '_baseline_arr') and self._baseline_arr is not None: + return self._baseline_arr, self._baseline_data + data, elapsed = self._txt2img() + if 'error' in data or 'images' not in data: + return None, data + self._baseline_arr = self._decode_image(data) + self._baseline_data = data + print(f' Baseline generated: time={elapsed:.1f}s mean={self._channel_means(self._baseline_arr)}') + return self._baseline_arr, data + + def _compare_param(self, name, params, check_fn=None): + """Generate with params and compare to baseline. Optionally run check_fn(baseline, result).""" + baseline, _ = self._generate_baseline() + if baseline is None: + self.skip(f'param_{name}', 'baseline generation failed') + return + + data, elapsed = self._txt2img(params) + if 'error' in data: + self.record(False, f'param_{name}', f"generation error: {data}") + return + + result = self._decode_image(data) + if result is None: + self.record(False, f'param_{name}', 'no image in response') + return + + diff = self._pixel_diff(baseline, result) + differs = diff > 0.5 # more than 0.5/255 mean difference + self.record(differs, f'param_{name}_differs', + f"mean_diff={diff:.2f}" if differs else f"images identical (diff={diff:.4f})") + + if check_fn and differs: + try: + ok, detail = check_fn(baseline, result, data) + self.record(ok, f'param_{name}_direction', detail) + except Exception as e: + self.record(False, f'param_{name}_direction', f"check error: {e}") + + def run_param_validation_tests(self): + """Verify per-request grading/correction params actually change the output.""" + self._category = 'param_validation' + print("\n--- Per-Request Param Validation ---") + + if self._critical_error: + self.skip('param_validation_all', self._critical_error) + return + + import numpy as np + + # -- Grading params -- + + # Brightness: positive should increase mean pixel value + def check_brightness(base, result, _data): + base_mean = float(base.mean()) + result_mean = float(result.mean()) + return result_mean > base_mean, f"baseline={base_mean:.1f} graded={result_mean:.1f}" + self._compare_param('grading_brightness', {'grading_brightness': 0.3}, check_brightness) + + # Contrast: should increase standard deviation + def check_contrast(base, result, _data): + return float(result.std()) > float(base.std()), \ + f"baseline_std={float(base.std()):.1f} graded_std={float(result.std()):.1f}" + self._compare_param('grading_contrast', {'grading_contrast': 0.5}, check_contrast) + + # Saturation: desaturation should reduce color channel spread + def check_desaturation(base, result, _data): + base_spread = max(self._channel_means(base)) - min(self._channel_means(base)) + result_spread = max(self._channel_means(result)) - min(self._channel_means(result)) + return result_spread < base_spread, \ + f"baseline_spread={base_spread:.1f} graded_spread={result_spread:.1f}" + self._compare_param('grading_saturation_neg', {'grading_saturation': -0.5}, check_desaturation) + + # Hue shift: just verify it changes + self._compare_param('grading_hue', {'grading_hue': 0.2}) + + # Gamma < 1: should brighten (raise values that are < 1) + def check_gamma(base, result, _data): + return float(result.mean()) > float(base.mean()), \ + f"baseline={float(base.mean()):.1f} gamma={float(result.mean()):.1f}" + self._compare_param('grading_gamma', {'grading_gamma': 0.7}, check_gamma) + + # Sharpness: just verify it changes + self._compare_param('grading_sharpness', {'grading_sharpness': 0.8}) + + # Color temperature warm: red channel mean should increase relative to blue + def check_warm(base, result, _data): + base_r, _, base_b = self._channel_means(base) + res_r, _, res_b = self._channel_means(result) + base_rb = base_r - base_b + res_rb = res_r - res_b + return res_rb > base_rb, f"baseline R-B={base_rb:.1f} warm R-B={res_rb:.1f}" + self._compare_param('grading_color_temp_warm', {'grading_color_temp': 3000}, check_warm) + + # Color temperature cool: blue should increase relative to red + def check_cool(base, result, _data): + base_r, _, base_b = self._channel_means(base) + res_r, _, res_b = self._channel_means(result) + base_rb = base_r - base_b + res_rb = res_r - res_b + return res_rb < base_rb, f"baseline R-B={base_rb:.1f} cool R-B={res_rb:.1f}" + self._compare_param('grading_color_temp_cool', {'grading_color_temp': 10000}, check_cool) + + # Vignette: corners should be darker than baseline corners + def check_vignette(base, result, _data): + h, w = base.shape[:2] + corner_size = h // 8 + base_corners = np.concatenate([ + base[:corner_size, :corner_size].flatten(), + base[:corner_size, -corner_size:].flatten(), + base[-corner_size:, :corner_size].flatten(), + base[-corner_size:, -corner_size:].flatten(), + ]) + result_corners = np.concatenate([ + result[:corner_size, :corner_size].flatten(), + result[:corner_size, -corner_size:].flatten(), + result[-corner_size:, :corner_size].flatten(), + result[-corner_size:, -corner_size:].flatten(), + ]) + return float(result_corners.mean()) < float(base_corners.mean()), \ + f"baseline_corners={float(base_corners.mean()):.1f} vignette_corners={float(result_corners.mean()):.1f}" + self._compare_param('grading_vignette', {'grading_vignette': 0.8}, check_vignette) + + # Grain: just verify it changes (stochastic) + self._compare_param('grading_grain', {'grading_grain': 0.5}) + + # Shadows/midtones/highlights: verify changes + self._compare_param('grading_shadows', {'grading_shadows': 0.5}) + self._compare_param('grading_highlights', {'grading_highlights': -0.3}) + + # CLAHE: should increase local contrast + self._compare_param('grading_clahe', {'grading_clahe_clip': 2.0}) + + # Split toning: verify changes + self._compare_param('grading_split_toning', { + 'grading_shadows_tint': '#003366', + 'grading_highlights_tint': '#ffcc00', + }) + + # -- Correction params -- + + # Latent brightness: should change output and appear in metadata + def check_correction_meta(key): + def _check(_base, _result, data): + info = self._get_info(data) + return key in info, f"'{key}' {'found' if key in info else 'missing'} in info" + return _check + self._compare_param('hdr_brightness', {'hdr_brightness': 2.0}, check_correction_meta('Latent brightness')) + self._compare_param('hdr_color', {'hdr_color': 0.8, 'hdr_mode': 1}, check_correction_meta('Latent color')) + self._compare_param('hdr_sharpen', {'hdr_sharpen': 1.5}, check_correction_meta('Latent sharpen')) + self._compare_param('hdr_clamp', {'hdr_clamp': True, 'hdr_threshold': 0.7}, check_correction_meta('Latent clamp')) + + # Isolation: verify params from one request don't leak to the next + data_after, _ = self._txt2img() + arr_after = self._decode_image(data_after) + baseline, _ = self._generate_baseline() + if baseline is not None and arr_after is not None: + leak_diff = self._pixel_diff(baseline, arr_after) + no_leak = leak_diff < 0.5 + self.record(no_leak, 'param_isolation', + f"post-grading baseline diff={leak_diff:.4f}" if no_leak + else f"LEAK: baseline changed after grading requests (diff={leak_diff:.2f})") + + # ========================================================================= + # Runner + # ========================================================================= + + def run_all(self): + print("=" * 60) + print("Generation API Test Suite") + print(f"Server: {self.base_url}") + print(f"Steps: {self.steps}") + print("=" * 60) + + # Samplers + available = self.test_samplers_list() + self.test_samplers_generate(available) + + # Grading + self.run_grading_tests() + + # Corrections + self.run_correction_tests() + + # Per-request param validation (baseline comparison) + self.run_param_validation_tests() + + # Summary + print("\n" + "=" * 60) + print("Results") + print("=" * 60) + total_passed = 0 + total_failed = 0 + total_skipped = 0 + for cat, data in self.results.items(): + total_passed += data['passed'] + total_failed += data['failed'] + total_skipped += data['skipped'] + status = 'PASS' if data['failed'] == 0 else 'FAIL' + print(f" {cat}: {data['passed']} passed, {data['failed']} failed, {data['skipped']} skipped [{status}]") + print(f" Total: {total_passed} passed, {total_failed} failed, {total_skipped} skipped") + print("=" * 60) + return total_failed == 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Generation API Tests (samplers, grading, correction)') + parser.add_argument('--url', default=os.environ.get('SDAPI_URL', 'http://127.0.0.1:7860'), help='server URL') + parser.add_argument('--steps', type=int, default=10, help='generation steps (lower = faster tests)') + args = parser.parse_args() + test = GenerationAPITest(args.url, args.steps) + success = test.run_all() + sys.exit(0 if success else 1)