display lora tag in networks and do not filter

pull/3454/head
Vladimir Mandic 2024-09-22 15:25:10 -04:00
parent 2d0ad9ae61
commit 3ceb998338
6 changed files with 69 additions and 51 deletions

View File

@ -36,6 +36,8 @@
- if lora contains no tags, lora name itself will be used as a tag
- if prompt contains `_tags_` it will be used as placeholder for replacement, otherwise tags will be appended
- used tags are also logged and registered in image metadata
- loras are no longer filtered per detected type vs loaded model type as its unreliable
- loras display in networks now shows possible version in top-left corner
- correct using of `extra_networks_default_multiplier` if not scale is specified
- always keep lora on gpu
- **text encoder**:

View File

@ -7,22 +7,27 @@ from rich import print # pylint: disable=redefined-builtin
def read_metadata(fn):
res = {}
if not fn.lower().endswith(".safetensors"):
return
with open(fn, mode="rb") as f:
metadata_len = f.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = f.read(2)
if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
print(f"Not a valid safetensors file: {fn}")
json_data = json_start + f.read(metadata_len-2)
json_obj = json.loads(json_data)
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
print(f"{fn}: {json.dumps(res, indent=4)}")
try:
metadata_len = f.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = f.read(2)
if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
print(f"Not a valid safetensors file: {fn}")
json_data = json_start + f.read(metadata_len-2)
json_obj = json.loads(json_data)
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
print(f"{fn}: {json.dumps(res, indent=4)}")
except Exception:
print(f"{fn}: cannot read metadata")
def main():

View File

@ -13,7 +13,10 @@ class SdVersion(enum.Enum):
Unknown = 1
SD1 = 2
SD2 = 3
SD3 = 3
SDXL = 4
SC = 5
F1 = 6
class NetworkOnDisk:
@ -40,13 +43,32 @@ class NetworkOnDisk:
self.sd_version = self.detect_version()
def detect_version(self):
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', "")) == "True":
return SdVersion.SD2
elif len(self.metadata):
return SdVersion.SD1
return SdVersion.Unknown
base = str(self.metadata.get('ss_base_model_version', "")).lower()
arch = str(self.metadata.get('modelspec.architecture', "")).lower()
if base.startswith("sd_v1"):
return 'sd1'
if base.startswith("sdxl"):
return 'xl'
if base.startswith("stable_cascade"):
return 'sc'
if base.startswith("sd3"):
return 'sd3'
if base.startswith("flux"):
return 'f1'
if arch.startswith("stable-diffusion-v1"):
return 'sd1'
if arch.startswith("stable-diffusion-xl"):
return 'xl'
if arch.startswith("stable-cascade"):
return 'sc'
if arch.startswith("flux"):
return 'f1'
if str(self.metadata.get('ss_v2', "")) == "True":
return 'sd2'
return ''
def set_hash(self, v):
self.hash = v or ''

View File

@ -1,7 +1,6 @@
import os
import json
import concurrent
import network
import networks
from modules import shared, ui_extra_networks
@ -21,24 +20,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
l = networks.available_networks.get(name)
if l is None:
shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered')
print(networks.available_networks)
return None
try:
# path, _ext = os.path.splitext(l.filename)
name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0]
if not shared.native:
if l.sd_version == network.SdVersion.SDXL:
return None
elif shared.native:
if shared.sd_model_type == 'none': # return all when model is not loaded
pass
elif shared.sd_model_type == 'sdxl':
if l.sd_version == network.SdVersion.SD1 or l.sd_version == network.SdVersion.SD2:
return None
elif shared.sd_model_type == 'sd':
if l.sd_version == network.SdVersion.SDXL:
return None
item = {
"type": 'Lora',
"name": name,
@ -48,22 +33,24 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
"mtime": os.path.getmtime(l.filename),
"size": os.path.getsize(l.filename),
"version": l.sd_version,
}
info = self.find_info(l.filename)
tags = {}
modelspec_tags = l.metadata.get('modelspec.tags', {}) if l.metadata is not None else {}
possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {} # tags from model metedata
possible_tags.update(modelspec_tags)
if isinstance(possible_tags, str):
possible_tags = {}
for k, v in possible_tags.items():
words = k.split('_', 1) if '_' in k else [v, k]
words = [str(w).replace('.json', '') for w in words]
if words[0] == '{}':
words[0] = 0
tag = ' '.join(words[1:]).lower()
tags[tag] = words[0]
if l.metadata is not None:
modelspec_tags = l.metadata.get('modelspec.tags', {})
possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata
possible_tags.update(modelspec_tags)
if isinstance(possible_tags, str):
possible_tags = {}
for k, v in possible_tags.items():
words = k.split('_', 1) if '_' in k else [v, k]
words = [str(w).replace('.json', '') for w in words]
if words[0] == '{}':
words[0] = 0
tag = ' '.join(words[1:]).lower()
tags[tag] = words[0]
def find_version():
found_versions = []
@ -77,7 +64,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
found_versions = all_versions
return found_versions
find_version()
for v in find_version(): # trigger words from info json
possible_tags = v.get('trainedWords', [])
if isinstance(possible_tags, list):

View File

@ -224,6 +224,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-md); cursor: pointer; display: inline-block; }
.extra-network-cards .card .actions>span { padding: 4px; font-size: 34px !important; }
.extra-network-cards .card .actions>span:hover { color: var(--highlight-color); }
.extra-network-cards .card .version { position: absolute; top: 0; left: 0; padding: 2px; font-weight: bolder; text-shadow: 1px 1px black; text-transform: uppercase; font-size: 0.8rem; background: gray; opacity: 75%; margin: 2px; line-height: 0.9rem; }
.extra-network-cards .card:hover .actions { display: block; }
.extra-network-cards .card:hover .overlay .tags { display: block; }
.extra-network-cards .card:has(>img[src*="card-no-preview.png"])::before { content: ''; position: absolute; width: 100%; height: 100%; mix-blend-mode: multiply; background-color: var(--data-color); }

View File

@ -32,6 +32,7 @@ card_full = '''
<div class='tags'></div>
<div class='name'>{title}</div>
</div>
<div class='version'>{version}</div>
<div class='actions'>
<span class='details' title="Get details" onclick="showCardDetails(event)">&#x1f6c8;</span>
<div class='additional'><ul></ul></div>
@ -40,7 +41,7 @@ card_full = '''
</div>
'''
card_list = '''
<div class='card card-list' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}' data-mtime='{mtime}' data-size='{size}' data-search='{search}'>
<div class='card card-list' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}' data-mtime='{mtime}' data-version='{version}' data-size='{size}' data-search='{search}'>
<div style='display: flex'>
<span class='details' title="Get details" onclick="showCardDetails(event)">&#x1f6c8;</span>&nbsp;
<div class='name' style='flex-flow: column'>{title}&nbsp;
@ -308,6 +309,7 @@ class ExtraNetworksPage:
"card_click": item.get("onclick", '"' + html.escape(f'return cardClicked({item.get("prompt", None)}, {"true" if self.allow_negative_prompt else "false"})') + '"'),
"mtime": item.get("mtime", 0),
"size": item.get("size", 0),
"version": item.get("version", ''),
"color": random_bright_color(),
}
alias = item.get("alias", None)