pull/27/head
Vladimir Mandic 2023-07-05 18:13:59 -04:00
parent b30e324552
commit d38db466a6
1 changed files with 113 additions and 49 deletions

View File

@ -8,20 +8,10 @@ import datetime
import logging
from hashlib import sha256
from html.parser import HTMLParser
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
except Exception:
pass
import accelerate
import gradio as gr
import psutil
import transformers
from modules import paths, script_callbacks, sd_hijack, sd_models, sd_samplers, shared, extensions, devices
from modules.ui_components import FormRow
from benchmark import run_benchmark, submit_benchmark # pylint: disable=E0401,E0611,C0411
@ -34,12 +24,12 @@ data = {
'date': '',
'timestamp': '',
'uptime': '',
'version': '',
'version': {},
'torch': '',
'gpu': {},
'state': {},
'memory': {},
'optimizations': '',
'optimizations': [],
'libs': {},
'repos': {},
'device': {},
@ -53,13 +43,15 @@ data = {
'extensions': [],
'platform': '',
'crossattention': '',
'backend': getattr(devices, 'backend', ''),
'pipeline': shared.opts.data.get('sd_backend', ''),
}
### benchmark globals
bench_text = ''
bench_file = os.path.join(os.path.dirname(__file__), 'benchmark-data-local.json')
bench_headers = ['timestamp', 'performance', 'version', 'system', 'libraries', 'gpu', 'optimizations', 'model', 'username', 'note', 'hash']
bench_headers = ['timestamp', 'performance', 'version', 'system', 'libraries', 'gpu', 'pipeline', 'model', 'username', 'note', 'hash']
bench_data = []
console_logging = None
@ -85,9 +77,10 @@ def get_user():
def get_gpu():
if not torch.cuda.is_available():
try:
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': str(ipex.__version__),
'ipex': get_package_version('intel-extension-for-pytorch'),
}
except Exception:
return {}
@ -213,16 +206,18 @@ def get_optimizations():
return ram
def get_package_version(pkg: str):
import pkg_resources
spec = pkg_resources.working_set.by_key.get(pkg, None) # more reliable than importlib
version = pkg_resources.get_distribution(pkg).version if spec is not None else ''
return version
def get_libs():
try:
import xformers # pylint: disable=import-outside-toplevel, import-error
xversion = xformers.__version__
except Exception:
xversion = 'unavailable'
return {
'xformers': xversion,
'accelerate': accelerate.__version__,
'transformers': transformers.__version__,
'xformers': get_package_version('xformers'),
'diffusers': get_package_version('diffusers'),
'transformers': get_package_version('transformers'),
}
@ -269,23 +264,27 @@ def get_torch():
def get_version():
version = None
if version is None:
try:
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
githash, updated = ver.split(' ')
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
version = {
'updated': updated,
'hash': githash,
'url': origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
}
except Exception:
pass
version = {}
try:
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
githash, updated = ver.split(' ')
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
url = origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
app = origin.replace('\n', '').split('/')[-1]
if app == 'automatic':
app = 'SD.next'
version = {
'app': app,
'updated': updated,
'hash': githash,
'url': url
}
except Exception:
pass
return version
@ -299,10 +298,8 @@ def get_skipped():
def get_crossattention():
try:
ca = sd_hijack.model_hijack.optimization_method
if ca is None:
return 'none'
else: return ca
ca = sd_hijack.model_hijack.optimization_method or getattr(shared.opts, 'cross_attention_optimization', 'none')
return ca
except Exception:
return 'unknown'
@ -369,6 +366,8 @@ def get_full_data():
'extensions': get_extensions(),
'platform': get_platform(),
'crossattention': get_crossattention(),
'backend': getattr(devices, 'backend', ''),
'pipeline': shared.opts.data.get('sd_backend', ''),
}
return data
@ -400,7 +399,7 @@ def refresh_info_quick(_old_data = None):
def refresh_info_full():
get_full_data()
return data['uptime'], dict2text(data['version']), dict2text(data['state']), dict2text(data['memory']), dict2text(data['platform']), data['torch'], dict2text(data['gpu']), list2text(data['optimizations']), data['crossattention'], dict2text(data['libs']), dict2text(data['repos']), dict2text(data['device']), data['models'], data['hypernetworks'], data['embeddings'], data['skipped'], data['loras'], data['lycos'], data['timestamp'], data
return data['uptime'], dict2text(data['version']), dict2text(data['state']), dict2text(data['memory']), dict2text(data['platform']), data['torch'], dict2text(data['gpu']), list2text(data['optimizations']), data['crossattention'], data['backend'], data['pipeline'], dict2text(data['libs']), dict2text(data['repos']), dict2text(data['device']), data['models'], data['hypernetworks'], data['embeddings'], data['skipped'], data['loras'], data['lycos'], data['timestamp'], data
### ui definition
@ -423,6 +422,9 @@ def on_ui_tabs():
with gr.Row():
with gr.Column():
platformtxt = gr.Textbox(dict2text(data['platform']), label = 'Platform', lines = len(data['platform']))
with gr.Row():
backendtxt = gr.Textbox(data['backend'], label = 'Backend')
pipelinetxt = gr.Textbox(data['pipeline'], label = 'Pipeline')
with gr.Column():
torchtxt = gr.Textbox(data['torch'], label = 'Torch', lines = 1)
gputxt = gr.Textbox(dict2text(data['gpu']), label = 'GPU', lines = len(data['gpu']))
@ -431,7 +433,7 @@ def on_ui_tabs():
attentiontxt = gr.Textbox(data['crossattention'], label = 'Cross-attention')
with gr.Column():
libstxt = gr.Textbox(dict2text(data['libs']), label = 'Libs', lines = len(data['libs']))
repostxt = gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos']))
repostxt = gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos']), visible = False)
devtxt = gr.Textbox(dict2text(data['device']), label = 'Device Info', lines = len(data['device']))
with gr.Box():
with gr.Accordion('Benchmarks...', open = True, visible = True):
@ -443,7 +445,7 @@ def on_ui_tabs():
username = gr.Textbox(get_user, label = 'Username', placeholder='enter username for submission', elem_id='system_info_tab_username')
note = gr.Textbox('', label = 'Note', placeholder='enter any additional notes', elem_id='system_info_tab_note')
with gr.Column(scale=1):
with FormRow():
with gr.Row():
global console_logging # pylint: disable=global-statement
console_logging = gr.Checkbox(label = 'Console logging', value = False, elem_id = 'system_info_tab_console', interactive = True)
warmup = gr.Checkbox(label = 'Perform warmup', value = True, elem_id = 'system_info_tab_warmup')
@ -497,7 +499,7 @@ def on_ui_tabs():
refresh_full_btn = gr.Button('Refresh data', elem_id = 'system_info_tab_refresh_full_btn', variant='primary').style()
refresh_full_btn.click(refresh_info_full, show_progress = False,
inputs = [],
outputs = [uptimetxt, versiontxt, statetxt, memorytxt, platformtxt, torchtxt, gputxt, opttxt, attentiontxt, libstxt, repostxt, devtxt, models, hypernetworks, embeddings, skipped, loras, lycos, timestamp, js]
outputs = [uptimetxt, versiontxt, statetxt, memorytxt, platformtxt, torchtxt, gputxt, opttxt, attentiontxt, backendtxt, pipelinetxt, libstxt, repostxt, devtxt, models, hypernetworks, embeddings, skipped, loras, lycos, timestamp, js]
)
interrupt_btn = gr.Button('Send interrupt', elem_id = 'system_info_tab_interrupt_btn', variant='primary')
interrupt_btn.click(shared.state.interrupt, inputs = [], outputs = [])
@ -563,7 +565,7 @@ def bench_init(username: str, note: str, warmup: bool, level: str, extra: bool):
d[3] = dict2str(data['platform'])
d[4] = f"torch:{data['torch']} {dict2str(data['libs'])}"
d[5] = dict2str(data['gpu']) + f' {str(round(mem))}GB'
d[6] = data['crossattention'] + ' ' + ','.join(data['optimizations'])
d[6] = (data['pipeline'] + ' ' + data['crossattention'] + ' ' + ','.join(data['optimizations'])).strip()
d[7] = shared.opts.data['sd_model_checkpoint']
d[8] = username
d[9] = note
@ -607,13 +609,75 @@ def bench_refresh():
return gr.HTML.update(value = bench_text)
def on_app_started(_block, app): # register api
### API
from typing import Optional # pylint: disable=wrong-import-order
from fastapi import FastAPI, Depends # pylint: disable=wrong-import-order
from pydantic import BaseModel, Field # pylint: disable=wrong-import-order,no-name-in-module
class StatusReq(BaseModel): # definition of http request
state: bool = Field(title="State", description="Get server state", default=False)
memory: bool = Field(title="Memory", description="Get server memory status", default=False)
full: bool = Field(title="FullInfo", description="Get full server info", default=False)
refresh: bool = Field(title="FullInfo", description="Force refresh server info", default=False)
class StatusRes(BaseModel): # definition of http response
version: dict = Field(title="Version", description="Server version")
uptime: str = Field(title="Uptime", description="Server uptime")
timestamp: str = Field(title="Timestamp", description="Data timestamp")
state: Optional[dict] = Field(title="State", description="Server state")
memory: Optional[dict] = Field(title="Memory", description="Server memory status")
platform: Optional[dict] = Field(title="Platform", description="Server platform")
torch: Optional[str] = Field(title="Torch", description="Torch version")
gpu: Optional[dict] = Field(title="GPU", description="GPU info")
optimizations: Optional[list] = Field(title="Optimizations", description="Memory optimizations")
crossatention: Optional[str] = Field(title="CrossAttention", description="Cross-attention optimization")
device: Optional[dict] = Field(title="Device", description="Device info")
backend: Optional[str] = Field(title="Backend", description="Backend")
pipeline: Optional[str] = Field(title="Pipeline", description="Pipeline")
def get_status_api(req: StatusReq = Depends()):
if req.refresh:
get_full_data()
else:
get_quick_data()
res = StatusRes(
version = data['version'],
timestamp = data['timestamp'],
uptime = data['uptime']
)
if req.state or req.full:
res.state = data['state']
if req.memory or req.full:
res.memory = data['memory']
if req.full:
res.platform = data['platform']
res.torch = data['torch']
res.gpu = data['gpu']
res.optimizations = data['optimizations']
res.crossatention = data['crossattention']
res.device = data['device']
res.backend = data['backend']
res.pipeline = data['pipeline']
return res
def register_api(app: FastAPI):
app.add_api_route("/sdapi/v1/system-info/status", get_status_api, methods=["GET"], response_model=StatusRes)
### Entry point
def on_app_started(_block, app): # register api
register_api(app)
"""
@app.get("/sdapi/v1/system-info/status")
async def sysinfo_api():
get_quick_data()
res = { 'state': data['state'], 'memory': data['memory'], 'timestamp': data['timestamp'] }
return res
"""
script_callbacks.on_ui_tabs(on_ui_tabs)