automatic/pipelines/hdm/xut/modules/time_emb.py

35 lines
1014 B
Python

import math
import torch
import torch.nn as nn
from ..utils import compile_wrapper
class TimestepEmbedding(nn.Module):
def __init__(self, dim, max_period=10000, time_factor: float = 1000.0):
super().__init__()
self.dim = dim
self.max_period = max_period
self.time_factor = time_factor
self.register_buffer(
"freqs",
torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=dim // 2, dtype=torch.float32)
/ (dim // 2)
)[None],
)
self.proj = nn.Sequential(nn.Linear(dim, dim), nn.Mish())
@compile_wrapper
def forward(self, t):
t = self.time_factor * t
args = t[:, None] * self.freqs
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if self.dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return self.proj(embedding)