21 lines
514 B
Python
21 lines
514 B
Python
import torch
|
|
|
|
vae_path = 'models/vqgan_cfw_00011.ckpt'
|
|
|
|
with open(vae_path, 'rb') as f:
|
|
vae_ckpt = torch.load(f, map_location='cpu')
|
|
|
|
prune_keys = []
|
|
for k, v in vae_ckpt['state_dict'].items():
|
|
if 'decoder.fusion_layer' in k:
|
|
prune_keys.append(k)
|
|
print(k)
|
|
|
|
vae_cfw = {}
|
|
for k in prune_keys:
|
|
vae_cfw[k] = vae_ckpt['state_dict'][k]
|
|
del vae_ckpt['state_dict'][k]
|
|
|
|
torch.save(vae_ckpt, 'models/vqgan_cfw_00011_vae_only.ckpt')
|
|
torch.save(vae_cfw, 'models/vqgan_cfw_00011_cfw_only.ckpt')
|