update
parent
b30e324552
commit
d38db466a6
|
|
@ -8,20 +8,10 @@ import datetime
|
|||
import logging
|
||||
from hashlib import sha256
|
||||
from html.parser import HTMLParser
|
||||
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
except Exception:
|
||||
pass
|
||||
import accelerate
|
||||
import gradio as gr
|
||||
import psutil
|
||||
import transformers
|
||||
|
||||
from modules import paths, script_callbacks, sd_hijack, sd_models, sd_samplers, shared, extensions, devices
|
||||
from modules.ui_components import FormRow
|
||||
|
||||
from benchmark import run_benchmark, submit_benchmark # pylint: disable=E0401,E0611,C0411
|
||||
|
||||
|
||||
|
|
@ -34,12 +24,12 @@ data = {
|
|||
'date': '',
|
||||
'timestamp': '',
|
||||
'uptime': '',
|
||||
'version': '',
|
||||
'version': {},
|
||||
'torch': '',
|
||||
'gpu': {},
|
||||
'state': {},
|
||||
'memory': {},
|
||||
'optimizations': '',
|
||||
'optimizations': [],
|
||||
'libs': {},
|
||||
'repos': {},
|
||||
'device': {},
|
||||
|
|
@ -53,13 +43,15 @@ data = {
|
|||
'extensions': [],
|
||||
'platform': '',
|
||||
'crossattention': '',
|
||||
'backend': getattr(devices, 'backend', ''),
|
||||
'pipeline': shared.opts.data.get('sd_backend', ''),
|
||||
}
|
||||
|
||||
### benchmark globals
|
||||
|
||||
bench_text = ''
|
||||
bench_file = os.path.join(os.path.dirname(__file__), 'benchmark-data-local.json')
|
||||
bench_headers = ['timestamp', 'performance', 'version', 'system', 'libraries', 'gpu', 'optimizations', 'model', 'username', 'note', 'hash']
|
||||
bench_headers = ['timestamp', 'performance', 'version', 'system', 'libraries', 'gpu', 'pipeline', 'model', 'username', 'note', 'hash']
|
||||
bench_data = []
|
||||
console_logging = None
|
||||
|
||||
|
|
@ -85,9 +77,10 @@ def get_user():
|
|||
def get_gpu():
|
||||
if not torch.cuda.is_available():
|
||||
try:
|
||||
import intel_extension_for_pytorch # pylint: disable=import-error, unused-import
|
||||
return {
|
||||
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
|
||||
'ipex': str(ipex.__version__),
|
||||
'ipex': get_package_version('intel-extension-for-pytorch'),
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
|
|
@ -213,16 +206,18 @@ def get_optimizations():
|
|||
return ram
|
||||
|
||||
|
||||
def get_package_version(pkg: str):
|
||||
import pkg_resources
|
||||
spec = pkg_resources.working_set.by_key.get(pkg, None) # more reliable than importlib
|
||||
version = pkg_resources.get_distribution(pkg).version if spec is not None else ''
|
||||
return version
|
||||
|
||||
|
||||
def get_libs():
|
||||
try:
|
||||
import xformers # pylint: disable=import-outside-toplevel, import-error
|
||||
xversion = xformers.__version__
|
||||
except Exception:
|
||||
xversion = 'unavailable'
|
||||
return {
|
||||
'xformers': xversion,
|
||||
'accelerate': accelerate.__version__,
|
||||
'transformers': transformers.__version__,
|
||||
'xformers': get_package_version('xformers'),
|
||||
'diffusers': get_package_version('diffusers'),
|
||||
'transformers': get_package_version('transformers'),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -269,23 +264,27 @@ def get_torch():
|
|||
|
||||
|
||||
def get_version():
|
||||
version = None
|
||||
if version is None:
|
||||
try:
|
||||
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
githash, updated = ver.split(' ')
|
||||
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
version = {
|
||||
'updated': updated,
|
||||
'hash': githash,
|
||||
'url': origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
version = {}
|
||||
try:
|
||||
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
githash, updated = ver.split(' ')
|
||||
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
url = origin.replace('\n', '') + '/tree/' + branch.replace('\n', '')
|
||||
app = origin.replace('\n', '').split('/')[-1]
|
||||
if app == 'automatic':
|
||||
app = 'SD.next'
|
||||
version = {
|
||||
'app': app,
|
||||
'updated': updated,
|
||||
'hash': githash,
|
||||
'url': url
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
return version
|
||||
|
||||
|
||||
|
|
@ -299,10 +298,8 @@ def get_skipped():
|
|||
|
||||
def get_crossattention():
|
||||
try:
|
||||
ca = sd_hijack.model_hijack.optimization_method
|
||||
if ca is None:
|
||||
return 'none'
|
||||
else: return ca
|
||||
ca = sd_hijack.model_hijack.optimization_method or getattr(shared.opts, 'cross_attention_optimization', 'none')
|
||||
return ca
|
||||
except Exception:
|
||||
return 'unknown'
|
||||
|
||||
|
|
@ -369,6 +366,8 @@ def get_full_data():
|
|||
'extensions': get_extensions(),
|
||||
'platform': get_platform(),
|
||||
'crossattention': get_crossattention(),
|
||||
'backend': getattr(devices, 'backend', ''),
|
||||
'pipeline': shared.opts.data.get('sd_backend', ''),
|
||||
}
|
||||
return data
|
||||
|
||||
|
|
@ -400,7 +399,7 @@ def refresh_info_quick(_old_data = None):
|
|||
|
||||
def refresh_info_full():
|
||||
get_full_data()
|
||||
return data['uptime'], dict2text(data['version']), dict2text(data['state']), dict2text(data['memory']), dict2text(data['platform']), data['torch'], dict2text(data['gpu']), list2text(data['optimizations']), data['crossattention'], dict2text(data['libs']), dict2text(data['repos']), dict2text(data['device']), data['models'], data['hypernetworks'], data['embeddings'], data['skipped'], data['loras'], data['lycos'], data['timestamp'], data
|
||||
return data['uptime'], dict2text(data['version']), dict2text(data['state']), dict2text(data['memory']), dict2text(data['platform']), data['torch'], dict2text(data['gpu']), list2text(data['optimizations']), data['crossattention'], data['backend'], data['pipeline'], dict2text(data['libs']), dict2text(data['repos']), dict2text(data['device']), data['models'], data['hypernetworks'], data['embeddings'], data['skipped'], data['loras'], data['lycos'], data['timestamp'], data
|
||||
|
||||
|
||||
### ui definition
|
||||
|
|
@ -423,6 +422,9 @@ def on_ui_tabs():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
platformtxt = gr.Textbox(dict2text(data['platform']), label = 'Platform', lines = len(data['platform']))
|
||||
with gr.Row():
|
||||
backendtxt = gr.Textbox(data['backend'], label = 'Backend')
|
||||
pipelinetxt = gr.Textbox(data['pipeline'], label = 'Pipeline')
|
||||
with gr.Column():
|
||||
torchtxt = gr.Textbox(data['torch'], label = 'Torch', lines = 1)
|
||||
gputxt = gr.Textbox(dict2text(data['gpu']), label = 'GPU', lines = len(data['gpu']))
|
||||
|
|
@ -431,7 +433,7 @@ def on_ui_tabs():
|
|||
attentiontxt = gr.Textbox(data['crossattention'], label = 'Cross-attention')
|
||||
with gr.Column():
|
||||
libstxt = gr.Textbox(dict2text(data['libs']), label = 'Libs', lines = len(data['libs']))
|
||||
repostxt = gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos']))
|
||||
repostxt = gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos']), visible = False)
|
||||
devtxt = gr.Textbox(dict2text(data['device']), label = 'Device Info', lines = len(data['device']))
|
||||
with gr.Box():
|
||||
with gr.Accordion('Benchmarks...', open = True, visible = True):
|
||||
|
|
@ -443,7 +445,7 @@ def on_ui_tabs():
|
|||
username = gr.Textbox(get_user, label = 'Username', placeholder='enter username for submission', elem_id='system_info_tab_username')
|
||||
note = gr.Textbox('', label = 'Note', placeholder='enter any additional notes', elem_id='system_info_tab_note')
|
||||
with gr.Column(scale=1):
|
||||
with FormRow():
|
||||
with gr.Row():
|
||||
global console_logging # pylint: disable=global-statement
|
||||
console_logging = gr.Checkbox(label = 'Console logging', value = False, elem_id = 'system_info_tab_console', interactive = True)
|
||||
warmup = gr.Checkbox(label = 'Perform warmup', value = True, elem_id = 'system_info_tab_warmup')
|
||||
|
|
@ -497,7 +499,7 @@ def on_ui_tabs():
|
|||
refresh_full_btn = gr.Button('Refresh data', elem_id = 'system_info_tab_refresh_full_btn', variant='primary').style()
|
||||
refresh_full_btn.click(refresh_info_full, show_progress = False,
|
||||
inputs = [],
|
||||
outputs = [uptimetxt, versiontxt, statetxt, memorytxt, platformtxt, torchtxt, gputxt, opttxt, attentiontxt, libstxt, repostxt, devtxt, models, hypernetworks, embeddings, skipped, loras, lycos, timestamp, js]
|
||||
outputs = [uptimetxt, versiontxt, statetxt, memorytxt, platformtxt, torchtxt, gputxt, opttxt, attentiontxt, backendtxt, pipelinetxt, libstxt, repostxt, devtxt, models, hypernetworks, embeddings, skipped, loras, lycos, timestamp, js]
|
||||
)
|
||||
interrupt_btn = gr.Button('Send interrupt', elem_id = 'system_info_tab_interrupt_btn', variant='primary')
|
||||
interrupt_btn.click(shared.state.interrupt, inputs = [], outputs = [])
|
||||
|
|
@ -563,7 +565,7 @@ def bench_init(username: str, note: str, warmup: bool, level: str, extra: bool):
|
|||
d[3] = dict2str(data['platform'])
|
||||
d[4] = f"torch:{data['torch']} {dict2str(data['libs'])}"
|
||||
d[5] = dict2str(data['gpu']) + f' {str(round(mem))}GB'
|
||||
d[6] = data['crossattention'] + ' ' + ','.join(data['optimizations'])
|
||||
d[6] = (data['pipeline'] + ' ' + data['crossattention'] + ' ' + ','.join(data['optimizations'])).strip()
|
||||
d[7] = shared.opts.data['sd_model_checkpoint']
|
||||
d[8] = username
|
||||
d[9] = note
|
||||
|
|
@ -607,13 +609,75 @@ def bench_refresh():
|
|||
return gr.HTML.update(value = bench_text)
|
||||
|
||||
|
||||
def on_app_started(_block, app): # register api
|
||||
### API
|
||||
|
||||
from typing import Optional # pylint: disable=wrong-import-order
|
||||
from fastapi import FastAPI, Depends # pylint: disable=wrong-import-order
|
||||
from pydantic import BaseModel, Field # pylint: disable=wrong-import-order,no-name-in-module
|
||||
|
||||
|
||||
class StatusReq(BaseModel): # definition of http request
|
||||
state: bool = Field(title="State", description="Get server state", default=False)
|
||||
memory: bool = Field(title="Memory", description="Get server memory status", default=False)
|
||||
full: bool = Field(title="FullInfo", description="Get full server info", default=False)
|
||||
refresh: bool = Field(title="FullInfo", description="Force refresh server info", default=False)
|
||||
|
||||
class StatusRes(BaseModel): # definition of http response
|
||||
version: dict = Field(title="Version", description="Server version")
|
||||
uptime: str = Field(title="Uptime", description="Server uptime")
|
||||
timestamp: str = Field(title="Timestamp", description="Data timestamp")
|
||||
state: Optional[dict] = Field(title="State", description="Server state")
|
||||
memory: Optional[dict] = Field(title="Memory", description="Server memory status")
|
||||
platform: Optional[dict] = Field(title="Platform", description="Server platform")
|
||||
torch: Optional[str] = Field(title="Torch", description="Torch version")
|
||||
gpu: Optional[dict] = Field(title="GPU", description="GPU info")
|
||||
optimizations: Optional[list] = Field(title="Optimizations", description="Memory optimizations")
|
||||
crossatention: Optional[str] = Field(title="CrossAttention", description="Cross-attention optimization")
|
||||
device: Optional[dict] = Field(title="Device", description="Device info")
|
||||
backend: Optional[str] = Field(title="Backend", description="Backend")
|
||||
pipeline: Optional[str] = Field(title="Pipeline", description="Pipeline")
|
||||
|
||||
def get_status_api(req: StatusReq = Depends()):
|
||||
if req.refresh:
|
||||
get_full_data()
|
||||
else:
|
||||
get_quick_data()
|
||||
res = StatusRes(
|
||||
version = data['version'],
|
||||
timestamp = data['timestamp'],
|
||||
uptime = data['uptime']
|
||||
)
|
||||
if req.state or req.full:
|
||||
res.state = data['state']
|
||||
if req.memory or req.full:
|
||||
res.memory = data['memory']
|
||||
if req.full:
|
||||
res.platform = data['platform']
|
||||
res.torch = data['torch']
|
||||
res.gpu = data['gpu']
|
||||
res.optimizations = data['optimizations']
|
||||
res.crossatention = data['crossattention']
|
||||
res.device = data['device']
|
||||
res.backend = data['backend']
|
||||
res.pipeline = data['pipeline']
|
||||
return res
|
||||
|
||||
|
||||
def register_api(app: FastAPI):
|
||||
app.add_api_route("/sdapi/v1/system-info/status", get_status_api, methods=["GET"], response_model=StatusRes)
|
||||
|
||||
|
||||
### Entry point
|
||||
|
||||
def on_app_started(_block, app): # register api
|
||||
register_api(app)
|
||||
"""
|
||||
@app.get("/sdapi/v1/system-info/status")
|
||||
async def sysinfo_api():
|
||||
get_quick_data()
|
||||
res = { 'state': data['state'], 'memory': data['memory'], 'timestamp': data['timestamp'] }
|
||||
return res
|
||||
"""
|
||||
|
||||
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue