mirror of https://github.com/vladmandic/automatic
display lora tag in networks and do not filter
parent
2d0ad9ae61
commit
3ceb998338
|
|
@ -36,6 +36,8 @@
|
|||
- if lora contains no tags, lora name itself will be used as a tag
|
||||
- if prompt contains `_tags_` it will be used as placeholder for replacement, otherwise tags will be appended
|
||||
- used tags are also logged and registered in image metadata
|
||||
- loras are no longer filtered per detected type vs loaded model type as its unreliable
|
||||
- loras display in networks now shows possible version in top-left corner
|
||||
- correct using of `extra_networks_default_multiplier` if not scale is specified
|
||||
- always keep lora on gpu
|
||||
- **text encoder**:
|
||||
|
|
|
|||
|
|
@ -7,22 +7,27 @@ from rich import print # pylint: disable=redefined-builtin
|
|||
|
||||
def read_metadata(fn):
|
||||
res = {}
|
||||
if not fn.lower().endswith(".safetensors"):
|
||||
return
|
||||
with open(fn, mode="rb") as f:
|
||||
metadata_len = f.read(8)
|
||||
metadata_len = int.from_bytes(metadata_len, "little")
|
||||
json_start = f.read(2)
|
||||
if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
|
||||
print(f"Not a valid safetensors file: {fn}")
|
||||
json_data = json_start + f.read(metadata_len-2)
|
||||
json_obj = json.loads(json_data)
|
||||
for k, v in json_obj.get("__metadata__", {}).items():
|
||||
res[k] = v
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
pass
|
||||
print(f"{fn}: {json.dumps(res, indent=4)}")
|
||||
try:
|
||||
metadata_len = f.read(8)
|
||||
metadata_len = int.from_bytes(metadata_len, "little")
|
||||
json_start = f.read(2)
|
||||
if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
|
||||
print(f"Not a valid safetensors file: {fn}")
|
||||
json_data = json_start + f.read(metadata_len-2)
|
||||
json_obj = json.loads(json_data)
|
||||
for k, v in json_obj.get("__metadata__", {}).items():
|
||||
res[k] = v
|
||||
if isinstance(v, str) and v[0:1] == '{':
|
||||
try:
|
||||
res[k] = json.loads(v)
|
||||
except Exception:
|
||||
pass
|
||||
print(f"{fn}: {json.dumps(res, indent=4)}")
|
||||
except Exception:
|
||||
print(f"{fn}: cannot read metadata")
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@ class SdVersion(enum.Enum):
|
|||
Unknown = 1
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SD3 = 3
|
||||
SDXL = 4
|
||||
SC = 5
|
||||
F1 = 6
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
|
|
@ -40,13 +43,32 @@ class NetworkOnDisk:
|
|||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return SdVersion.SD2
|
||||
elif len(self.metadata):
|
||||
return SdVersion.SD1
|
||||
return SdVersion.Unknown
|
||||
base = str(self.metadata.get('ss_base_model_version', "")).lower()
|
||||
arch = str(self.metadata.get('modelspec.architecture', "")).lower()
|
||||
if base.startswith("sd_v1"):
|
||||
return 'sd1'
|
||||
if base.startswith("sdxl"):
|
||||
return 'xl'
|
||||
if base.startswith("stable_cascade"):
|
||||
return 'sc'
|
||||
if base.startswith("sd3"):
|
||||
return 'sd3'
|
||||
if base.startswith("flux"):
|
||||
return 'f1'
|
||||
|
||||
if arch.startswith("stable-diffusion-v1"):
|
||||
return 'sd1'
|
||||
if arch.startswith("stable-diffusion-xl"):
|
||||
return 'xl'
|
||||
if arch.startswith("stable-cascade"):
|
||||
return 'sc'
|
||||
if arch.startswith("flux"):
|
||||
return 'f1'
|
||||
|
||||
if str(self.metadata.get('ss_v2', "")) == "True":
|
||||
return 'sd2'
|
||||
|
||||
return ''
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v or ''
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import json
|
||||
import concurrent
|
||||
import network
|
||||
import networks
|
||||
from modules import shared, ui_extra_networks
|
||||
|
||||
|
|
@ -21,24 +20,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||
l = networks.available_networks.get(name)
|
||||
if l is None:
|
||||
shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered')
|
||||
print(networks.available_networks)
|
||||
return None
|
||||
try:
|
||||
# path, _ext = os.path.splitext(l.filename)
|
||||
name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
|
||||
if not shared.native:
|
||||
if l.sd_version == network.SdVersion.SDXL:
|
||||
return None
|
||||
elif shared.native:
|
||||
if shared.sd_model_type == 'none': # return all when model is not loaded
|
||||
pass
|
||||
elif shared.sd_model_type == 'sdxl':
|
||||
if l.sd_version == network.SdVersion.SD1 or l.sd_version == network.SdVersion.SD2:
|
||||
return None
|
||||
elif shared.sd_model_type == 'sd':
|
||||
if l.sd_version == network.SdVersion.SDXL:
|
||||
return None
|
||||
|
||||
item = {
|
||||
"type": 'Lora',
|
||||
"name": name,
|
||||
|
|
@ -48,22 +33,24 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
|
||||
"mtime": os.path.getmtime(l.filename),
|
||||
"size": os.path.getsize(l.filename),
|
||||
"version": l.sd_version,
|
||||
}
|
||||
info = self.find_info(l.filename)
|
||||
|
||||
tags = {}
|
||||
modelspec_tags = l.metadata.get('modelspec.tags', {}) if l.metadata is not None else {}
|
||||
possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {} # tags from model metedata
|
||||
possible_tags.update(modelspec_tags)
|
||||
if isinstance(possible_tags, str):
|
||||
possible_tags = {}
|
||||
for k, v in possible_tags.items():
|
||||
words = k.split('_', 1) if '_' in k else [v, k]
|
||||
words = [str(w).replace('.json', '') for w in words]
|
||||
if words[0] == '{}':
|
||||
words[0] = 0
|
||||
tag = ' '.join(words[1:]).lower()
|
||||
tags[tag] = words[0]
|
||||
if l.metadata is not None:
|
||||
modelspec_tags = l.metadata.get('modelspec.tags', {})
|
||||
possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata
|
||||
possible_tags.update(modelspec_tags)
|
||||
if isinstance(possible_tags, str):
|
||||
possible_tags = {}
|
||||
for k, v in possible_tags.items():
|
||||
words = k.split('_', 1) if '_' in k else [v, k]
|
||||
words = [str(w).replace('.json', '') for w in words]
|
||||
if words[0] == '{}':
|
||||
words[0] = 0
|
||||
tag = ' '.join(words[1:]).lower()
|
||||
tags[tag] = words[0]
|
||||
|
||||
def find_version():
|
||||
found_versions = []
|
||||
|
|
@ -77,7 +64,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||
found_versions = all_versions
|
||||
return found_versions
|
||||
|
||||
find_version()
|
||||
for v in find_version(): # trigger words from info json
|
||||
possible_tags = v.get('trainedWords', [])
|
||||
if isinstance(possible_tags, list):
|
||||
|
|
|
|||
|
|
@ -224,6 +224,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
|
|||
.extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-md); cursor: pointer; display: inline-block; }
|
||||
.extra-network-cards .card .actions>span { padding: 4px; font-size: 34px !important; }
|
||||
.extra-network-cards .card .actions>span:hover { color: var(--highlight-color); }
|
||||
.extra-network-cards .card .version { position: absolute; top: 0; left: 0; padding: 2px; font-weight: bolder; text-shadow: 1px 1px black; text-transform: uppercase; font-size: 0.8rem; background: gray; opacity: 75%; margin: 2px; line-height: 0.9rem; }
|
||||
.extra-network-cards .card:hover .actions { display: block; }
|
||||
.extra-network-cards .card:hover .overlay .tags { display: block; }
|
||||
.extra-network-cards .card:has(>img[src*="card-no-preview.png"])::before { content: ''; position: absolute; width: 100%; height: 100%; mix-blend-mode: multiply; background-color: var(--data-color); }
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ card_full = '''
|
|||
<div class='tags'></div>
|
||||
<div class='name'>{title}</div>
|
||||
</div>
|
||||
<div class='version'>{version}</div>
|
||||
<div class='actions'>
|
||||
<span class='details' title="Get details" onclick="showCardDetails(event)">🛈</span>
|
||||
<div class='additional'><ul></ul></div>
|
||||
|
|
@ -40,7 +41,7 @@ card_full = '''
|
|||
</div>
|
||||
'''
|
||||
card_list = '''
|
||||
<div class='card card-list' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}' data-mtime='{mtime}' data-size='{size}' data-search='{search}'>
|
||||
<div class='card card-list' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}' data-mtime='{mtime}' data-version='{version}' data-size='{size}' data-search='{search}'>
|
||||
<div style='display: flex'>
|
||||
<span class='details' title="Get details" onclick="showCardDetails(event)">🛈</span>
|
||||
<div class='name' style='flex-flow: column'>{title}
|
||||
|
|
@ -308,6 +309,7 @@ class ExtraNetworksPage:
|
|||
"card_click": item.get("onclick", '"' + html.escape(f'return cardClicked({item.get("prompt", None)}, {"true" if self.allow_negative_prompt else "false"})') + '"'),
|
||||
"mtime": item.get("mtime", 0),
|
||||
"size": item.get("size", 0),
|
||||
"version": item.get("version", ''),
|
||||
"color": random_bright_color(),
|
||||
}
|
||||
alias = item.get("alias", None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue