147 lines
5.7 KiB
Python
147 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
|
|
"""
|
|
This module has the EMA class used to store a copy of the exponentially decayed
|
|
model params.
|
|
|
|
Typical usage of EMA class involves initializing an object using an existing
|
|
model (random or from a seed model) and setting the config like ema_decay,
|
|
ema_start_update which determine how the EMA model is updated. After every
|
|
update of the model i.e. at the end of the train_step, the EMA should be updated
|
|
by passing the new model to the EMA.step function. The EMA model state dict
|
|
can be stored in the extra state under the key of "ema" and dumped
|
|
into a checkpoint and loaded. The EMA object can be passed to tasks
|
|
by setting task.uses_ema property.
|
|
EMA is a smoothed/ensemble model which might have better performance
|
|
when used for inference or further fine-tuning. EMA class has a
|
|
reverse function to load the EMA params into a model and use it
|
|
like a regular model.
|
|
"""
|
|
|
|
import copy
|
|
import os
|
|
import shutil
|
|
|
|
import safetensors.torch
|
|
import torch
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
|
class EMAModel(object):
|
|
|
|
def __init__(self, model: UNet2DConditionModel, decay: float = 0.9999, device=None, dtype=None):
|
|
"""
|
|
@param model: model to initialize the EMA with
|
|
@param decay: Decay rate to use
|
|
@param device: If provided, copy EMA to this device (e.g. gpu), else EMA is in the same device as the model.
|
|
"""
|
|
|
|
self.decay = decay
|
|
self.model = copy.deepcopy(model)
|
|
self.model.to(device, dtype=dtype)
|
|
|
|
self.params = {}
|
|
self.build_params()
|
|
|
|
self.update_freq_counter = 0
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
def get_model(self):
|
|
return self.model
|
|
|
|
def build_params(self, state_dict=None):
|
|
"""
|
|
Store a copy of the EMA params in fp32.
|
|
If state dict is passed, the EMA params is copied from
|
|
the provided state dict. Otherwise, it is copied from the
|
|
current EMA model parameters.
|
|
"""
|
|
if state_dict is None:
|
|
state_dict = self.model.state_dict()
|
|
|
|
# for non-float params (like registered symbols), they are copied into this dict and covered in each update
|
|
for param_key in state_dict:
|
|
if param_key in self.params:
|
|
self.params[param_key].copy_(state_dict[param_key])
|
|
else:
|
|
self.params[param_key] = state_dict[param_key]
|
|
|
|
def load(self, state_dict, build_params=False):
|
|
""" Load data from a state_dict """
|
|
self.model.load_state_dict(state_dict, strict=False)
|
|
if build_params:
|
|
self.build_params(state_dict)
|
|
|
|
def get_decay(self):
|
|
return self.decay
|
|
|
|
@torch.no_grad()
|
|
def step(self, new_model):
|
|
""" One update of the EMA model based on new model weights """
|
|
decay = self.decay
|
|
|
|
ema_state_dict = {}
|
|
ema_params = self.params
|
|
for key, param in new_model.state_dict().items():
|
|
if ".parametrizations." in key:
|
|
if ".parametrizations.weight.original" in key:
|
|
# Handle reparametrized parameters
|
|
param = new_model.get_submodule(key.replace(".parametrizations.weight.original", "")).weight
|
|
key = key.replace(".parametrizations.weight.original", ".weight")
|
|
else:
|
|
# Skip extra values used in reparametrization
|
|
continue
|
|
|
|
try:
|
|
ema_param = ema_params[key]
|
|
except KeyError:
|
|
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
|
|
|
if param.shape != ema_param.shape:
|
|
raise ValueError(
|
|
"incompatible tensor shapes between model param and ema param"
|
|
+ "{} vs. {}".format(param.shape, ema_param.shape)
|
|
)
|
|
if "version" in key:
|
|
# Do not decay a model.version pytorch param
|
|
continue
|
|
|
|
# for non-float params (like registered symbols), they are covered in each update
|
|
if not torch.is_floating_point(ema_param):
|
|
if ema_param.dtype != param.dtype:
|
|
raise ValueError(
|
|
"incompatible tensor dtypes between model param and ema param"
|
|
+ "{} vs. {}".format(param.dtype, ema_param.dtype)
|
|
)
|
|
ema_param.copy_(param)
|
|
else:
|
|
ema_param.mul_(decay)
|
|
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
|
|
ema_state_dict[key] = ema_param
|
|
self.load(ema_state_dict, build_params=False)
|
|
|
|
def apply(self, model):
|
|
"""
|
|
Load the model parameters from EMA model.
|
|
Useful for inference or fine-tuning from the EMA model.
|
|
"""
|
|
model.load_state_dict(self.model.state_dict(), strict=False)
|
|
return model
|
|
|
|
def save_pretrained(self, model_path, safe_serialization=True):
|
|
model_file = os.path.join(model_path, "diffusion_pytorch_model.safetensors")
|
|
model_bin = model_file.replace("safetensors", "bin")
|
|
self.model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
|
if not os.path.exists(model_file):
|
|
print("Yeah, regular save_pretrained is not working.")
|
|
safetensors.torch.save_file(self.model.state_dict(), model_file)
|
|
if os.path.exists(model_bin):
|
|
os.remove(model_bin)
|
|
model_config_path = os.path.join(model_path, "config.json")
|
|
if not os.path.exists(model_config_path):
|
|
unet_config_path = model_config_path.replace("ema_", "")
|
|
if os.path.exists(unet_config_path):
|
|
shutil.copyfile(unet_config_path, model_config_path)
|