validate benchmark results
parent
1841cf7627
commit
ce1b2a4a10
21
benchmark.py
21
benchmark.py
|
|
@ -4,6 +4,8 @@ from modules import shared
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, Processed, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, Processed, process_images
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger('sd')
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
'sd_model': None,
|
'sd_model': None,
|
||||||
'prompt': 'postapocalyptic steampunk city, exploration, cinematic, realistic, hyper detailed, photorealistic maximum detail, volumetric light, (((focus))), wide-angle, (((brightly lit))), (((vegetation))), lightning, vines, destruction, devastation, wartorn, ruins',
|
'prompt': 'postapocalyptic steampunk city, exploration, cinematic, realistic, hyper detailed, photorealistic maximum detail, volumetric light, (((focus))), wide-angle, (((brightly lit))), (((vegetation))), lightning, vines, destruction, devastation, wartorn, ruins',
|
||||||
|
|
@ -29,16 +31,31 @@ def run_benchmark(batch: int, extra: bool):
|
||||||
args['sd_model'] = shared.sd_model
|
args['sd_model'] = shared.sd_model
|
||||||
args['batch_size'] = batch
|
args['batch_size'] = batch
|
||||||
args['steps'] = 20 if not extra else 50
|
args['steps'] = 20 if not extra else 50
|
||||||
|
mp = 0
|
||||||
p = StableDiffusionProcessingTxt2Img(**args)
|
p = StableDiffusionProcessingTxt2Img(**args)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
_processed: Processed = process_images(p)
|
processed: Processed = process_images(p)
|
||||||
|
if processed is None or processed.images is None:
|
||||||
|
log.error(f'SD-System-Info benchmark error: {batch} no results')
|
||||||
|
return 'error'
|
||||||
|
if len(processed.images) != batch:
|
||||||
|
log.error(f'SD-System-Info benchmark error: {batch} results mismatch')
|
||||||
|
return 'error'
|
||||||
|
for image in processed.images:
|
||||||
|
if image.width * image.height < 65536:
|
||||||
|
log.error(f'SD-System-Info benchmark error: {batch} image too small')
|
||||||
|
return 'error'
|
||||||
|
mp += image.width * image.height
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'SD-System-Info: benchmark error: {batch} {e}')
|
log.error(f'SD-System-Info benchmark error: {batch} {e}')
|
||||||
return 'error'
|
return 'error'
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
its = args['steps'] * batch / (t1 - t0)
|
its = args['steps'] * batch / (t1 - t0)
|
||||||
|
mp = mp / 1024 / 1024
|
||||||
|
mps = mp / (t1 - t0)
|
||||||
|
log.debug(f'SD-System-Info benchmark: batch={batch} time={t1-t0:.2f} steps={args["steps"]} its={its:.2f} mps={mps:.2f}')
|
||||||
return round(its, 2)
|
return round(its, 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -565,9 +565,9 @@ def bench_submit(username: str):
|
||||||
def bench_run(batches: list = [1], extra: bool = False):
|
def bench_run(batches: list = [1], extra: bool = False):
|
||||||
results = []
|
results = []
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
log.debug(f'SD-System-Info: benchmark starting: batch-size={batch}')
|
log.debug(f'SD-System-Info: benchmark starting: batch={batch}')
|
||||||
res = run_benchmark(batch, extra)
|
res = run_benchmark(batch, extra)
|
||||||
log.info(f'SD-System-Info: benchmark batch-size={batch} it/s={res}')
|
log.info(f'SD-System-Info: benchmark batch={batch} its={res}')
|
||||||
results.append(str(res))
|
results.append(str(res))
|
||||||
its = ' / '.join(results)
|
its = ' / '.join(results)
|
||||||
return its
|
return its
|
||||||
|
|
@ -671,6 +671,7 @@ class StatusReq(BaseModel): # definition of http request
|
||||||
full: bool = Field(title="FullInfo", description="Get full server info", default=False)
|
full: bool = Field(title="FullInfo", description="Get full server info", default=False)
|
||||||
refresh: bool = Field(title="FullInfo", description="Force refresh server info", default=False)
|
refresh: bool = Field(title="FullInfo", description="Force refresh server info", default=False)
|
||||||
|
|
||||||
|
|
||||||
class StatusRes(BaseModel): # definition of http response
|
class StatusRes(BaseModel): # definition of http response
|
||||||
version: dict = Field(title="Version", description="Server version")
|
version: dict = Field(title="Version", description="Server version")
|
||||||
uptime: str = Field(title="Uptime", description="Server uptime")
|
uptime: str = Field(title="Uptime", description="Server uptime")
|
||||||
|
|
@ -686,6 +687,7 @@ class StatusRes(BaseModel): # definition of http response
|
||||||
backend: Optional[str] = Field(title="Backend", description="Backend")
|
backend: Optional[str] = Field(title="Backend", description="Backend")
|
||||||
pipeline: Optional[str] = Field(title="Pipeline", description="Pipeline")
|
pipeline: Optional[str] = Field(title="Pipeline", description="Pipeline")
|
||||||
|
|
||||||
|
|
||||||
def get_status_api(req: StatusReq = Depends()):
|
def get_status_api(req: StatusReq = Depends()):
|
||||||
if req.refresh:
|
if req.refresh:
|
||||||
get_full_data()
|
get_full_data()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue