diff --git a/README.md b/README.md index 9ec481b..f01830f 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,33 @@ -# sd-extension-system-info -System Info Extension for SD WebUI +# Info Tab extensions for SD Automatic WebUI + +Creates a top-level **Info** tab in Automatic WebUI with + +State & memory info are auto-updated every second if tab is visible (no updates are performed when tab is not visible) +All other information is updated once upon WebUI load and can be force refreshed if required + +## Current information: +- Version +- Current Model & VAE +- Current State +- Current Memory statistics + +## System data: +- Platform details +- Torch & CUDA details +- Active CMD flags such as `low-vram` or `med-vram` +- Versions of critical libraries +- Versions of dependent repositories + + ![screenshot](system-info.jpg) + +## Models +- Models +- Hypernetworks +- Embeddings + + ![screenshot](system-info-models.jpg) + +## Info Object +- System object is available as JSON for quick passing of information + + ![screenshot](system-info-json.jpg) diff --git a/install.py b/install.py new file mode 100644 index 0000000..5416cb0 --- /dev/null +++ b/install.py @@ -0,0 +1 @@ +import launch diff --git a/javascript/system-info.js b/javascript/system-info.js new file mode 100644 index 0000000..23c19b3 --- /dev/null +++ b/javascript/system-info.js @@ -0,0 +1,39 @@ +// this would not be needed if automatic run gradio with loop enabled + +let loaded = false; +let interval; + +function refresh() { + const btn = gradioApp().getElementById('system_info_tab_refresh_btn') // we could cache this dom element + if (!btn) return // but ui may get destroyed + btn.click() // actual refresh is done from python code we just trigger it but simulating button click +} + +function onHidden() { // stop refresh interval when tab is not visible + if (!interval) return + clearInterval(interval); + interval = undefined; +} + +function onVisible() { // start refresh interval tab is when visible + if (interval) return // interval already started so dont start it again + interval = setInterval(refresh, 1000); +} + +function initLoading() { // triggered on gradio change to monitor when ui gets sufficiently constructed + if (loaded) return + const block = gradioApp().getElementById('system_info_tab'); + if (!block) return + intersectionObserver = new IntersectionObserver((entries) => { + if (entries[0].intersectionRatio <= 0) onHidden(); + if (entries[0].intersectionRatio > 0) onVisible(); + }); + intersectionObserver.observe(block); // monitor visibility of tab +} + +function initInitial() { // just setup monitor for gradio events + const mutationObserver = new MutationObserver(initLoading) + mutationObserver.observe(gradioApp(), { childList: true, subtree: true }); // monitor changes to gradio +} + +document.addEventListener('DOMContentLoaded', initInitial); diff --git a/preload.py b/preload.py new file mode 100644 index 0000000..b37effc --- /dev/null +++ b/preload.py @@ -0,0 +1,2 @@ +def preload(parser): + pass diff --git a/scripts/system-info.py b/scripts/system-info.py new file mode 100644 index 0000000..6fc9b65 --- /dev/null +++ b/scripts/system-info.py @@ -0,0 +1,307 @@ +import datetime +import os +import platform +import subprocess +import time + +import accelerate +import gradio as gr +import psutil +import pytorch_lightning +import safetensors +import torch +import transformers +from modules import paths, script_callbacks, sd_hijack, sd_models, sd_samplers, shared + +data = {} + +def get_cuda(): + if not torch.cuda.is_available(): + return {} + else: + try: + return { + 'version': torch.version.cuda, + 'devices': torch.cuda.device_count(), + 'current': torch.cuda.get_device_name(torch.cuda.current_device()), + 'arch': torch.cuda.get_arch_list()[-1], + 'capability': torch.cuda.get_device_capability(shared.device), + } + except Exception as e: + return { 'error': e } + +def get_uptime(): + s = vars(shared.state) + return time.strftime('%c', time.localtime(s.get('server_start', time.time()))) + +def get_state(): + s = vars(shared.state) + flags = 'skipped ' if s.get('skipped', False) else '' + flags += 'interrupted ' if s.get('interrupted', False) else '' + flags += 'needs restart' if s.get('need_restart', False) else '' + return { + 'started': time.strftime('%c', time.localtime(s.get('time_start', time.time()))), + 'step': f'{s.get("sampling_step", 0)} / {s.get("sampling_steps", 0)}', + 'jobs': f'{s.get("job_no", 0)} / {s.get("job_count", 0)}', # pylint: disable=consider-using-f-string + 'flags': flags, + 'job': s.get('job', ''), + 'text': s.get('textinfo', ''), + } + +def get_memory(): + def gb(val: float): + return round(val / 1024 / 1024 / 1024, 2) + mem = {} + try: + process = psutil.Process(os.getpid()) + res = process.memory_info() + ram_total = 100 * res.rss / process.memory_percent() + ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) } + mem.update({ 'ram': ram }) + except Exception as e: + mem.update({ 'ram': e }) + try: + if torch.cuda.is_available(): + s = torch.cuda.mem_get_info() + gpu = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) } + s = dict(torch.cuda.memory_stats(shared.device)) + allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) } + reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) } + active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) } + inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) } + warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } + mem.update({ + 'gpu': gpu, + 'gpu-active': active, + 'gpu-allocated': allocated, + 'gpu-reserved': reserved, + 'gpu-inactive': inactive, + 'events': warnings, + }) + except: + pass + return mem + +def get_optimizations(): + ram = [] + if shared.cmd_opts.medvram: + ram.append('medvram') + if shared.cmd_opts.lowvram: + ram.append('lowvram') + if shared.cmd_opts.lowram: + ram.append('lowram') + if len(ram) == 0: + ram.append('none') + return ram + +def get_libs(): + return { + 'xformers': shared.xformers_available, + 'accelerate': accelerate.__version__, + 'transformers': transformers.__version__, + 'safetensors': safetensors.__version__, + 'lightning': pytorch_lightning.__version__, + } + +def get_repos(): + repos = {} + for key, val in paths.paths.items(): + try: + cmd = f'git -C {val} log --pretty=format:"%h %ad" -1 --date=short' + res = subprocess.run(f'{cmd} {val}', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + stdout = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' + words = stdout.split(' ') + repos[key] = f'[{words[0]}] {words[1]}' + except: + repos[key] = '(unknown)' + return repos + +def get_model(): + try: + return { + 'configured': shared.opts.data['sd_model_checkpoint'], + 'current': shared.sd_model.sd_checkpoint_info.title, + 'configuration': os.path.basename(sd_models.find_checkpoint_config(shared.sd_model.sd_checkpoint_info)), + } + except: + return { 'error': 'no model config found' } + +def get_vae(): + try: + return { + 'configured': shared.opts.sd_vae, + 'current': os.path.basename(shared.sd_vae.loaded_vae_file), + } + except: + return { 'error': 'no vae config found' } + +def get_platform(): + try: + return { + 'host': platform.node(), + 'arch': platform.machine(), + 'cpu': platform.processor(), + 'system': platform.system(), + 'platform': platform.platform(aliased = True, terse = False), + 'release': platform.release(), + 'version': platform.version(), + 'python': platform.python_version(), + } + except Exception as e: + return { 'error': e } + +def get_torch(): + return { + 'version': torch.__version__, + 'precision': shared.cmd_opts.precision + (' fp32' if shared.cmd_opts.no_half else ' fp16'), + } + +def get_version(): + try: + res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' + githash, updated = ver.split(' ') + res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' + res = subprocess.run('git branch --show-current', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True) + branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else '' + return { + 'updated': updated, + 'hash': githash, + 'origin': origin.replace('\n', ''), + 'branch': branch.replace('\n', ''), + } + except: + return {} + +def get_embeddings(): + return [f'{v} ({sd_hijack.model_hijack.embedding_db.word_embeddings[v].vectors})' for i, v in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings)] + +def get_skipped(): + return [k for k in sd_hijack.model_hijack.embedding_db.skipped_embeddings.keys()] + +def get_crossattention(): + try: + return sd_hijack.model_hijack.optimization_method + except: + return 'unknown' + +def get_models(): + return [x.title for x in sd_models.checkpoints_list.values()] + +def get_samplers(): + return [sampler[0] for sampler in sd_samplers.all_samplers] + +def get_full_data(): + global data # pylint: disable=global-statement + data = { + 'date': datetime.datetime.now().strftime('%c'), + 'timestamp': datetime.datetime.now().strftime('%X'), + 'uptime': get_uptime(), + 'version': get_version(), + 'model': get_model(), + 'vae': get_vae(), + 'torch': get_torch(), + 'cuda': get_cuda(), + 'state': get_state(), + 'memory': get_memory(), + 'optimizations': get_optimizations(), + 'libs': get_libs(), + 'repos': get_repos(), + 'models': get_models(), + 'hypernetworks': [name for name in shared.hypernetworks], + 'embeddings': get_embeddings(), + 'skipped': get_skipped(), + 'schedulers': get_samplers(), + 'platform': get_platform(), + 'crossattention': get_crossattention(), + 'api': shared.cmd_opts.api, + 'webui': not shared.cmd_opts.nowebui, + } + return data + +def get_quick_data(): + data['timestamp'] = datetime.datetime.now().strftime('%X') + data['state'] = get_state() + data['memory'] = get_memory() + +def list2text(lst: list): + return '\n'.join(lst) + +def dict2str(d: dict): + arr = [f'{name}: {d[name]}' for i, name in enumerate(d)] + return ' '.join(arr) + +def dict2text(d: dict): + arr = ['{name}: {val}'.format(name = name, val = d[name] if not type(d[name]) is dict else dict2str(d[name])) for i, name in enumerate(d)] # pylint: disable=consider-using-f-string + return list2text(arr) + +def refresh_info_quick(): + get_quick_data() + return dict2text(data['state']), dict2text(data['memory']), data['timestamp'], data + +def refresh_info_full(): + get_full_data() + return dict2text(data['state']), dict2text(data['memory']), data['models'], data['hypernetworks'], data['embeddings'], data['skipped'], dict2text(data['model']), dict2text(data['vae']), data['timestamp'], data + +def on_ui_tabs(): + get_full_data() + with gr.Blocks(analytics_enabled = False) as system_info_tab: + with gr.Row(elem_id = 'system_info_tab'): + with gr.Column(scale = 9): + with gr.Box(): + with gr.Row(): + with gr.Column(): + gr.Textbox(data['uptime'], label = 'Server start time', lines = 1) + gr.Textbox(dict2text(data['version']), label = 'Version', lines = len(data['version'])) + with gr.Column(): + model = gr.Textbox(dict2text(data['model']), label = 'Model', lines = len(data['model'])) + vae = gr.Textbox(dict2text(data['vae']), label = 'VAE', lines = len(data['vae'])) + with gr.Column(): + state = gr.Textbox(dict2text(data['state']), label = 'State', lines = len(data['state'])) + with gr.Column(): + memory = gr.Textbox(dict2text(data['memory']), label = 'Memory', lines = len(data['memory'])) + with gr.Box(): + with gr.Accordion('System data', open = True, visible = True): + with gr.Row(): + with gr.Column(): + gr.Textbox(dict2text(data['platform']), label = 'Platform', lines = len(data['platform'])) + with gr.Column(): + gr.Textbox(dict2text(data['torch']), label = 'Torch', lines = len(data['torch'])) + gr.Textbox(dict2text(data['cuda']), label = 'CUDA', lines = len(data['cuda'])) + with gr.Row(): + gr.Textbox(list2text(data['optimizations']), label = 'Memory optimization') + gr.Textbox(data['crossattention'], label = 'Cross-attention') + gr.Textbox((data['api']), label = 'API') + with gr.Column(): + gr.Textbox(dict2text(data['libs']), label = 'Libs', lines = len(data['libs'])) + gr.Textbox(dict2text(data['repos']), label = 'Repos', lines = len(data['repos'])) + with gr.Box(): + with gr.Accordion('Models...', open = False, visible = True): + with gr.Row(): + with gr.Column(): + models = gr.JSON(data['models'], label = 'Models', lines = len(data['models'])) + hypernetworks = gr.JSON(data['hypernetworks'], label = 'Hypernetworks', lines = len(data['hypernetworks'])) + with gr.Column(): + embeddings = gr.JSON(data['embeddings'], label = 'Embeddings: loaded', lines = len(data['embeddings'])) + skipped = gr.JSON(data['skipped'], label = 'Embeddings: skipped', lines = len(data['embeddings'])) + with gr.Box(): + with gr.Accordion('Info object', open = False, visible = True): + # reduce json data to avoid private info + data.pop('models', None) + data.pop('embeddings', None) + data.pop('skipped', None) + data.pop('hypernetworks', None) + data.pop('schedulers', None) + json = gr.JSON(data) + with gr.Column(scale = 1, min_width = 120): + timestamp = gr.Text(data['timestamp'], label = '', elem_id = 'system_info_tab_last_update') + refresh_quick = gr.Button('Refresh state', elem_id = 'system_info_tab_refresh_btn', visible = False).style(full_width = False) # quick refresh is used from js interval + refresh_quick.click(refresh_info_quick, inputs = [], outputs = [state, memory, timestamp, json]) + refresh_full = gr.Button('Refresh data', elem_id = 'system_info_tab_refresh_full_btn').style(full_width = False) + refresh_full.click(refresh_info_full, inputs = [], outputs = [state, memory, models, hypernetworks, embeddings, skipped, model, vae, timestamp, json]) + interrupt = gr.Button('Send interrupt', elem_id = 'system_info_tab_interrupt_btn') + interrupt.click(shared.state.interrupt, inputs = [], outputs = []) + return (system_info_tab, 'System Info', 'system_info_tab'), + +script_callbacks.on_ui_tabs(on_ui_tabs) diff --git a/system-info-json.jpg b/system-info-json.jpg new file mode 100644 index 0000000..4e5d3b2 Binary files /dev/null and b/system-info-json.jpg differ diff --git a/system-info-models.jpg b/system-info-models.jpg new file mode 100644 index 0000000..9d31700 Binary files /dev/null and b/system-info-models.jpg differ diff --git a/system-info.jpg b/system-info.jpg new file mode 100644 index 0000000..428114e Binary files /dev/null and b/system-info.jpg differ