pull/31/merge
チャールズ 2024-07-08 00:48:46 +09:00 committed by GitHub
commit 2d48e3629b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 79 additions and 24 deletions

View File

@ -12,7 +12,8 @@ import numpy as np
from pixelization.models.networks import define_G
import pixelization.models.c2pGen
import gdown
import colorsys
pixelize_code = [
233356.8125, -27387.5918, -32866.8008, 126575.0312, -181590.0156,
@ -107,20 +108,19 @@ class Model(torch.nn.Module):
missing = False
models = (
(path_pixelart_vgg19, "https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM"),
(path_160_net_G_A, "https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az"),
(path_alias_net, "https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_"),
)
if not os.path.exists(path_pixelart_vgg19):
print(f"Missing {path_pixelart_vgg19} - download it from https://drive.google.com/uc?id=1VRYKQOsNlE1w1LXje3yTRU5THN2MGdMM")
missing = True
for path, url in models:
if not os.path.exists(path):
gdown.download(url, path)
if not os.path.exists(path_160_net_G_A):
print(f"Missing {path_160_net_G_A} - download it from https://drive.google.com/uc?id=1i_8xL3stbLWNF4kdQJ50ZhnRFhSDh3Az")
missing = True
if not os.path.exists(path):
missing = True
if not os.path.exists(path_alias_net):
print(f"Missing {path_alias_net} - download it from https://drive.google.com/uc?id=17f2rKnZOpnO9ATwRXgqLz5u5AZsyDvq_")
missing = True
assert not missing, f'Missing checkpoints for pixelization - see console for download links. Download checkpoints manually and place them in {path_checkpoints}.'
assert not missing, 'Missing checkpoints for pixelization - see console for doqwnload links.'
with torch.no_grad():
self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
@ -136,7 +136,6 @@ class Model(torch.nn.Module):
alias_state["module." + str(p)] = alias_state.pop(p)
self.alias_net.load_state_dict(alias_state)
def process(img):
ow, oh = img.size
@ -150,22 +149,71 @@ def process(img):
img = img.crop((left, top, right, bottom))
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Split the RGBA image into RGB and alpha channels
img_rgba = img.convert('RGBA')
r, g, b, a = img_rgba.split()
return trans(img)[None, :, :, :]
# Convert RGB to tensor and normalize
rgb_img = Image.merge('RGB', (r, g, b))
trans_rgb = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
rgb_tensor = trans_rgb(rgb_img)
# Convert alpha channel to tensor (scale from 0-255 to 0-1)
alpha_tensor = transforms.ToTensor()(a)[None, :, :] # Add an extra dimension for batch size
def to_image(tensor, pixel_size, upscale_after):
return rgb_tensor[None, :, :, :], alpha_tensor
def to_image(tensor, alpha_tensor, pixel_size, upscale_after, original_img, copy_hue, copy_sat):
img = tensor.data[0].cpu().float().numpy()
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
img = img.astype(np.uint8)
img = Image.fromarray(img)
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
width = img.size[0] // 4
height = img.size[1] // 4
img = img.resize((width, height), resample=Image.Resampling.NEAREST)
# Resize the alpha channel to match the new dimensions
alpha_img = alpha_tensor.data[0].cpu().numpy()
alpha_img = (alpha_img * 255).astype(np.uint8)
alpha_img = Image.fromarray(alpha_img.squeeze(), mode='L')
alpha_img = alpha_img.resize((width, height), resample=Image.Resampling.NEAREST)
if copy_hue or copy_sat:
original_img = original_img.resize((width, height), resample=Image.Resampling.NEAREST)
img = color_image(img, original_img, copy_hue, copy_sat)
# Merge the processed RGB image with the alpha channel
img.putalpha(alpha_img)
if upscale_after:
img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST)
return img
def color_image(img, original_img, copy_hue, copy_sat):
img = img.convert("RGB")
original_img = original_img.convert("RGB")
colored_img = Image.new("RGB", img.size)
for x in range(img.width):
for y in range(img.height):
pixel = original_img.getpixel((x, y))
r, g, b = pixel
original_h, original_s, original_v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)
pixel = img.getpixel((x, y))
r, g, b = pixel
h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)
r, g, b = colorsys.hsv_to_rgb(original_h if copy_hue else h, original_s if copy_sat else s, v)
colored_img.putpixel((x, y), (int(r * 255), int(g * 255), int(b * 255)))
return colored_img
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
name = "Pixelization"
@ -175,16 +223,21 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def ui(self):
with ui_components.InputAccordion(False, label="Pixelize") as enable:
with gr.Row():
upscale_after = gr.Checkbox(False, label="Keep resolution")
upscale_after = gr.Checkbox(False, label="Keep resolution")
copy_hue = gr.Checkbox(False, label="Restore hue")
copy_sat = gr.Checkbox(False, label="Restore saturation")
with gr.Column():
pixel_size = gr.Slider(minimum=1, maximum=16, step=1, label="Pixel size", value=4, elem_id="pixelization_pixel_size")
return {
"enable": enable,
"upscale_after": upscale_after,
"pixel_size": pixel_size,
"copy_hue": copy_hue,
"copy_sat": copy_sat,
}
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size):
def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, upscale_after, pixel_size, copy_hue, copy_sat):
if not enable:
return
@ -196,20 +249,22 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
self.model.to(devices.device)
pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size)).convert('RGB')
pp.image = pp.image.resize((pp.image.width * 4 // pixel_size, pp.image.height * 4 // pixel_size))
original_img = pp.image.copy()
with torch.no_grad():
in_t = process(pp.image).to(devices.device)
in_t, alpha_t = process(pp.image)
in_t = in_t.to(devices.device)
alpha_t = alpha_t.to(devices.device)
feature = self.model.G_A_net.module.RGBEnc(in_t)
code = torch.asarray(pixelize_code, device=devices.device).reshape((1, 256, 1, 1))
code = torch.tensor(pixelize_code, device=devices.device).reshape((1, 256, 1, 1))
adain_params = self.model.G_A_net.module.MLP(code)
images = self.model.G_A_net.module.RGBDec(feature, adain_params)
out_t = self.model.alias_net(images)
pp.image = to_image(out_t, pixel_size=pixel_size, upscale_after=upscale_after)
pp.image = to_image(out_t, alpha_t, pixel_size=pixel_size, upscale_after=upscale_after, original_img=original_img, copy_hue=copy_hue, copy_sat=copy_sat)
self.model.to(devices.cpu)
pp.info["Pixelization pixel size"] = pixel_size