diff --git a/CHANGELOG.md b/CHANGELOG.md index cd70c6d08..c3c9740b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ - if lora contains no tags, lora name itself will be used as a tag - if prompt contains `_tags_` it will be used as placeholder for replacement, otherwise tags will be appended - used tags are also logged and registered in image metadata + - loras are no longer filtered per detected type vs loaded model type as its unreliable + - loras display in networks now shows possible version in top-left corner - correct using of `extra_networks_default_multiplier` if not scale is specified - always keep lora on gpu - **text encoder**: diff --git a/cli/model-metadata.py b/cli/model-metadata.py index c4c5b6411..0f6311e97 100755 --- a/cli/model-metadata.py +++ b/cli/model-metadata.py @@ -7,22 +7,27 @@ from rich import print # pylint: disable=redefined-builtin def read_metadata(fn): res = {} + if not fn.lower().endswith(".safetensors"): + return with open(fn, mode="rb") as f: - metadata_len = f.read(8) - metadata_len = int.from_bytes(metadata_len, "little") - json_start = f.read(2) - if metadata_len <= 2 or json_start not in (b'{"', b"{'"): - print(f"Not a valid safetensors file: {fn}") - json_data = json_start + f.read(metadata_len-2) - json_obj = json.loads(json_data) - for k, v in json_obj.get("__metadata__", {}).items(): - res[k] = v - if isinstance(v, str) and v[0:1] == '{': - try: - res[k] = json.loads(v) - except Exception: - pass - print(f"{fn}: {json.dumps(res, indent=4)}") + try: + metadata_len = f.read(8) + metadata_len = int.from_bytes(metadata_len, "little") + json_start = f.read(2) + if metadata_len <= 2 or json_start not in (b'{"', b"{'"): + print(f"Not a valid safetensors file: {fn}") + json_data = json_start + f.read(metadata_len-2) + json_obj = json.loads(json_data) + for k, v in json_obj.get("__metadata__", {}).items(): + res[k] = v + if isinstance(v, str) and v[0:1] == '{': + try: + res[k] = json.loads(v) + except Exception: + pass + print(f"{fn}: {json.dumps(res, indent=4)}") + except Exception: + print(f"{fn}: cannot read metadata") def main(): diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 0a65fd927..05bb8cc87 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -13,7 +13,10 @@ class SdVersion(enum.Enum): Unknown = 1 SD1 = 2 SD2 = 3 + SD3 = 3 SDXL = 4 + SC = 5 + F1 = 6 class NetworkOnDisk: @@ -40,13 +43,32 @@ class NetworkOnDisk: self.sd_version = self.detect_version() def detect_version(self): - if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): - return SdVersion.SDXL - elif str(self.metadata.get('ss_v2', "")) == "True": - return SdVersion.SD2 - elif len(self.metadata): - return SdVersion.SD1 - return SdVersion.Unknown + base = str(self.metadata.get('ss_base_model_version', "")).lower() + arch = str(self.metadata.get('modelspec.architecture', "")).lower() + if base.startswith("sd_v1"): + return 'sd1' + if base.startswith("sdxl"): + return 'xl' + if base.startswith("stable_cascade"): + return 'sc' + if base.startswith("sd3"): + return 'sd3' + if base.startswith("flux"): + return 'f1' + + if arch.startswith("stable-diffusion-v1"): + return 'sd1' + if arch.startswith("stable-diffusion-xl"): + return 'xl' + if arch.startswith("stable-cascade"): + return 'sc' + if arch.startswith("flux"): + return 'f1' + + if str(self.metadata.get('ss_v2', "")) == "True": + return 'sd2' + + return '' def set_hash(self, v): self.hash = v or '' diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index d35c9406b..170c3c7d3 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,7 +1,6 @@ import os import json import concurrent -import network import networks from modules import shared, ui_extra_networks @@ -21,24 +20,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): l = networks.available_networks.get(name) if l is None: shared.log.warning(f'Networks: type=lora registered={len(list(networks.available_networks))} file="{name}" not registered') - print(networks.available_networks) return None try: # path, _ext = os.path.splitext(l.filename) name = os.path.splitext(os.path.relpath(l.filename, shared.cmd_opts.lora_dir))[0] - if not shared.native: - if l.sd_version == network.SdVersion.SDXL: - return None - elif shared.native: - if shared.sd_model_type == 'none': # return all when model is not loaded - pass - elif shared.sd_model_type == 'sdxl': - if l.sd_version == network.SdVersion.SD1 or l.sd_version == network.SdVersion.SD2: - return None - elif shared.sd_model_type == 'sd': - if l.sd_version == network.SdVersion.SDXL: - return None - item = { "type": 'Lora', "name": name, @@ -48,22 +33,24 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "metadata": json.dumps(l.metadata, indent=4) if l.metadata else None, "mtime": os.path.getmtime(l.filename), "size": os.path.getsize(l.filename), + "version": l.sd_version, } info = self.find_info(l.filename) tags = {} - modelspec_tags = l.metadata.get('modelspec.tags', {}) if l.metadata is not None else {} - possible_tags = l.metadata.get('ss_tag_frequency', {}) if l.metadata is not None else {} # tags from model metedata - possible_tags.update(modelspec_tags) - if isinstance(possible_tags, str): - possible_tags = {} - for k, v in possible_tags.items(): - words = k.split('_', 1) if '_' in k else [v, k] - words = [str(w).replace('.json', '') for w in words] - if words[0] == '{}': - words[0] = 0 - tag = ' '.join(words[1:]).lower() - tags[tag] = words[0] + if l.metadata is not None: + modelspec_tags = l.metadata.get('modelspec.tags', {}) + possible_tags = l.metadata.get('ss_tag_frequency', {}) # tags from model metedata + possible_tags.update(modelspec_tags) + if isinstance(possible_tags, str): + possible_tags = {} + for k, v in possible_tags.items(): + words = k.split('_', 1) if '_' in k else [v, k] + words = [str(w).replace('.json', '') for w in words] + if words[0] == '{}': + words[0] = 0 + tag = ' '.join(words[1:]).lower() + tags[tag] = words[0] def find_version(): found_versions = [] @@ -77,7 +64,6 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): found_versions = all_versions return found_versions - find_version() for v in find_version(): # trigger words from info json possible_tags = v.get('trainedWords', []) if isinstance(possible_tags, list): diff --git a/javascript/sdnext.css b/javascript/sdnext.css index a1493d2e0..04bea121b 100644 --- a/javascript/sdnext.css +++ b/javascript/sdnext.css @@ -224,6 +224,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt .extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-md); cursor: pointer; display: inline-block; } .extra-network-cards .card .actions>span { padding: 4px; font-size: 34px !important; } .extra-network-cards .card .actions>span:hover { color: var(--highlight-color); } +.extra-network-cards .card .version { position: absolute; top: 0; left: 0; padding: 2px; font-weight: bolder; text-shadow: 1px 1px black; text-transform: uppercase; font-size: 0.8rem; background: gray; opacity: 75%; margin: 2px; line-height: 0.9rem; } .extra-network-cards .card:hover .actions { display: block; } .extra-network-cards .card:hover .overlay .tags { display: block; } .extra-network-cards .card:has(>img[src*="card-no-preview.png"])::before { content: ''; position: absolute; width: 100%; height: 100%; mix-blend-mode: multiply; background-color: var(--data-color); } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 1ac571d41..430775661 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -32,6 +32,7 @@ card_full = '''