mirror of https://github.com/bmaltais/kohya_ss
159 lines
7.2 KiB
Python
159 lines
7.2 KiB
Python
import safetensors.torch
|
|
import json
|
|
from collections import OrderedDict
|
|
import sys # To redirect stdout
|
|
import traceback
|
|
|
|
class Logger(object):
|
|
def __init__(self, filename="loha_analysis_output.txt"):
|
|
self.terminal = sys.stdout
|
|
self.log = open(filename, "w", encoding='utf-8')
|
|
|
|
def write(self, message):
|
|
self.terminal.write(message)
|
|
self.log.write(message)
|
|
|
|
def flush(self):
|
|
# This flush method is needed for python 3 compatibility.
|
|
# This handles the flush command, which shutil.copytree or os.system uses.
|
|
self.terminal.flush()
|
|
self.log.flush()
|
|
|
|
def close(self):
|
|
self.log.close()
|
|
|
|
def analyze_safetensors_file(filepath, output_filename="loha_analysis_output.txt"):
|
|
"""
|
|
Analyzes a .safetensors file to extract and print its metadata
|
|
and tensor information (keys, shapes, dtypes) to a file.
|
|
"""
|
|
original_stdout = sys.stdout
|
|
logger = Logger(filename=output_filename)
|
|
sys.stdout = logger
|
|
|
|
try:
|
|
print(f"--- Analyzing: {filepath} ---\n")
|
|
print(f"--- Output will be saved to: {output_filename} ---\n")
|
|
|
|
# Load the tensors to get their structure
|
|
state_dict = safetensors.torch.load_file(filepath, device="cpu") # Load to CPU to avoid potential CUDA issues
|
|
|
|
print("--- Tensor Information ---")
|
|
if not state_dict:
|
|
print("No tensors found in the state dictionary.")
|
|
else:
|
|
# Sort keys for consistent output
|
|
sorted_keys = sorted(state_dict.keys())
|
|
current_module_prefix = ""
|
|
|
|
# First, identify all unique module prefixes for better grouping
|
|
module_prefixes = sorted(list(set([".".join(key.split(".")[:-1]) for key in sorted_keys if "." in key])))
|
|
|
|
for prefix in module_prefixes:
|
|
if not prefix: # Skip keys that don't seem to be part of a module (e.g. global metadata tensors if any)
|
|
continue
|
|
print(f"\nModule: {prefix}")
|
|
for key in sorted_keys:
|
|
if key.startswith(prefix + "."):
|
|
tensor = state_dict[key]
|
|
print(f" - Key: {key}")
|
|
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}") # Output shape as list for clarity
|
|
if key.endswith((".alpha", ".dim")):
|
|
try:
|
|
value = tensor.item()
|
|
# Check if value is float and format if it is
|
|
if isinstance(value, float):
|
|
print(f" Value: {value:.8f}") # Format float to a certain precision
|
|
else:
|
|
print(f" Value: {value}")
|
|
except Exception as e:
|
|
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
|
elif tensor.numel() < 10: # Print small tensors' values
|
|
print(f" Values (first few): {tensor.flatten()[:10].tolist()}")
|
|
|
|
|
|
# Print keys that might not fit the module pattern (e.g., older formats or single tensors)
|
|
print("\n--- Other Tensor Keys (if any, not fitting typical module.parameter pattern) ---")
|
|
other_keys_found = False
|
|
for key in sorted_keys:
|
|
if not any(key.startswith(p + ".") for p in module_prefixes if p):
|
|
other_keys_found = True
|
|
tensor = state_dict[key]
|
|
print(f" - Key: {key}")
|
|
print(f" Shape: {list(tensor.shape)}, Dtype: {tensor.dtype}")
|
|
if key.endswith((".alpha", ".dim")) or tensor.numel() == 1:
|
|
try:
|
|
value = tensor.item()
|
|
if isinstance(value, float):
|
|
print(f" Value: {value:.8f}")
|
|
else:
|
|
print(f" Value: {value}")
|
|
except Exception as e:
|
|
print(f" Value: Could not extract scalar value ({tensor}, error: {e})")
|
|
|
|
if not other_keys_found:
|
|
print("No other keys found.")
|
|
|
|
print(f"\nTotal tensor keys found: {len(state_dict)}")
|
|
|
|
print("\n--- Metadata (from safetensors header) ---")
|
|
metadata_content = OrderedDict()
|
|
malformed_metadata_keys = []
|
|
try:
|
|
# Use safe_open to access the metadata separately
|
|
with safetensors.safe_open(filepath, framework="pt", device="cpu") as f:
|
|
metadata_keys = f.metadata()
|
|
if metadata_keys is None:
|
|
print("No metadata dictionary found in the file header (f.metadata() returned None).")
|
|
else:
|
|
for k in metadata_keys.keys():
|
|
try:
|
|
metadata_content[k] = metadata_keys.get(k)
|
|
except Exception as e:
|
|
malformed_metadata_keys.append((k, str(e)))
|
|
metadata_content[k] = f"[Error reading value: {e}]"
|
|
except Exception as e:
|
|
print(f"Could not open or read metadata using safe_open: {e}")
|
|
traceback.print_exc(file=sys.stdout)
|
|
|
|
if not metadata_content and not malformed_metadata_keys:
|
|
print("No metadata content extracted.")
|
|
else:
|
|
for key, value in metadata_content.items():
|
|
print(f"- {key}: {value}")
|
|
if key == "ss_network_args" and value and not value.startswith("[Error"):
|
|
try:
|
|
parsed_args = json.loads(value)
|
|
print(" Parsed ss_network_args:")
|
|
for arg_key, arg_value in parsed_args.items():
|
|
print(f" - {arg_key}: {arg_value}")
|
|
except json.JSONDecodeError:
|
|
print(" (ss_network_args is not a valid JSON string)")
|
|
if malformed_metadata_keys:
|
|
print("\n--- Malformed Metadata Keys (could not be read) ---")
|
|
for key, error_msg in malformed_metadata_keys:
|
|
print(f"- {key}: Error: {error_msg}")
|
|
|
|
print("\n--- End of Analysis ---")
|
|
|
|
except Exception as e:
|
|
print(f"\n!!! An error occurred during analysis !!!")
|
|
print(str(e))
|
|
traceback.print_exc(file=sys.stdout) # Print full traceback to the log file
|
|
finally:
|
|
sys.stdout = original_stdout # Restore standard output
|
|
logger.close()
|
|
print(f"\nAnalysis complete. Output saved to: {output_filename}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
input_file_path = input("Enter the path to your working LoHA .safetensors file: ")
|
|
output_file_name = "loha_analysis_results.txt" # You can change this default
|
|
|
|
# Suggest a default output name based on input file if desired
|
|
# import os
|
|
# base_name = os.path.splitext(os.path.basename(input_file_path))[0]
|
|
# output_file_name = f"{base_name}_analysis.txt"
|
|
|
|
print(f"The analysis will be saved to: {output_file_name}")
|
|
analyze_safetensors_file(input_file_path, output_filename=output_file_name) |