validate benchmark results

pull/48/head
Vladimir Mandic 2023-12-14 10:06:17 -05:00
parent 1841cf7627
commit ce1b2a4a10
2 changed files with 23 additions and 4 deletions

View File

@ -4,6 +4,8 @@ from modules import shared
from modules.processing import StableDiffusionProcessingTxt2Img, Processed, process_images from modules.processing import StableDiffusionProcessingTxt2Img, Processed, process_images
log = logging.getLogger('sd')
args = { args = {
'sd_model': None, 'sd_model': None,
'prompt': 'postapocalyptic steampunk city, exploration, cinematic, realistic, hyper detailed, photorealistic maximum detail, volumetric light, (((focus))), wide-angle, (((brightly lit))), (((vegetation))), lightning, vines, destruction, devastation, wartorn, ruins', 'prompt': 'postapocalyptic steampunk city, exploration, cinematic, realistic, hyper detailed, photorealistic maximum detail, volumetric light, (((focus))), wide-angle, (((brightly lit))), (((vegetation))), lightning, vines, destruction, devastation, wartorn, ruins',
@ -29,16 +31,31 @@ def run_benchmark(batch: int, extra: bool):
args['sd_model'] = shared.sd_model args['sd_model'] = shared.sd_model
args['batch_size'] = batch args['batch_size'] = batch
args['steps'] = 20 if not extra else 50 args['steps'] = 20 if not extra else 50
mp = 0
p = StableDiffusionProcessingTxt2Img(**args) p = StableDiffusionProcessingTxt2Img(**args)
t0 = time.time() t0 = time.time()
try: try:
_processed: Processed = process_images(p) processed: Processed = process_images(p)
if processed is None or processed.images is None:
log.error(f'SD-System-Info benchmark error: {batch} no results')
return 'error'
if len(processed.images) != batch:
log.error(f'SD-System-Info benchmark error: {batch} results mismatch')
return 'error'
for image in processed.images:
if image.width * image.height < 65536:
log.error(f'SD-System-Info benchmark error: {batch} image too small')
return 'error'
mp += image.width * image.height
except Exception as e: except Exception as e:
print(f'SD-System-Info: benchmark error: {batch} {e}') log.error(f'SD-System-Info benchmark error: {batch} {e}')
return 'error' return 'error'
t1 = time.time() t1 = time.time()
shared.state.end() shared.state.end()
its = args['steps'] * batch / (t1 - t0) its = args['steps'] * batch / (t1 - t0)
mp = mp / 1024 / 1024
mps = mp / (t1 - t0)
log.debug(f'SD-System-Info benchmark: batch={batch} time={t1-t0:.2f} steps={args["steps"]} its={its:.2f} mps={mps:.2f}')
return round(its, 2) return round(its, 2)

View File

@ -565,9 +565,9 @@ def bench_submit(username: str):
def bench_run(batches: list = [1], extra: bool = False): def bench_run(batches: list = [1], extra: bool = False):
results = [] results = []
for batch in batches: for batch in batches:
log.debug(f'SD-System-Info: benchmark starting: batch-size={batch}') log.debug(f'SD-System-Info: benchmark starting: batch={batch}')
res = run_benchmark(batch, extra) res = run_benchmark(batch, extra)
log.info(f'SD-System-Info: benchmark batch-size={batch} it/s={res}') log.info(f'SD-System-Info: benchmark batch={batch} its={res}')
results.append(str(res)) results.append(str(res))
its = ' / '.join(results) its = ' / '.join(results)
return its return its
@ -671,6 +671,7 @@ class StatusReq(BaseModel): # definition of http request
full: bool = Field(title="FullInfo", description="Get full server info", default=False) full: bool = Field(title="FullInfo", description="Get full server info", default=False)
refresh: bool = Field(title="FullInfo", description="Force refresh server info", default=False) refresh: bool = Field(title="FullInfo", description="Force refresh server info", default=False)
class StatusRes(BaseModel): # definition of http response class StatusRes(BaseModel): # definition of http response
version: dict = Field(title="Version", description="Server version") version: dict = Field(title="Version", description="Server version")
uptime: str = Field(title="Uptime", description="Server uptime") uptime: str = Field(title="Uptime", description="Server uptime")
@ -686,6 +687,7 @@ class StatusRes(BaseModel): # definition of http response
backend: Optional[str] = Field(title="Backend", description="Backend") backend: Optional[str] = Field(title="Backend", description="Backend")
pipeline: Optional[str] = Field(title="Pipeline", description="Pipeline") pipeline: Optional[str] = Field(title="Pipeline", description="Pipeline")
def get_status_api(req: StatusReq = Depends()): def get_status_api(req: StatusReq = Depends()):
if req.refresh: if req.refresh:
get_full_data() get_full_data()