automatic/extensions-builtin/Lora/network_oft.py

50 lines
1.9 KiB
Python

import torch
import diffusers.models.lora as diffusers_lora
import network
from modules import devices
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
"""
weights.w.items()
alpha : tensor(0.0010, dtype=torch.bfloat16)
oft_blocks : tensor([[[ 0.0000e+00, 1.4400e-04, 1.7319e-03, ..., -8.8882e-04,
5.7373e-03, -4.4250e-03],
[-1.4400e-04, 0.0000e+00, 8.6594e-04, ..., 1.5945e-03,
-8.5449e-04, 1.9684e-03], ...etc...
, dtype=torch.bfloat16)"""
if "oft_blocks" in weights.w.keys():
module = NetworkModuleOFT(net, weights)
return module
else:
return None
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.weights = weights.w.get("oft_blocks").to(device=devices.device)
self.dim = self.weights.shape[0] # num blocks
self.alpha = self.multiplier()
self.block_size = self.weights.shape[-1]
def get_weight(self):
block_Q = self.weights - self.weights.transpose(1, 2)
I = torch.eye(self.block_size, device=devices.device).unsqueeze(0).repeat(self.dim, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.alpha * block_R + (1 - self.alpha) * I
R = torch.block_diag(*block_R_weighted)
return R
def calc_updown(self, orig_weight):
R = self.get_weight().to(device=devices.device, dtype=orig_weight.dtype)
if orig_weight.dim() == 4:
updown = torch.einsum("oihw, op -> pihw", orig_weight, R) * self.calc_scale()
else:
updown = torch.einsum("oi, op -> pi", orig_weight, R) * self.calc_scale()
return self.finalize_updown(updown, orig_weight, orig_weight.shape)