import os
import sys
import time
import json
import platform
import subprocess
import datetime
import logging
from html.parser import HTMLParser
import torch
import gradio as gr
from modules import paths, script_callbacks, sd_models, sd_samplers, shared, extensions, devices, scripts
import benchmark # pylint: disable=wrong-import-order
### system info globals
log = logging.getLogger('sd')
data = {
'date': '',
'timestamp': '',
'uptime': '',
'version': {},
'torch': '',
'gpu': {},
'state': {},
'memory': {},
'flags': [],
'libs': {},
'repos': {},
'device': {},
'schedulers': [],
'extensions': [],
'platform': '',
'crossattention': '',
'backend': getattr(devices, 'backend', ''),
'pipeline': ('native' if shared.native else 'original') if hasattr(shared, 'native') else 'a1111',
'model': {},
}
networks = {
'models': [],
'hypernetworks': [],
'embeddings': [],
'skipped': [],
'loras': [],
'lycos': [],
}
data_loaded = False
### benchmark globals
bench_text = ''
bench_file = os.path.join(os.path.dirname(__file__), 'benchmark-data-local.json')
bench_headers = ['timestamp', 'it/s', 'version', 'system', 'libraries', 'gpu', 'flags', 'settings', 'username', 'note', 'hash']
bench_data = []
### system info module
def get_user():
try:
return os.getlogin()
except Exception:
pass
if 'USER' in os.environ:
return os.environ['USER']
if 'USERNAME' in os.environ:
return os.environ['USERNAME']
if sys.platform != 'win32':
try:
import pwd
return pwd.getpwuid(os.getuid())[0] # pylint: disable=no-member
except Exception:
pass
return ''
def get_gpu():
if not torch.cuda.is_available():
try:
if shared.cmd_opts.use_openvino:
from modules.intel.openvino import get_openvino_device
return {
'device': get_openvino_device(),
'openvino': get_package_version("openvino")
}
else:
return {}
except Exception:
return {}
else:
try:
if hasattr(torch, "xpu") and torch.xpu.is_available():
return {
'device': f'{torch.xpu.get_device_name(torch.xpu.current_device())} ({str(torch.xpu.device_count())})',
'ipex': get_package_version('intel-extension-for-pytorch'),
}
elif torch.version.cuda:
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())}) ({torch.cuda.get_arch_list()[-1]}) {str(torch.cuda.get_device_capability(shared.device))}',
'cuda': torch.version.cuda,
'cudnn': torch.backends.cudnn.version(),
'driver': get_driver(),
}
elif torch.version.hip:
return {
'device': f'{torch.cuda.get_device_name(torch.cuda.current_device())} ({str(torch.cuda.device_count())})',
'hip': torch.version.hip,
}
else:
return {
'device': 'unknown'
}
except Exception as e:
return { 'error': e }
def get_driver():
if torch.cuda.is_available() and torch.version.cuda:
try:
result = subprocess.run('nvidia-smi --query-gpu=driver_version --format=csv,noheader', shell=True, check=False, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
version = result.stdout.decode(encoding="utf8", errors="ignore").strip()
return version
except Exception:
return ''
else:
return ''
def get_uptime():
s = vars(shared.state)
return time.strftime('%c', time.localtime(s.get('server_start', time.time())))
class HTMLFilter(HTMLParser):
text = ""
def handle_data(self, data): # pylint: disable=redefined-outer-name
self.text += data
def get_state():
s = vars(shared.state)
flags = 'skipped ' if s.get('skipped', False) else ''
flags += 'interrupted ' if s.get('interrupted', False) else ''
flags += 'needs restart' if s.get('need_restart', False) else ''
text = s.get('textinfo', '')
if text is not None and len(text) > 0:
f = HTMLFilter()
f.feed(text)
text = os.linesep.join([s for s in f.text.splitlines() if s])
return {
'started': time.strftime('%c', time.localtime(s.get('time_start', time.time()))),
'step': f'{s.get("sampling_step", 0)} / {s.get("sampling_steps", 0)}',
'jobs': f'{s.get("job_no", 0)} / {s.get("job_count", 0)}', # pylint: disable=consider-using-f-string
'flags': flags,
'job': s.get('job', ''),
'text-info': text,
}
def get_docker_limit():
try:
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r', encoding='utf8') as f:
docker_limit = float(f.read())
except Exception:
docker_limit = sys.float_info.max
return docker_limit
def get_memory():
def gb(val: float):
return round(val / 1024 / 1024 / 1024, 2)
mem = {}
try:
import psutil
process = psutil.Process(os.getpid())
res = process.memory_info()
ram_total = 100 * res.rss / process.memory_percent()
ram_total = min(ram_total, get_docker_limit())
ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) }
mem.update({ 'ram': ram })
except Exception as e:
mem.update({ 'ram': e })
if torch.cuda.is_available():
try:
s = torch.cuda.mem_get_info()
gpu = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) }
s = dict(torch.cuda.memory_stats(shared.device))
allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) }
reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) }
active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) }
inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) }
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
mem.update({
'gpu': gpu,
'gpu-active': active,
'gpu-allocated': allocated,
'gpu-reserved': reserved,
'gpu-inactive': inactive,
'events': warnings,
'utilization': 0,
})
mem.update({ 'utilization': torch.cuda.utilization() }) # do this one separately as it may fail
except Exception:
pass
else:
try:
from openvino.runtime import Core as OpenVINO_Core
from modules.intel.openvino import get_device as get_raw_openvino_device
openvino_core = OpenVINO_Core()
mem.update({
'gpu': { 'total': gb(openvino_core.get_property(get_raw_openvino_device(), 'GPU_DEVICE_TOTAL_MEM_SIZE')) },
})
except Exception:
pass
return mem
def get_flags():
ram = []
if getattr(shared.cmd_opts, 'medvram', False):
ram.append('medvram')
if getattr(shared.cmd_opts, 'medvram_sdxl', False):
ram.append('medvram-sdxl')
if getattr(shared.cmd_opts, 'lowvram', False):
ram.append('lowvram')
if getattr(shared.cmd_opts, 'lowvam', False):
ram.append('lowram')
if len(ram) == 0:
ram.append('none')
return ram
def get_package_version(pkg: str):
try:
import importlib.metadata
return importlib.metadata.version(pkg)
except Exception:
return ''
def get_libs():
return {
'xformers': get_package_version('xformers'),
'diffusers': get_package_version('diffusers'),
'transformers': get_package_version('transformers'),
}
def run_git_command(cmd_args: list, cwd: str | None = None):
try:
res = subprocess.run(cmd_args, capture_output=True, cwd = cwd, check=True)
return res.stdout.decode(encoding = 'utf8', errors='ignore').strip() if len(res.stdout) > 0 else ''
except Exception:
return ''
def get_repos():
try:
repos = {}
for key, val in paths.paths.items():
try:
words = run_git_command(['git', 'log', '--pretty=format:%h %ad', '-1', '--date=short'], cwd=val).split(' ')
repos[key] = f'[{words[0]}] {words[1]}'
except Exception:
repos[key] = '(unknown)'
return repos
except Exception as e:
return { 'error': e }
def get_platform():
try:
if platform.system() == 'Windows':
release = platform.platform(aliased = True, terse = False)
else:
release = platform.release()
return {
# 'host': platform.node(),
'arch': platform.machine(),
'cpu': platform.processor(),
'system': platform.system(),
'release': release,
# 'platform': platform.platform(aliased = True, terse = False),
# 'version': platform.version(),
'python': platform.python_version(),
}
except Exception as e:
return { 'error': e }
def get_torch():
try:
ver = torch.__long_version__
except Exception:
ver = torch.__version__
return f"{ver}"
def get_version():
version = {}
try:
githash, updated = run_git_command(['git', 'log', '--pretty=format:%h %ad', '-1', '--date=short']).split(' ')
tag = run_git_command(['git', 'describe', '--tags', '--exact-match'])
tags = " ".join(run_git_command(['git', 'tag', '--points-at', 'HEAD']).split())
origin = run_git_command(['git', 'remote', 'get-url', 'origin'])
branch = run_git_command(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
url = origin.removesuffix('.git') + '/tree/' + branch
app = origin.split('/')[-1]
if app == 'automatic':
app = 'SD.next'
version = {
'app': app,
'updated': updated,
'hash': githash,
'tag': tag,
'tags': tags,
'url': url
}
except Exception:
pass
return version
def get_crossattention():
try:
ca = getattr(shared.opts, 'cross_attention_optimization', None)
return ca
except Exception:
return 'unknown'
def get_model():
obj = {
'base': shared.opts.data.get('sd_model_checkpoint', 'none'),
'refiner': shared.opts.data.get('sd_model_refiner', 'none'),
'vae': shared.opts.data.get('sd_vae', 'none'),
'te': shared.opts.data.get('sd_text_encoder', 'none'),
'unet': shared.opts.data.get('sd_unet', 'none'),
}
return obj
def get_models():
return sorted([x.title for x in sd_models.checkpoints_list.values()])
def get_samplers():
return sorted([sampler[0] for sampler in sd_samplers.all_samplers])
def get_extensions():
return sorted([f"{e.name} ({'enabled' if e.enabled else 'disabled'}{' builtin' if e.is_builtin else ''})" for e in extensions.extensions])
def get_loras():
loras = []
try:
sys.path.append(extensions.extensions_builtin_dir)
from Lora import lora # pylint: disable=E0401
loras = sorted(lora.available_loras.keys())
except Exception:
pass
return loras
def get_device():
dev = {
'active': str(devices.device),
'dtype': str(devices.dtype),
'vae': str(devices.dtype_vae),
'unet': str(devices.dtype_unet),
}
return dev
def get_full_data():
global data # pylint: disable=global-statement
data = {
'date': datetime.datetime.now().strftime('%c'),
'timestamp': datetime.datetime.now().strftime('%X'),
'uptime': get_uptime(),
'version': get_version(),
'torch': get_torch(),
'gpu': get_gpu(),
'state': get_state(),
'memory': get_memory(),
'flags': get_flags(),
'libs': get_libs(),
'repos': get_repos(),
'device': get_device(),
'model': get_model(),
'schedulers': get_samplers(),
'extensions': get_extensions(),
'platform': get_platform(),
'crossattention': get_crossattention(),
'backend': getattr(devices, 'backend', ''),
'pipeline': ('native' if shared.native else 'original') if hasattr(shared, 'native') else 'a1111',
}
global networks # pylint: disable=global-statement
networks = {
'models': get_models(),
'loras': get_loras(),
}
global data_loaded # pylint: disable=global-statement
data_loaded = True
return data
def get_quick_data():
data['timestamp'] = datetime.datetime.now().strftime('%X')
data['state'] = get_state()
data['memory'] = get_memory()
data['model'] = get_model()
def list2text(lst: list):
return '\n'.join(lst)
def dict2str(d: dict):
arr = [f'{name}:{d[name]}' for i, name in enumerate(d)]
return ' '.join(arr)
def dict2text(d: dict):
arr = ['{name}: {val}'.format(name = name, val = d[name] if type(d[name]) is not dict else dict2str(d[name])) for i, name in enumerate(d)] # pylint: disable=consider-using-f-string
return list2text(arr)
def refresh_info_quick(_old_data = None):
get_quick_data()
return dict2text(data['state']), dict2text(data['memory']), data['crossattention'], data['timestamp'], data
def refresh_info_full():
get_full_data()
return data['uptime'], dict2text(data['version']), dict2text(data['state']), dict2text(data['memory']), dict2text(data['platform']), data['torch'], dict2text(data['gpu']), list2text(data['flags']), data['crossattention'], data['backend'], data['pipeline'], dict2text(data['libs']), dict2text(data['repos']), dict2text(data['device']), dict2text(data['model']), networks['models'], networks['loras'], data['timestamp'], data
### ui definition
def create_ui(blocks: gr.Blocks = None):
if not standalone:
from modules.ui import ui_system_tabs # pylint: disable=redefined-outer-name
else:
ui_system_tabs = None
with gr.Blocks(analytics_enabled = False) if standalone else blocks as system_info:
with gr.Row(elem_id = 'system_info'):
with ui_system_tabs or gr.Tabs(elem_id = 'system_info_tabs'):
with gr.TabItem('System Info'):
with gr.Row():
timestamp = gr.Textbox(value=data['timestamp'], label = '', elem_id = 'system_info_tab_last_update', container=False)
refresh_quick_btn = gr.Button('Refresh state', elem_id = 'system_info_tab_refresh_btn', visible = False) # quick refresh is used from js interval
refresh_full_btn = gr.Button('Refresh data', elem_id = 'system_info_tab_refresh_full_btn', variant='primary')
interrupt_btn = gr.Button('Send interrupt', elem_id = 'system_info_tab_interrupt_btn', variant='primary')
with gr.Row():
with gr.Column():
uptimetxt = gr.Textbox(data['uptime'], label = 'Server start time', lines = 1)
versiontxt = gr.Textbox(dict2text(data['version']), label = 'Version', lines = len(data['version']))
with gr.Column():
statetxt = gr.Textbox(dict2text(data['state']), label = 'State', lines = len(data['state']))
with gr.Column():
memorytxt = gr.Textbox(dict2text(data['memory']), label = 'Memory', lines = len(data['memory']))
with gr.Row():
with gr.Column():
platformtxt = gr.Textbox(dict2text(data['platform']), label = 'Platform', lines = len(data['platform']))
with gr.Row():
backendtxt = gr.Textbox(data['backend'], label = 'Backend')
pipelinetxt = gr.Textbox(data['pipeline'], label = 'Pipeline')
with gr.Column():
torchtxt = gr.Textbox(data['torch'], label = 'Torch', lines = 1)
gputxt = gr.Textbox(dict2text(data['gpu']), label = 'GPU', lines = len(data['gpu']))
with gr.Row():
opttxt = gr.Textbox(list2text(data['flags']), label = 'Memory optimization')
attentiontxt = gr.Textbox(data['crossattention'], label = 'Cross-attention')
with gr.Column():
libstxt = gr.Textbox(dict2text(data['libs']), label = 'Libs', lines = len(data['libs']))
repostxt = gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos']), visible = False)
devtxt = gr.Textbox(dict2text(data['device']), label = 'Device Info', lines = len(data['device']))
modeltxt = gr.Textbox(dict2text(data['model']), label = 'Model Info', lines = len(data['model']))
with gr.Row():
gr.HTML('Load