sd-webui-stablesr/tools/extract_vaecfw.py

21 lines
514 B
Python

import torch
vae_path = 'models/vqgan_cfw_00011.ckpt'
with open(vae_path, 'rb') as f:
vae_ckpt = torch.load(f, map_location='cpu')
prune_keys = []
for k, v in vae_ckpt['state_dict'].items():
if 'decoder.fusion_layer' in k:
prune_keys.append(k)
print(k)
vae_cfw = {}
for k in prune_keys:
vae_cfw[k] = vae_ckpt['state_dict'][k]
del vae_ckpt['state_dict'][k]
torch.save(vae_ckpt, 'models/vqgan_cfw_00011_vae_only.ckpt')
torch.save(vae_cfw, 'models/vqgan_cfw_00011_cfw_only.ckpt')