172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
import fnmatch
|
|
import cv2
|
|
|
|
import sys
|
|
|
|
import numpy as np
|
|
from modules import devices
|
|
from einops import rearrange
|
|
from annotator.annotator_path import models_path
|
|
|
|
import torchvision
|
|
from torchvision.models import MobileNet_V2_Weights
|
|
from torchvision import transforms
|
|
|
|
COLOR_BACKGROUND = (255,255,0)
|
|
COLOR_HAIR = (0,0,255)
|
|
COLOR_EYE = (255,0,0)
|
|
COLOR_MOUTH = (255,255,255)
|
|
COLOR_FACE = (0,255,0)
|
|
COLOR_SKIN = (0,255,255)
|
|
COLOR_CLOTHES = (255,0,255)
|
|
PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(self):
|
|
super(UNet, self).__init__()
|
|
self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
|
|
|
|
mobilenet_v2 = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
|
|
mob_blocks = mobilenet_v2.features
|
|
|
|
# Encoder
|
|
self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
|
|
mob_blocks[0],
|
|
mob_blocks[1]
|
|
)
|
|
self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
|
|
mob_blocks[2],
|
|
mob_blocks[3],
|
|
)
|
|
self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
|
|
mob_blocks[4],
|
|
mob_blocks[5],
|
|
mob_blocks[6],
|
|
)
|
|
self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
|
|
mob_blocks[7],
|
|
mob_blocks[8],
|
|
mob_blocks[9],
|
|
mob_blocks[10],
|
|
mob_blocks[11],
|
|
mob_blocks[12],
|
|
mob_blocks[13],
|
|
)
|
|
self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
|
|
mob_blocks[14],
|
|
mob_blocks[15],
|
|
mob_blocks[16],
|
|
)
|
|
|
|
# Decoder
|
|
self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(160, 96, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(96),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(p=0.2)
|
|
)
|
|
self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(32),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(p=0.2)
|
|
)
|
|
self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(24),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(p=0.2)
|
|
)
|
|
self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
|
|
nn.InstanceNorm2d(16),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(p=0.2)
|
|
)
|
|
|
|
self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
|
|
nn.Softmax2d()
|
|
)
|
|
|
|
def forward(self, x):
|
|
e0 = self.en_block0(x)
|
|
e1 = self.en_block1(e0)
|
|
e2 = self.en_block2(e1)
|
|
e3 = self.en_block3(e2)
|
|
e4 = self.en_block4(e3)
|
|
|
|
d4 = self.de_block4(e4)
|
|
d4 = F.interpolate(d4, size=e3.size()[2:], mode='bilinear', align_corners=True)
|
|
c4 = torch.cat((d4,e3),1)
|
|
|
|
d3 = self.de_block3(c4)
|
|
d3 = F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=True)
|
|
c3 = torch.cat((d3,e2),1)
|
|
|
|
d2 = self.de_block2(c3)
|
|
d2 = F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=True)
|
|
c2 =torch.cat((d2,e1),1)
|
|
|
|
d1 = self.de_block1(c2)
|
|
d1 = F.interpolate(d1, size=e0.size()[2:], mode='bilinear', align_corners=True)
|
|
c1 = torch.cat((d1,e0),1)
|
|
y = self.de_block0(c1)
|
|
|
|
return y
|
|
|
|
|
|
class AnimeFaceSegment:
|
|
|
|
model_dir = os.path.join(models_path, "anime_face_segment")
|
|
|
|
def __init__(self):
|
|
self.model = None
|
|
self.device = devices.get_device_for("controlnet")
|
|
|
|
def load_model(self):
|
|
remote_model_path = "https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/Annotators/UNet.pth"
|
|
modelpath = os.path.join(self.model_dir, "UNet.pth")
|
|
if not os.path.exists(modelpath):
|
|
from basicsr.utils.download_util import load_file_from_url
|
|
load_file_from_url(remote_model_path, model_dir=self.model_dir)
|
|
net = UNet()
|
|
ckpt = torch.load(modelpath, map_location=self.device)
|
|
for key in list(ckpt.keys()):
|
|
if 'module.' in key:
|
|
ckpt[key.replace('module.', '')] = ckpt[key]
|
|
del ckpt[key]
|
|
net.load_state_dict(ckpt)
|
|
net.eval()
|
|
self.model = net.to(self.device)
|
|
|
|
def unload_model(self):
|
|
if self.model is not None:
|
|
self.model.cpu()
|
|
|
|
def __call__(self, input_image):
|
|
|
|
if self.model is None:
|
|
self.load_model()
|
|
self.model.to(self.device)
|
|
transform = transforms.Compose([
|
|
transforms.Resize(512,interpolation=transforms.InterpolationMode.BICUBIC),
|
|
transforms.ToTensor(),])
|
|
img = Image.fromarray(input_image)
|
|
with torch.no_grad():
|
|
img = transform(img).unsqueeze(dim=0).to(self.device)
|
|
seg = self.model(img).squeeze(dim=0)
|
|
seg = seg.cpu().detach().numpy()
|
|
img = rearrange(seg,'h w c -> w c h')
|
|
img = [[PALETTE[np.argmax(val)] for val in buf]for buf in img]
|
|
return np.array(img).astype(np.uint8) |