commit
28e5c1aa02
|
|
@ -9,7 +9,7 @@ import logging
|
|||
from html.parser import HTMLParser
|
||||
import torch
|
||||
import gradio as gr
|
||||
from modules import paths, script_callbacks, sd_models, sd_samplers, shared, extensions, devices
|
||||
from modules import paths, script_callbacks, sd_models, sd_samplers, shared, extensions, devices, scripts
|
||||
import benchmark # pylint: disable=wrong-import-order
|
||||
|
||||
|
||||
|
|
@ -45,6 +45,7 @@ networks = {
|
|||
'loras': [],
|
||||
'lycos': [],
|
||||
}
|
||||
data_loaded = False
|
||||
|
||||
### benchmark globals
|
||||
|
||||
|
|
@ -244,15 +245,20 @@ def get_libs():
|
|||
}
|
||||
|
||||
|
||||
def run_git_command(cmd_args: list, cwd: str = None):
|
||||
try:
|
||||
res = subprocess.run(cmd_args, capture_output=True, cwd = cwd, check=True)
|
||||
return res.stdout.decode(encoding = 'utf8', errors='ignore').strip() if len(res.stdout) > 0 else ''
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
|
||||
def get_repos():
|
||||
try:
|
||||
repos = {}
|
||||
for key, val in paths.paths.items():
|
||||
try:
|
||||
cmd = f'git -C {val} log --pretty=format:"%h %ad" -1 --date=short'
|
||||
res = subprocess.run(f'{cmd} {val}', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
stdout = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
words = stdout.split(' ')
|
||||
words = run_git_command(['git', 'log', '--pretty=format:%h %ad', '-1', '--date=short'], cwd=val).split(' ')
|
||||
repos[key] = f'[{words[0]}] {words[1]}'
|
||||
except Exception:
|
||||
repos[key] = '(unknown)'
|
||||
|
|
@ -292,21 +298,21 @@ def get_torch():
|
|||
def get_version():
|
||||
version = {}
|
||||
try:
|
||||
res = subprocess.run('git log --pretty=format:"%h %ad" -1 --date=short', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
ver = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
githash, updated = ver.split(' ')
|
||||
res = subprocess.run('git remote get-url origin', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
origin = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
res = subprocess.run('git rev-parse --abbrev-ref HEAD', stdout = subprocess.PIPE, stderr = subprocess.PIPE, shell=True, check=True)
|
||||
branch = res.stdout.decode(encoding = 'utf8', errors='ignore') if len(res.stdout) > 0 else ''
|
||||
url = origin.replace('\n', '').removesuffix('.git') + '/tree/' + branch.replace('\n', '')
|
||||
app = origin.replace('\n', '').split('/')[-1]
|
||||
githash, updated = run_git_command(['git', 'log', '--pretty=format:%h %ad', '-1', '--date=short']).split(' ')
|
||||
tag = run_git_command(['git', 'describe', '--tags', '--exact-match'])
|
||||
tags = " ".join(run_git_command(['git', 'tag', '--points-at', 'HEAD']).split())
|
||||
origin = run_git_command(['git', 'remote', 'get-url', 'origin'])
|
||||
branch = run_git_command(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
url = origin.removesuffix('.git') + '/tree/' + branch
|
||||
app = origin.split('/')[-1]
|
||||
if app == 'automatic':
|
||||
app = 'SD.next'
|
||||
version = {
|
||||
'app': app,
|
||||
'updated': updated,
|
||||
'hash': githash,
|
||||
'tag': tag,
|
||||
'tags': tags,
|
||||
'url': url
|
||||
}
|
||||
except Exception:
|
||||
|
|
@ -394,6 +400,8 @@ def get_full_data():
|
|||
'models': get_models(),
|
||||
'loras': get_loras(),
|
||||
}
|
||||
global data_loaded # pylint: disable=global-statement
|
||||
data_loaded = True
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -661,6 +669,105 @@ def bench_refresh():
|
|||
return gr.HTML.update(value = bench_text)
|
||||
|
||||
|
||||
### metadata injection
|
||||
|
||||
AVAILABLE_FIELDS = {
|
||||
'version.app': ('version', 'app'),
|
||||
'version.hash': ('version', 'hash'),
|
||||
'version.tag': ('version', 'tag'),
|
||||
'version.tags': ('version', 'tags'),
|
||||
'version.updated': ('version', 'updated'),
|
||||
'version.url': ('version', 'url'),
|
||||
'platform.arch': ('platform', 'arch'),
|
||||
'platform.cpu': ('platform', 'cpu'),
|
||||
'platform.system': ('platform', 'system'),
|
||||
'platform.release': ('platform', 'release'),
|
||||
'platform.python': ('platform', 'python'),
|
||||
'gpu.device': ('gpu', 'device'),
|
||||
'gpu.cuda': ('gpu', 'cuda'),
|
||||
'gpu.cudnn': ('gpu', 'cudnn'),
|
||||
'gpu.hip': ('gpu', 'hip'),
|
||||
'gpu.ipex': ('gpu', 'ipex'),
|
||||
'gpu.openvino': ('gpu', 'openvino'),
|
||||
'torch.version': ('torch',),
|
||||
'crossattention': ('crossattention',),
|
||||
'backend': ('backend',),
|
||||
'pipeline': ('pipeline',),
|
||||
'flags': ('flags',),
|
||||
'libs.xformers': ('libs', 'xformers'),
|
||||
'libs.diffusers': ('libs', 'diffusers'),
|
||||
'libs.transformers': ('libs', 'transformers'),
|
||||
}
|
||||
|
||||
DEFAULT_METADATA_FIELDS = ['version.app', 'version.hash', 'platform.release', 'gpu.device', 'gpu.cuda', 'torch.version', 'crossattention']
|
||||
|
||||
|
||||
def get_value_from_data(field_path):
|
||||
try:
|
||||
source = data
|
||||
for field_name in field_path:
|
||||
if isinstance(source, dict):
|
||||
source = source.get(field_name)
|
||||
else:
|
||||
return None
|
||||
return source
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def build_system_info_metadata():
|
||||
selected_fields = shared.opts.data.get('system_info_metadata_fields', DEFAULT_METADATA_FIELDS)
|
||||
metadata = {}
|
||||
for field_name in selected_fields:
|
||||
if field_name in AVAILABLE_FIELDS:
|
||||
field_path = AVAILABLE_FIELDS[field_name]
|
||||
value = get_value_from_data(field_path)
|
||||
if value is not None:
|
||||
metadata[field_name] = value if not isinstance(value, list) else ' '.join(value)
|
||||
return metadata if metadata else None
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
shared.options_templates.update(
|
||||
shared.options_section(
|
||||
('infotext', "Infotext", "ui"),
|
||||
{
|
||||
'system_info_metadata_enabled': shared.OptionInfo(
|
||||
False,
|
||||
"Add system information to infotext",
|
||||
gr.Checkbox,
|
||||
),
|
||||
'system_info_metadata_fields': shared.OptionInfo(
|
||||
DEFAULT_METADATA_FIELDS,
|
||||
"System information to include in infotext",
|
||||
gr.Dropdown,
|
||||
{"choices": list(AVAILABLE_FIELDS.keys()), "multiselect": True},
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SystemInfoMetadataScript(scripts.Script):
|
||||
def title(self):
|
||||
return "System Info"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def postprocess_image(self, p, pp):
|
||||
if not shared.opts.data.get('system_info_metadata_enabled', False):
|
||||
return
|
||||
try:
|
||||
if not data_loaded:
|
||||
get_full_data()
|
||||
system_info = build_system_info_metadata()
|
||||
if system_info:
|
||||
p.extra_generation_params["System Info"] = system_info
|
||||
except Exception as e:
|
||||
log.warning(f'sd-extension-system-info: failed to add metadata: {str(e)}')
|
||||
|
||||
|
||||
### API
|
||||
|
||||
from typing import Optional # pylint: disable=wrong-import-order
|
||||
|
|
@ -744,3 +851,4 @@ except Exception:
|
|||
if standalone:
|
||||
script_callbacks.on_ui_tabs(create_ui)
|
||||
script_callbacks.on_app_started(on_app_started)
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
|
|
|||
Loading…
Reference in New Issue