mirror of https://github.com/vladmandic/automatic
1316 lines
58 KiB
Python
1316 lines
58 KiB
Python
import gc
|
|
from collections import OrderedDict
|
|
from typing import Any, Dict, Callable
|
|
import os
|
|
from copy import deepcopy
|
|
from math import ceil
|
|
import json
|
|
import safetensors
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from diffusers import (
|
|
DiffusionPipeline,
|
|
StableDiffusionPipeline,
|
|
StableDiffusionXLPipeline,
|
|
DDPMScheduler,
|
|
UNet2DConditionModel,
|
|
)
|
|
import tqdm
|
|
import yaml
|
|
|
|
|
|
def remove_all_forward_hooks(model: torch.nn.Module) -> None:
|
|
for _name, child in model._modules.items(): # pylint: disable=protected-access
|
|
if child is not None:
|
|
if hasattr(child, "_forward_hooks"):
|
|
child._forward_hooks: Dict[int, Callable] = OrderedDict()
|
|
remove_all_forward_hooks(child)
|
|
|
|
|
|
# Inspired from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock
|
|
class SparseMoeBlock(nn.Module):
|
|
def __init__(self, config, experts):
|
|
super().__init__()
|
|
self.hidden_dim = config["hidden_size"]
|
|
self.num_experts = config["num_local_experts"]
|
|
self.top_k = config["num_experts_per_tok"]
|
|
self.out_dim = config.get("out_dim", self.hidden_dim)
|
|
|
|
# gating
|
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
self.experts = nn.ModuleList([deepcopy(exp) for exp in experts])
|
|
|
|
def forward(self, hidden_states: torch.Tensor, scale=None) -> torch.Tensor: # pylint: disable=unused-argument
|
|
batch_size, sequence_length, f_map_sz = hidden_states.shape
|
|
hidden_states = hidden_states.view(-1, f_map_sz)
|
|
# router_logits: (batch * sequence_length, n_experts)
|
|
router_logits = self.gate(hidden_states)
|
|
_, selected_experts = torch.topk(
|
|
router_logits.sum(dim=0, keepdim=True), self.top_k, dim=1
|
|
)
|
|
routing_weights = F.softmax(
|
|
router_logits[:, selected_experts[0]], dim=1, dtype=torch.float
|
|
)
|
|
|
|
# we cast back to the input dtype
|
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
|
|
final_hidden_states = torch.zeros(
|
|
(batch_size * sequence_length, self.out_dim),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
# Loop over all available experts in the model and perform the computation on each expert
|
|
for i, expert_idx in enumerate(selected_experts[0].tolist()):
|
|
expert_layer = self.experts[expert_idx]
|
|
|
|
current_hidden_states = routing_weights[:, i].view(
|
|
batch_size * sequence_length, -1
|
|
) * expert_layer(hidden_states)
|
|
|
|
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
# the `top_x` tensor here.
|
|
final_hidden_states = final_hidden_states + current_hidden_states
|
|
final_hidden_states = final_hidden_states.reshape(
|
|
batch_size, sequence_length, self.out_dim
|
|
)
|
|
return final_hidden_states
|
|
|
|
|
|
def getActivation(activation, name):
|
|
def hook(model, inp, output): # pylint: disable=unused-argument
|
|
activation[name] = inp
|
|
|
|
return hook
|
|
|
|
|
|
class SegMoEPipeline:
|
|
def __init__(self, config_or_path, **kwargs) -> Any:
|
|
"""
|
|
Instantiates the SegMoEPipeline. SegMoEPipeline implements the Segmind Mixture of Diffusion Experts, efficiently combining Stable Diffusion and Stable Diffusion Xl models.
|
|
|
|
Usage:
|
|
|
|
from segmoe import SegMoEPipeline
|
|
pipeline = SegMoEPipeline(config_or_path, **kwargs)
|
|
|
|
config_or_path: Path to Config or Directory containing SegMoE checkpoint or HF Card of SegMoE Checkpoint.
|
|
|
|
Other Keyword Arguments:
|
|
torch_dtype: Data Type to load the pipeline in. (Default: torch.float16)
|
|
variant: Variant of the Model. (Default: fp16)
|
|
device: Device to load the model on. (Default: cuda)
|
|
Other args supported by diffusers.DiffusionPipeline are also supported.
|
|
|
|
For more details visit https://github.com/segmind/segmoe.
|
|
"""
|
|
self.torch_dtype = kwargs.pop("torch_dtype", torch.float16)
|
|
self.use_safetensors = kwargs.pop("use_safetensors", True)
|
|
self.variant = kwargs.pop("variant", "fp16")
|
|
self.device = kwargs.pop("device", "cuda")
|
|
if os.path.isfile(config_or_path):
|
|
self.load_from_scratch(config_or_path, **kwargs)
|
|
else:
|
|
if not os.path.isdir(config_or_path):
|
|
cached_folder = DiffusionPipeline.download(config_or_path)
|
|
else:
|
|
cached_folder = config_or_path
|
|
unet = self.create_empty(cached_folder)
|
|
unet.load_state_dict(
|
|
safetensors.torch.load_file(
|
|
f"{cached_folder}/unet/diffusion_pytorch_model.safetensors"
|
|
)
|
|
)
|
|
self.pipe = DiffusionPipeline.from_pretrained(
|
|
cached_folder,
|
|
unet=unet,
|
|
torch_dtype=self.torch_dtype,
|
|
use_safetensors=self.use_safetensors,
|
|
)
|
|
self.pipe.to(self.device)
|
|
self.pipe.unet.to(
|
|
device=self.device,
|
|
dtype=self.torch_dtype,
|
|
memory_format=torch.channels_last,
|
|
)
|
|
|
|
def to(self, *args, **kwargs): # TODO added no-op to avoid error
|
|
self.pipe.to(*args, **kwargs)
|
|
|
|
def load_from_scratch(self, config: str, **kwargs) -> None:
|
|
# Load Config
|
|
with open(config, "r", encoding='utf8') as f:
|
|
config = yaml.load(f, Loader=yaml.SafeLoader)
|
|
self.config = config
|
|
if self.config.get("num_experts", None):
|
|
self.num_experts = self.config["num_experts"]
|
|
else:
|
|
if self.config.get("experts", None):
|
|
self.num_experts = len(self.config["experts"])
|
|
else:
|
|
if self.config.get("loras", None):
|
|
self.num_experts = len(self.config["loras"])
|
|
else:
|
|
self.num_experts = 1
|
|
num_experts_per_tok = self.config.get("num_experts_per_tok", 1)
|
|
self.config["num_experts_per_tok"] = num_experts_per_tok
|
|
moe_layers = self.config.get("moe_layers", "attn")
|
|
self.config["moe_layers"] = moe_layers
|
|
# Load Base Model
|
|
if self.config["base_model"].startswith(
|
|
"https://civitai.com/api/download/models/"
|
|
):
|
|
os.makedirs("base", exist_ok=True)
|
|
if not os.path.isfile("base/model.safetensors"):
|
|
os.system(
|
|
"wget -O "
|
|
+ "base/model.safetensors"
|
|
+ self.config["base_model"]
|
|
+ " --content-disposition"
|
|
)
|
|
self.config["base_model"] = "base/model.safetensors"
|
|
self.pipe = DiffusionPipeline.from_single_file(
|
|
self.config["base_model"], torch_dtype=self.torch_dtype
|
|
)
|
|
else:
|
|
try:
|
|
self.pipe = DiffusionPipeline.from_pretrained(
|
|
self.config["base_model"],
|
|
torch_dtype=self.torch_dtype,
|
|
use_safetensors=self.use_safetensors,
|
|
variant=self.variant,
|
|
**kwargs,
|
|
)
|
|
except Exception:
|
|
self.pipe = DiffusionPipeline.from_pretrained(
|
|
self.config["base_model"], torch_dtype=self.torch_dtype, **kwargs
|
|
)
|
|
if self.pipe.__class__ == StableDiffusionPipeline:
|
|
self.up_idx_start = 1
|
|
self.up_idx_end = len(self.pipe.unet.up_blocks)
|
|
self.down_idx_start = 0
|
|
self.down_idx_end = len(self.pipe.unet.down_blocks) - 1
|
|
elif self.pipe.__class__ == StableDiffusionXLPipeline:
|
|
self.up_idx_start = 0
|
|
self.up_idx_end = len(self.pipe.unet.up_blocks) - 1
|
|
self.down_idx_start = 1
|
|
self.down_idx_end = len(self.pipe.unet.down_blocks)
|
|
self.config["up_idx_start"] = self.up_idx_start
|
|
self.config["up_idx_end"] = self.up_idx_end
|
|
self.config["down_idx_start"] = self.down_idx_start
|
|
self.config["down_idx_end"] = self.down_idx_end
|
|
|
|
# TODO: Add Support for Scheduler Selection
|
|
self.pipe.scheduler = DDPMScheduler.from_config(self.pipe.scheduler.config)
|
|
|
|
# Load Experts
|
|
experts = []
|
|
positive = []
|
|
negative = []
|
|
if self.config.get("experts", None):
|
|
for i, exp in enumerate(self.config["experts"]):
|
|
positive.append(exp["positive_prompt"])
|
|
negative.append(exp["negative_prompt"])
|
|
if exp["source_model"].startswith(
|
|
"https://civitai.com/api/download/models/"
|
|
):
|
|
try:
|
|
if not os.path.isfile(f"expert_{i}/model.safetensors"):
|
|
os.makedirs(f"expert_{i}", exist_ok=True)
|
|
if not os.path.isfile(f"expert_{i}/model.safetensors"):
|
|
os.system(
|
|
f"wget {exp['source_model']} -O "
|
|
+ f"expert_{i}/model.safetensors"
|
|
+ " --content-disposition"
|
|
)
|
|
exp["source_model"] = f"expert_{i}/model.safetensors"
|
|
expert = DiffusionPipeline.from_single_file(
|
|
exp["source_model"],
|
|
).to(self.device, self.torch_dtype)
|
|
except Exception as e:
|
|
print(f"Expert {i} {exp['source_model']} failed to load")
|
|
print("Error:", e)
|
|
else:
|
|
try:
|
|
expert = DiffusionPipeline.from_pretrained(
|
|
exp["source_model"],
|
|
torch_dtype=self.torch_dtype,
|
|
use_safetensors=self.use_safetensors,
|
|
variant=self.variant,
|
|
**kwargs,
|
|
)
|
|
|
|
# TODO: Add Support for Scheduler Selection
|
|
expert.scheduler = DDPMScheduler.from_config(
|
|
expert.scheduler.config
|
|
)
|
|
except Exception:
|
|
expert = DiffusionPipeline.from_pretrained(
|
|
exp["source_model"], torch_dtype=self.torch_dtype, **kwargs
|
|
)
|
|
expert.scheduler = DDPMScheduler.from_config(
|
|
expert.scheduler.config
|
|
)
|
|
if exp.get("loras", None):
|
|
for j, lora in enumerate(exp["loras"]):
|
|
if lora.get("positive_prompt", None):
|
|
positive[-1] += " " + lora["positive_prompt"]
|
|
if lora.get("negative_prompt", None):
|
|
negative[-1] += " " + lora["negative_prompt"]
|
|
if lora["source_model"].startswith(
|
|
"https://civitai.com/api/download/models/"
|
|
):
|
|
try:
|
|
os.makedirs(f"expert_{i}/lora_{i}", exist_ok=True)
|
|
if not os.path.isfile(
|
|
f"expert_{i}/lora_{i}/pytorch_lora_weights.safetensors"
|
|
):
|
|
os.system(
|
|
f"wget {lora['source_model']} -O "
|
|
+ f"expert_{i}/lora_{j}/pytorch_lora_weights.safetensors"
|
|
+ " --content-disposition"
|
|
)
|
|
lora["source_model"] = f"expert_{j}/lora_{j}"
|
|
expert.load_lora_weights(lora["source_model"])
|
|
if len(exp["loras"]) == 1:
|
|
expert.fuse_lora()
|
|
except Exception as e:
|
|
print(
|
|
f"Expert{i} LoRA {j} {lora['source_model']} failed to load"
|
|
)
|
|
print("Error:", e)
|
|
else:
|
|
expert.load_lora_weights(lora["source_model"])
|
|
if len(exp["loras"]) == 1:
|
|
expert.fuse_lora()
|
|
experts.append(expert)
|
|
else:
|
|
experts = [deepcopy(self.pipe) for _ in range(self.num_experts)]
|
|
if self.config.get("experts", None):
|
|
if self.config.get("loras", None):
|
|
for i, lora in enumerate(self.config["loras"]):
|
|
if lora["source_model"].startswith(
|
|
"https://civitai.com/api/download/models/"
|
|
):
|
|
try:
|
|
os.makedirs(f"lora_{i}", exist_ok=True)
|
|
if not os.path.isfile(
|
|
f"lora_{i}/pytorch_lora_weights.safetensors"
|
|
):
|
|
os.system(
|
|
f"wget {lora['source_model']} -O "
|
|
+ f"lora_{i}/pytorch_lora_weights.safetensors"
|
|
+ " --content-disposition"
|
|
)
|
|
lora["source_model"] = f"lora_{i}"
|
|
self.pipe.load_lora_weights(lora["source_model"])
|
|
if len(self.config["loras"]) == 1:
|
|
self.pipe.fuse_lora()
|
|
except Exception as e:
|
|
print(f"LoRA {i} {lora['source_model']} failed to load")
|
|
print("Error:", e)
|
|
else:
|
|
self.pipe.load_lora_weights(lora["source_model"])
|
|
if len(self.config["loras"]) == 1:
|
|
self.pipe.fuse_lora()
|
|
else:
|
|
if self.config.get("loras", None):
|
|
j = []
|
|
n_loras = len(self.config["loras"])
|
|
i = 0
|
|
positive = [""] * len(experts)
|
|
negative = [""] * len(experts)
|
|
while n_loras:
|
|
n = ceil(n_loras / len(experts))
|
|
j += [i] * n
|
|
n_loras -= n
|
|
i += 1
|
|
for i, lora in enumerate(self.config["loras"]):
|
|
positive[j[i]] += lora["positive_prompt"] + " "
|
|
negative[j[i]] += lora["negative_prompt"] + " "
|
|
if lora["source_model"].startswith(
|
|
"https://civitai.com/api/download/models/"
|
|
):
|
|
try:
|
|
os.makedirs(f"lora_{i}", exist_ok=True)
|
|
if not os.path.isfile(
|
|
f"lora_{i}/pytorch_lora_weights.safetensors"
|
|
):
|
|
os.system(
|
|
f"wget {lora['source_model']} -O "
|
|
+ f"lora_{i}/pytorch_lora_weights.safetensors"
|
|
+ " --content-disposition"
|
|
)
|
|
lora["source_model"] = f"lora_{i}"
|
|
experts[j[i]].load_lora_weights(lora["source_model"])
|
|
experts[j[i]].fuse_lora()
|
|
except Exception:
|
|
print(f"LoRA {i} {lora['source_model']} failed to load")
|
|
else:
|
|
experts[j[i]].load_lora_weights(lora["source_model"])
|
|
experts[j[i]].fuse_lora()
|
|
|
|
# Replace FF and Attention Layers with Sparse MoE Layers
|
|
for i in range(self.down_idx_start, self.down_idx_end):
|
|
for j in range(len(self.pipe.unet.down_blocks[i].attentions)):
|
|
for k in range(
|
|
len(self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks)
|
|
):
|
|
if not moe_layers == "attn":
|
|
config = {
|
|
"hidden_size": next(
|
|
self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.ff.parameters()
|
|
).size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
# FF Layers
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.ff
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].ff = SparseMoeBlock(config, layers)
|
|
if not moe_layers == "ff":
|
|
## Attns
|
|
config = {
|
|
"hidden_size": self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": self.num_experts,
|
|
}
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_q = SparseMoeBlock(config, layers)
|
|
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_k
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_k = SparseMoeBlock(config, layers)
|
|
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_v
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_v = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_q = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
"out_dim": self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[0],
|
|
}
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_k = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[-1],
|
|
"out_dim": self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[0],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v
|
|
)
|
|
)
|
|
self.pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_v = SparseMoeBlock(config, layers)
|
|
|
|
for i in range(self.up_idx_start, self.up_idx_end):
|
|
for j in range(len(self.pipe.unet.up_blocks[i].attentions)):
|
|
for k in range(
|
|
len(self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks)
|
|
):
|
|
if not moe_layers == "attn":
|
|
config = {
|
|
"hidden_size": next(
|
|
self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.ff.parameters()
|
|
).size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
# FF Layers
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.ff
|
|
)
|
|
)
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].ff = SparseMoeBlock(config, layers)
|
|
|
|
if not moe_layers == "ff":
|
|
# Attns
|
|
config = {
|
|
"hidden_size": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
|
|
layers = []
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q
|
|
)
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_q = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_k.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
layers = []
|
|
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_k
|
|
)
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_k = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_v.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
layers = []
|
|
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_v
|
|
)
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_v = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
layers = []
|
|
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q
|
|
)
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_q = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[-1],
|
|
"out_dim": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[0],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
|
|
layers = []
|
|
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k
|
|
)
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_k = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[-1],
|
|
"out_dim": self.pipe.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[0],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": len(experts),
|
|
}
|
|
layers = []
|
|
|
|
for l in range(len(experts)):
|
|
layers.append(
|
|
deepcopy(
|
|
experts[l]
|
|
.unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v
|
|
)
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_v = SparseMoeBlock(config, layers)
|
|
|
|
# Routing Weight Initialization
|
|
if self.config.get("init", "hidden") == "hidden":
|
|
gate_params = self.get_gate_params(experts, positive, negative)
|
|
for i in range(self.down_idx_start, self.down_idx_end):
|
|
for j in range(len(self.pipe.unet.down_blocks[i].attentions)):
|
|
for k in range(
|
|
len(
|
|
self.pipe.unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks
|
|
)
|
|
):
|
|
# FF Layers
|
|
if not moe_layers == "attn":
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[k].ff.gate.weight = nn.Parameter(
|
|
gate_params[f"d{i}a{j}t{k}"]
|
|
)
|
|
|
|
# Attns
|
|
if not moe_layers == "ff":
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn1.to_q.gate.weight = nn.Parameter(
|
|
gate_params[f"sattnqd{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn1.to_k.gate.weight = nn.Parameter(
|
|
gate_params[f"sattnkd{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn1.to_v.gate.weight = nn.Parameter(
|
|
gate_params[f"sattnvd{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn2.to_q.gate.weight = nn.Parameter(
|
|
gate_params[f"cattnqd{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn2.to_k.gate.weight = nn.Parameter(
|
|
gate_params[f"cattnkd{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.down_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn2.to_v.gate.weight = nn.Parameter(
|
|
gate_params[f"cattnvd{i}a{j}t{k}"]
|
|
)
|
|
|
|
for i in range(self.up_idx_start, self.up_idx_end):
|
|
for j in range(len(self.pipe.unet.up_blocks[i].attentions)):
|
|
for k in range(
|
|
len(
|
|
self.pipe.unet.up_blocks[i].attentions[j].transformer_blocks
|
|
)
|
|
):
|
|
# FF Layers
|
|
if not moe_layers == "attn":
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[k].ff.gate.weight = nn.Parameter(
|
|
gate_params[f"u{i}a{j}t{k}"]
|
|
)
|
|
if not moe_layers == "ff":
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn1.to_q.gate.weight = nn.Parameter(
|
|
gate_params[f"sattnqu{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn1.to_k.gate.weight = nn.Parameter(
|
|
gate_params[f"sattnku{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn1.to_v.gate.weight = nn.Parameter(
|
|
gate_params[f"sattnvu{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn2.to_q.gate.weight = nn.Parameter(
|
|
gate_params[f"cattnqu{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn2.to_k.gate.weight = nn.Parameter(
|
|
gate_params[f"cattnku{i}a{j}t{k}"]
|
|
)
|
|
|
|
self.pipe.unet.up_blocks[i].attentions[
|
|
j
|
|
].transformer_blocks[
|
|
k
|
|
].attn2.to_v.gate.weight = nn.Parameter(
|
|
gate_params[f"cattnvu{i}a{j}t{k}"]
|
|
)
|
|
self.config["num_experts"] = len(experts)
|
|
remove_all_forward_hooks(self.pipe.unet)
|
|
try:
|
|
del experts
|
|
del expert
|
|
except Exception:
|
|
pass
|
|
# Move Model to Device
|
|
self.pipe.to(self.device)
|
|
self.pipe.unet.to(
|
|
device=self.device,
|
|
dtype=self.torch_dtype,
|
|
memory_format=torch.channels_last,
|
|
)
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
"""
|
|
Inference the SegMoEPipeline.
|
|
|
|
Calls diffusers.DiffusionPipeline forward with the keyword arguments. See https://github.com/segmind/segmoe#usage for detailed usage.
|
|
"""
|
|
return self.pipe(*args, **kwds)
|
|
|
|
def create_empty(self, path):
|
|
with open(f"{path}/unet/config.json", encoding='utf8') as f:
|
|
config = json.load(f)
|
|
self.config = config["segmoe_config"]
|
|
unet = UNet2DConditionModel.from_config(config)
|
|
num_experts_per_tok = self.config["num_experts_per_tok"]
|
|
num_experts = self.config["num_experts"]
|
|
moe_layers = self.config["moe_layers"]
|
|
self.up_idx_start = self.config["up_idx_start"]
|
|
self.up_idx_end = self.config["up_idx_end"]
|
|
self.down_idx_start = self.config["down_idx_start"]
|
|
self.down_idx_end = self.config["down_idx_end"]
|
|
for i in range(self.down_idx_start, self.down_idx_end):
|
|
for j in range(len(unet.down_blocks[i].attentions)):
|
|
for k in range(
|
|
len(unet.down_blocks[i].attentions[j].transformer_blocks)
|
|
):
|
|
if not moe_layers == "attn":
|
|
config = {
|
|
"hidden_size": next(
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.ff.parameters()
|
|
).size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
# FF Layers
|
|
layers = [
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[k].ff
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].ff = SparseMoeBlock(config, layers)
|
|
if not moe_layers == "ff":
|
|
## Attns
|
|
config = {
|
|
"hidden_size": unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
layers = [
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_q = SparseMoeBlock(config, layers)
|
|
|
|
layers = [
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_k
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_k = SparseMoeBlock(config, layers)
|
|
|
|
layers = [
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_v
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_v = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
|
|
layers = [
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_q = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
"out_dim": unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[0],
|
|
}
|
|
layers = [
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_k = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[-1],
|
|
"out_dim": unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[0],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
layers = [
|
|
unet.down_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v
|
|
] * num_experts
|
|
unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_v = SparseMoeBlock(config, layers)
|
|
for i in range(self.up_idx_start, self.up_idx_end):
|
|
for j in range(len(unet.up_blocks[i].attentions)):
|
|
for k in range(len(unet.up_blocks[i].attentions[j].transformer_blocks)):
|
|
if not moe_layers == "attn":
|
|
config = {
|
|
"hidden_size": next(
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.ff.parameters()
|
|
).size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
# FF Layers
|
|
layers = [
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[k].ff
|
|
] * num_experts
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].ff = SparseMoeBlock(config, layers)
|
|
|
|
if not moe_layers == "ff":
|
|
# Attns
|
|
config = {
|
|
"hidden_size": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
|
|
layers = [
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_q
|
|
] * num_experts
|
|
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_q = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_k.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
layers = [
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_k
|
|
] * num_experts
|
|
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_k = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_v.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
layers = [
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn1.to_v
|
|
] * num_experts
|
|
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_v = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q.weight.size()[-1],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
layers = [
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_q
|
|
] * num_experts
|
|
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_q = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[-1],
|
|
"out_dim": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k.weight.size()[0],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
|
|
layers = [
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_k
|
|
] * num_experts
|
|
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_k = SparseMoeBlock(config, layers)
|
|
|
|
config = {
|
|
"hidden_size": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[-1],
|
|
"out_dim": unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v.weight.size()[0],
|
|
"num_experts_per_tok": num_experts_per_tok,
|
|
"num_local_experts": num_experts,
|
|
}
|
|
layers = [
|
|
unet.up_blocks[i]
|
|
.attentions[j]
|
|
.transformer_blocks[k]
|
|
.attn2.to_v
|
|
] * num_experts
|
|
|
|
unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_v = SparseMoeBlock(config, layers)
|
|
return unet
|
|
|
|
def save_pretrained(self, path):
|
|
"""
|
|
Save SegMoEPipeline to Disk.
|
|
|
|
Usage:
|
|
pipeline.save_pretrained(path)
|
|
|
|
Parameters:
|
|
path: Path to Directory to save the model in.
|
|
"""
|
|
for param in self.pipe.unet.parameters():
|
|
param.data = param.data.contiguous()
|
|
self.pipe.unet.config["segmoe_config"] = self.config
|
|
self.pipe.save_pretrained(path)
|
|
safetensors.torch.save_file(
|
|
self.pipe.unet.state_dict(),
|
|
f"{path}/unet/diffusion_pytorch_model.safetensors",
|
|
)
|
|
|
|
def cast_hook(self, pipe, dicts):
|
|
for i in range(self.down_idx_start, self.down_idx_end):
|
|
for j in range(len(pipe.unet.down_blocks[i].attentions)):
|
|
for k in range(
|
|
len(pipe.unet.down_blocks[i].attentions[j].transformer_blocks)
|
|
):
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].ff.register_forward_hook(getActivation(dicts, f"d{i}a{j}t{k}"))
|
|
|
|
## Down Self Attns
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_q.register_forward_hook(
|
|
getActivation(dicts, f"sattnqd{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_k.register_forward_hook(
|
|
getActivation(dicts, f"sattnkd{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_v.register_forward_hook(
|
|
getActivation(dicts, f"sattnvd{i}a{j}t{k}")
|
|
)
|
|
|
|
## Down Cross Attns
|
|
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_q.register_forward_hook(
|
|
getActivation(dicts, f"cattnqd{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_k.register_forward_hook(
|
|
getActivation(dicts, f"cattnkd{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.down_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_v.register_forward_hook(
|
|
getActivation(dicts, f"cattnvd{i}a{j}t{k}")
|
|
)
|
|
|
|
for i in range(self.up_idx_start, self.up_idx_end):
|
|
for j in range(len(pipe.unet.up_blocks[i].attentions)):
|
|
for k in range(
|
|
len(pipe.unet.up_blocks[i].attentions[j].transformer_blocks)
|
|
):
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].ff.register_forward_hook(getActivation(dicts, f"u{i}a{j}t{k}"))
|
|
## Up Self Attns
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_q.register_forward_hook(
|
|
getActivation(dicts, f"sattnqu{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_k.register_forward_hook(
|
|
getActivation(dicts, f"sattnku{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn1.to_v.register_forward_hook(
|
|
getActivation(dicts, f"sattnvu{i}a{j}t{k}")
|
|
)
|
|
|
|
## Up Cross Attns
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_q.register_forward_hook(
|
|
getActivation(dicts, f"cattnqu{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_k.register_forward_hook(
|
|
getActivation(dicts, f"cattnku{i}a{j}t{k}")
|
|
)
|
|
pipe.unet.up_blocks[i].attentions[j].transformer_blocks[
|
|
k
|
|
].attn2.to_v.register_forward_hook(
|
|
getActivation(dicts, f"cattnvu{i}a{j}t{k}")
|
|
)
|
|
|
|
@torch.no_grad
|
|
def get_hidden_states(self, model, positive, negative, average: bool = True):
|
|
intermediate = {}
|
|
self.cast_hook(model, intermediate)
|
|
with torch.no_grad():
|
|
_ = model(positive, negative_prompt=negative, num_inference_steps=25)
|
|
hidden = {}
|
|
for key in intermediate:
|
|
hidden_states = intermediate[key][0][-1]
|
|
if average:
|
|
# use average over sequence
|
|
hidden_states = hidden_states.sum(dim=0) / hidden_states.shape[0]
|
|
else:
|
|
# take last value
|
|
hidden_states = hidden_states[:-1]
|
|
hidden[key] = hidden_states.to(self.device)
|
|
del intermediate
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
return hidden
|
|
|
|
@torch.no_grad
|
|
def get_gate_params(
|
|
self,
|
|
experts,
|
|
positive,
|
|
negative,
|
|
):
|
|
gate_vects = {}
|
|
for i, expert in enumerate(tqdm.tqdm(experts, desc="Expert Prompts")):
|
|
expert.to(self.device)
|
|
expert.unet.to(
|
|
device=self.device,
|
|
dtype=self.torch_dtype,
|
|
memory_format=torch.channels_last,
|
|
)
|
|
hidden_states = self.get_hidden_states(expert, positive[i], negative[i])
|
|
del expert
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
for h in hidden_states:
|
|
if i == 0:
|
|
gate_vects[h] = []
|
|
hidden_states[h] /= (
|
|
hidden_states[h].norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
|
|
)
|
|
gate_vects[h].append(hidden_states[h])
|
|
for h in hidden_states:
|
|
gate_vects[h] = torch.stack(
|
|
gate_vects[h], dim=0
|
|
) # (num_expert, num_layer, hidden_size)
|
|
gate_vects[h].permute(1, 0)
|
|
|
|
return gate_vects
|