automatic/test/test-generation-api.py

616 lines
25 KiB
Python

#!/usr/bin/env python
"""
API tests for generation with scheduler params, color grading, and latent corrections.
Tests:
- GET /sdapi/v1/samplers — sampler enumeration and config
- POST /sdapi/v1/txt2img — generation with various samplers
- POST /sdapi/v1/txt2img — generation with color grading params
- POST /sdapi/v1/txt2img — generation with latent correction params
Requires a running SD.Next instance with a model loaded.
Usage:
python test/test-generation-api.py [--url URL] [--steps STEPS]
"""
import io
import os
import sys
import json
import time
import base64
import argparse
import requests
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class GenerationAPITest:
"""Test harness for generation API with scheduler and grading params."""
# Samplers to test — a representative subset covering different scheduler families
TEST_SAMPLERS = [
'Euler a',
'Euler',
'DPM++ 2M',
'UniPC',
'DDIM',
'DPM++ 2M SDE',
]
def __init__(self, base_url, steps=10, timeout=300):
self.base_url = base_url.rstrip('/')
self.steps = steps
self.timeout = timeout
self.results = {
'samplers': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
'generation': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
'grading': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
'correction': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
'param_validation': {'passed': 0, 'failed': 0, 'skipped': 0, 'tests': []},
}
self._category = 'samplers'
self._critical_error = None
def _get(self, endpoint):
try:
r = requests.get(f'{self.base_url}{endpoint}', timeout=self.timeout, verify=False)
if r.status_code != 200:
return {'error': r.status_code, 'reason': r.reason}
return r.json()
except requests.exceptions.ConnectionError:
return {'error': 'connection_refused', 'reason': 'Server not running'}
except Exception as e:
return {'error': 'exception', 'reason': str(e)}
def _post(self, endpoint, data):
try:
r = requests.post(f'{self.base_url}{endpoint}', json=data, timeout=self.timeout, verify=False)
if r.status_code != 200:
return {'error': r.status_code, 'reason': r.reason}
return r.json()
except requests.exceptions.ConnectionError:
return {'error': 'connection_refused', 'reason': 'Server not running'}
except Exception as e:
return {'error': 'exception', 'reason': str(e)}
def record(self, passed, name, detail=''):
status = 'PASS' if passed else 'FAIL'
self.results[self._category]['passed' if passed else 'failed'] += 1
self.results[self._category]['tests'].append((status, name))
msg = f' {status}: {name}'
if detail:
msg += f' ({detail})'
print(msg)
def skip(self, name, reason):
self.results[self._category]['skipped'] += 1
self.results[self._category]['tests'].append(('SKIP', name))
print(f' SKIP: {name} ({reason})')
def _txt2img(self, extra_params=None, prompt='a cat'):
"""Helper: run txt2img with base params + overrides. Returns (data, time)."""
payload = {
'prompt': prompt,
'steps': self.steps,
'width': 512,
'height': 512,
'seed': 42,
'save_images': False,
'send_images': True,
}
if extra_params:
payload.update(extra_params)
t0 = time.time()
data = self._post('/sdapi/v1/txt2img', payload)
return data, time.time() - t0
def _check_generation(self, data, test_name, elapsed):
"""Validate a generation response has images."""
if 'error' in data:
self.record(False, test_name, f"error: {data}")
return False
has_images = 'images' in data and len(data['images']) > 0
self.record(has_images, test_name, f"time={elapsed:.1f}s")
return has_images
def _get_info(self, data):
"""Extract info string from generation response."""
if 'info' not in data:
return ''
info = data['info']
return info if isinstance(info, str) else json.dumps(info)
def _decode_image(self, data):
"""Decode first image from generation response into numpy array."""
import numpy as np
from PIL import Image
if 'images' not in data or len(data['images']) == 0:
return None
img_data = data['images'][0].split(',', 1)[0]
img = Image.open(io.BytesIO(base64.b64decode(img_data))).convert('RGB')
return np.array(img, dtype=np.float32)
def _pixel_diff(self, arr_a, arr_b):
"""Mean absolute pixel difference between two images (0-255 scale)."""
import numpy as np
if arr_a is None or arr_b is None:
return -1.0
if arr_a.shape != arr_b.shape:
return -1.0
return float(np.abs(arr_a - arr_b).mean())
def _channel_means(self, arr):
"""Return per-channel means [R, G, B]."""
if arr is None:
return [0, 0, 0]
return [float(arr[:, :, c].mean()) for c in range(3)]
# =========================================================================
# Tests: Sampler Enumeration
# =========================================================================
def test_samplers_list(self):
"""GET /sdapi/v1/samplers returns available samplers with config."""
self._category = 'samplers'
print("\n--- Sampler Enumeration ---")
data = self._get('/sdapi/v1/samplers')
if 'error' in data:
self.record(False, 'samplers_list', f"error: {data}")
self._critical_error = f"Server error: {data}"
return []
if not isinstance(data, list):
self.record(False, 'samplers_list', f"expected list, got {type(data).__name__}")
return []
self.record(True, 'samplers_list', f"{len(data)} samplers available")
# Check that each sampler has a name
sampler_names = []
for s in data:
name = s.get('name', '')
if name:
sampler_names.append(name)
self.record(len(sampler_names) == len(data), 'samplers_have_names',
f"{len(sampler_names)}/{len(data)} have names")
# Check for our test samplers
for test_sampler in self.TEST_SAMPLERS:
found = test_sampler in sampler_names
if not found:
self.skip(f'sampler_available_{test_sampler}', 'not in server sampler list')
else:
self.record(True, f'sampler_available_{test_sampler}')
return sampler_names
# =========================================================================
# Tests: Generation with Different Samplers
# =========================================================================
def test_samplers_generate(self, available_samplers):
"""Generate with each test sampler and verify success."""
self._category = 'generation'
print("\n--- Generation with Different Samplers ---")
if self._critical_error:
for s in self.TEST_SAMPLERS:
self.skip(f'generate_{s}', self._critical_error)
return
for sampler in self.TEST_SAMPLERS:
if sampler not in available_samplers:
self.skip(f'generate_{sampler}', 'sampler not available')
continue
data, elapsed = self._txt2img({'sampler_name': sampler})
self._check_generation(data, f'generate_{sampler}', elapsed)
# =========================================================================
# Tests: Color Grading Params
# =========================================================================
def test_grading_brightness_contrast(self):
"""Generate with grading brightness and contrast."""
data, elapsed = self._txt2img({
'grading_brightness': 0.2,
'grading_contrast': 0.3,
})
self._check_generation(data, 'grading_brightness_contrast', elapsed)
def test_grading_saturation_hue(self):
"""Generate with grading saturation and hue shift."""
data, elapsed = self._txt2img({
'grading_saturation': 0.5,
'grading_hue': 0.1,
})
self._check_generation(data, 'grading_saturation_hue', elapsed)
def test_grading_gamma_sharpness(self):
"""Generate with gamma correction and sharpness."""
data, elapsed = self._txt2img({
'grading_gamma': 0.8,
'grading_sharpness': 0.5,
})
self._check_generation(data, 'grading_gamma_sharpness', elapsed)
def test_grading_color_temp(self):
"""Generate with warm color temperature."""
data, elapsed = self._txt2img({
'grading_color_temp': 3500,
})
self._check_generation(data, 'grading_color_temp', elapsed)
def test_grading_tone(self):
"""Generate with shadows/midtones/highlights adjustments."""
data, elapsed = self._txt2img({
'grading_shadows': 0.3,
'grading_midtones': -0.1,
'grading_highlights': 0.2,
})
self._check_generation(data, 'grading_tone', elapsed)
def test_grading_effects(self):
"""Generate with vignette and grain."""
data, elapsed = self._txt2img({
'grading_vignette': 0.5,
'grading_grain': 0.3,
})
self._check_generation(data, 'grading_effects', elapsed)
def test_grading_split_toning(self):
"""Generate with split toning colors."""
data, elapsed = self._txt2img({
'grading_shadows_tint': '#003366',
'grading_highlights_tint': '#ffcc00',
'grading_split_tone_balance': 0.6,
})
self._check_generation(data, 'grading_split_toning', elapsed)
def test_grading_combined(self):
"""Generate with multiple grading params at once."""
data, elapsed = self._txt2img({
'grading_brightness': 0.1,
'grading_contrast': 0.2,
'grading_saturation': 0.3,
'grading_gamma': 0.9,
'grading_color_temp': 5000,
'grading_vignette': 0.3,
})
self._check_generation(data, 'grading_combined', elapsed)
def run_grading_tests(self):
"""Run all grading tests."""
self._category = 'grading'
print("\n--- Color Grading Tests ---")
if self._critical_error:
self.skip('grading_all', self._critical_error)
return
self.test_grading_brightness_contrast()
self.test_grading_saturation_hue()
self.test_grading_gamma_sharpness()
self.test_grading_color_temp()
self.test_grading_tone()
self.test_grading_effects()
self.test_grading_split_toning()
self.test_grading_combined()
# =========================================================================
# Tests: Latent Correction Params
# =========================================================================
def test_correction_brightness(self):
"""Generate with latent brightness correction."""
data, elapsed = self._txt2img({'hdr_brightness': 1.5})
ok = self._check_generation(data, 'correction_brightness', elapsed)
if ok:
info = self._get_info(data)
has_param = 'Latent brightness' in info
self.record(has_param, 'correction_brightness_metadata',
'found in info' if has_param else 'not found in info')
def test_correction_color(self):
"""Generate with latent color centering."""
data, elapsed = self._txt2img({'hdr_color': 0.5, 'hdr_mode': 1})
ok = self._check_generation(data, 'correction_color', elapsed)
if ok:
info = self._get_info(data)
has_param = 'Latent color' in info
self.record(has_param, 'correction_color_metadata',
'found in info' if has_param else 'not found in info')
def test_correction_clamp(self):
"""Generate with latent clamping."""
data, elapsed = self._txt2img({
'hdr_clamp': True,
'hdr_threshold': 0.8,
'hdr_boundary': 4.0,
})
ok = self._check_generation(data, 'correction_clamp', elapsed)
if ok:
info = self._get_info(data)
has_param = 'Latent clamp' in info
self.record(has_param, 'correction_clamp_metadata',
'found in info' if has_param else 'not found in info')
def test_correction_sharpen(self):
"""Generate with latent sharpening."""
data, elapsed = self._txt2img({'hdr_sharpen': 1.0})
ok = self._check_generation(data, 'correction_sharpen', elapsed)
if ok:
info = self._get_info(data)
has_param = 'Latent sharpen' in info
self.record(has_param, 'correction_sharpen_metadata',
'found in info' if has_param else 'not found in info')
def test_correction_maximize(self):
"""Generate with latent maximize/normalize."""
data, elapsed = self._txt2img({
'hdr_maximize': True,
'hdr_max_center': 0.6,
'hdr_max_boundary': 2.0,
})
ok = self._check_generation(data, 'correction_maximize', elapsed)
if ok:
info = self._get_info(data)
has_param = 'Latent max' in info
self.record(has_param, 'correction_maximize_metadata',
'found in info' if has_param else 'not found in info')
def test_correction_combined(self):
"""Generate with multiple correction params."""
data, elapsed = self._txt2img({
'hdr_brightness': 1.0,
'hdr_color': 0.3,
'hdr_sharpen': 0.5,
'hdr_clamp': True,
})
ok = self._check_generation(data, 'correction_combined', elapsed)
if ok:
info = self._get_info(data)
# At least some correction params should appear
found = [k for k in ['Latent brightness', 'Latent color', 'Latent sharpen', 'Latent clamp'] if k in info]
self.record(len(found) > 0, 'correction_combined_metadata', f"found: {found}")
def run_correction_tests(self):
"""Run all latent correction tests."""
self._category = 'correction'
print("\n--- Latent Correction Tests ---")
if self._critical_error:
self.skip('correction_all', self._critical_error)
return
self.test_correction_brightness()
self.test_correction_color()
self.test_correction_clamp()
self.test_correction_sharpen()
self.test_correction_maximize()
self.test_correction_combined()
# =========================================================================
# Tests: Per-Request Param Validation (baseline comparison)
# =========================================================================
def _generate_baseline(self):
"""Generate a baseline image with no grading/correction. Cache and reuse."""
if hasattr(self, '_baseline_arr') and self._baseline_arr is not None:
return self._baseline_arr, self._baseline_data
data, elapsed = self._txt2img()
if 'error' in data or 'images' not in data:
return None, data
self._baseline_arr = self._decode_image(data)
self._baseline_data = data
print(f' Baseline generated: time={elapsed:.1f}s mean={self._channel_means(self._baseline_arr)}')
return self._baseline_arr, data
def _compare_param(self, name, params, check_fn=None):
"""Generate with params and compare to baseline. Optionally run check_fn(baseline, result)."""
baseline, _ = self._generate_baseline()
if baseline is None:
self.skip(f'param_{name}', 'baseline generation failed')
return
data, elapsed = self._txt2img(params)
if 'error' in data:
self.record(False, f'param_{name}', f"generation error: {data}")
return
result = self._decode_image(data)
if result is None:
self.record(False, f'param_{name}', 'no image in response')
return
diff = self._pixel_diff(baseline, result)
differs = diff > 0.5 # more than 0.5/255 mean difference
self.record(differs, f'param_{name}_differs',
f"mean_diff={diff:.2f}" if differs else f"images identical (diff={diff:.4f})")
if check_fn and differs:
try:
ok, detail = check_fn(baseline, result, data)
self.record(ok, f'param_{name}_direction', detail)
except Exception as e:
self.record(False, f'param_{name}_direction', f"check error: {e}")
def run_param_validation_tests(self):
"""Verify per-request grading/correction params actually change the output."""
self._category = 'param_validation'
print("\n--- Per-Request Param Validation ---")
if self._critical_error:
self.skip('param_validation_all', self._critical_error)
return
import numpy as np
# -- Grading params --
# Brightness: positive should increase mean pixel value
def check_brightness(base, result, _data):
base_mean = float(base.mean())
result_mean = float(result.mean())
return result_mean > base_mean, f"baseline={base_mean:.1f} graded={result_mean:.1f}"
self._compare_param('grading_brightness', {'grading_brightness': 0.3}, check_brightness)
# Contrast: should increase standard deviation
def check_contrast(base, result, _data):
return float(result.std()) > float(base.std()), \
f"baseline_std={float(base.std()):.1f} graded_std={float(result.std()):.1f}"
self._compare_param('grading_contrast', {'grading_contrast': 0.5}, check_contrast)
# Saturation: desaturation should reduce color channel spread
def check_desaturation(base, result, _data):
base_spread = max(self._channel_means(base)) - min(self._channel_means(base))
result_spread = max(self._channel_means(result)) - min(self._channel_means(result))
return result_spread < base_spread, \
f"baseline_spread={base_spread:.1f} graded_spread={result_spread:.1f}"
self._compare_param('grading_saturation_neg', {'grading_saturation': -0.5}, check_desaturation)
# Hue shift: just verify it changes
self._compare_param('grading_hue', {'grading_hue': 0.2})
# Gamma < 1: should brighten (raise values that are < 1)
def check_gamma(base, result, _data):
return float(result.mean()) > float(base.mean()), \
f"baseline={float(base.mean()):.1f} gamma={float(result.mean()):.1f}"
self._compare_param('grading_gamma', {'grading_gamma': 0.7}, check_gamma)
# Sharpness: just verify it changes
self._compare_param('grading_sharpness', {'grading_sharpness': 0.8})
# Color temperature warm: red channel mean should increase relative to blue
def check_warm(base, result, _data):
base_r, _, base_b = self._channel_means(base)
res_r, _, res_b = self._channel_means(result)
base_rb = base_r - base_b
res_rb = res_r - res_b
return res_rb > base_rb, f"baseline R-B={base_rb:.1f} warm R-B={res_rb:.1f}"
self._compare_param('grading_color_temp_warm', {'grading_color_temp': 3000}, check_warm)
# Color temperature cool: blue should increase relative to red
def check_cool(base, result, _data):
base_r, _, base_b = self._channel_means(base)
res_r, _, res_b = self._channel_means(result)
base_rb = base_r - base_b
res_rb = res_r - res_b
return res_rb < base_rb, f"baseline R-B={base_rb:.1f} cool R-B={res_rb:.1f}"
self._compare_param('grading_color_temp_cool', {'grading_color_temp': 10000}, check_cool)
# Vignette: corners should be darker than baseline corners
def check_vignette(base, result, _data):
h, w = base.shape[:2]
corner_size = h // 8
base_corners = np.concatenate([
base[:corner_size, :corner_size].flatten(),
base[:corner_size, -corner_size:].flatten(),
base[-corner_size:, :corner_size].flatten(),
base[-corner_size:, -corner_size:].flatten(),
])
result_corners = np.concatenate([
result[:corner_size, :corner_size].flatten(),
result[:corner_size, -corner_size:].flatten(),
result[-corner_size:, :corner_size].flatten(),
result[-corner_size:, -corner_size:].flatten(),
])
return float(result_corners.mean()) < float(base_corners.mean()), \
f"baseline_corners={float(base_corners.mean()):.1f} vignette_corners={float(result_corners.mean()):.1f}"
self._compare_param('grading_vignette', {'grading_vignette': 0.8}, check_vignette)
# Grain: just verify it changes (stochastic)
self._compare_param('grading_grain', {'grading_grain': 0.5})
# Shadows/midtones/highlights: verify changes
self._compare_param('grading_shadows', {'grading_shadows': 0.5})
self._compare_param('grading_highlights', {'grading_highlights': -0.3})
# CLAHE: should increase local contrast
self._compare_param('grading_clahe', {'grading_clahe_clip': 2.0})
# Split toning: verify changes
self._compare_param('grading_split_toning', {
'grading_shadows_tint': '#003366',
'grading_highlights_tint': '#ffcc00',
})
# -- Correction params --
# Latent brightness: should change output and appear in metadata
def check_correction_meta(key):
def _check(_base, _result, data):
info = self._get_info(data)
return key in info, f"'{key}' {'found' if key in info else 'missing'} in info"
return _check
self._compare_param('hdr_brightness', {'hdr_brightness': 2.0}, check_correction_meta('Latent brightness'))
self._compare_param('hdr_color', {'hdr_color': 0.8, 'hdr_mode': 1}, check_correction_meta('Latent color'))
self._compare_param('hdr_sharpen', {'hdr_sharpen': 1.5}, check_correction_meta('Latent sharpen'))
self._compare_param('hdr_clamp', {'hdr_clamp': True, 'hdr_threshold': 0.7}, check_correction_meta('Latent clamp'))
# Isolation: verify params from one request don't leak to the next
data_after, _ = self._txt2img()
arr_after = self._decode_image(data_after)
baseline, _ = self._generate_baseline()
if baseline is not None and arr_after is not None:
leak_diff = self._pixel_diff(baseline, arr_after)
no_leak = leak_diff < 0.5
self.record(no_leak, 'param_isolation',
f"post-grading baseline diff={leak_diff:.4f}" if no_leak
else f"LEAK: baseline changed after grading requests (diff={leak_diff:.2f})")
# =========================================================================
# Runner
# =========================================================================
def run_all(self):
print("=" * 60)
print("Generation API Test Suite")
print(f"Server: {self.base_url}")
print(f"Steps: {self.steps}")
print("=" * 60)
# Samplers
available = self.test_samplers_list()
self.test_samplers_generate(available)
# Grading
self.run_grading_tests()
# Corrections
self.run_correction_tests()
# Per-request param validation (baseline comparison)
self.run_param_validation_tests()
# Summary
print("\n" + "=" * 60)
print("Results")
print("=" * 60)
total_passed = 0
total_failed = 0
total_skipped = 0
for cat, data in self.results.items():
total_passed += data['passed']
total_failed += data['failed']
total_skipped += data['skipped']
status = 'PASS' if data['failed'] == 0 else 'FAIL'
print(f" {cat}: {data['passed']} passed, {data['failed']} failed, {data['skipped']} skipped [{status}]")
print(f" Total: {total_passed} passed, {total_failed} failed, {total_skipped} skipped")
print("=" * 60)
return total_failed == 0
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generation API Tests (samplers, grading, correction)')
parser.add_argument('--url', default=os.environ.get('SDAPI_URL', 'http://127.0.0.1:7860'), help='server URL')
parser.add_argument('--steps', type=int, default=10, help='generation steps (lower = faster tests)')
args = parser.parse_args()
test = GenerationAPITest(args.url, args.steps)
success = test.run_all()
sys.exit(0 if success else 1)