sd_smartprocess/upscalers/spandrel/spandrel_upscaler_base.py

224 lines
7.2 KiB
Python

from __future__ import annotations
import logging
import os
import re
import time
import traceback
import zipfile
from abc import abstractmethod
from pathlib import Path
from urllib.request import urlretrieve
import PIL
import cv2
import numpy as np
import torch
from PIL import Image
from spandrel import (
ImageModelDescriptor,
ModelDescriptor, ModelLoader,
)
from modules import shared
from modules.upscaler import Upscaler, UpscalerData, NEAREST
logger = logging.getLogger(__name__)
def convert_google_drive_link(url: str) -> str:
pattern = re.compile(
r"^https://drive.google.com/file/d/([a-zA-Z0-9_\-]+)/view(?:\?.*)?$"
)
m = pattern.match(url)
if not m:
return url
file_id = m.group(1)
return "https://drive.google.com/uc?export=download&confirm=1&id=" + file_id
def download_file(url: str, filename: Path | str) -> None:
filename = Path(filename)
filename.parent.mkdir(exist_ok=True)
url = convert_google_drive_link(url)
temp_filename = filename.with_suffix(f".part-{int(time.time())}")
try:
logger.info("Downloading %s to %s", url, filename)
path, _ = urlretrieve(url, filename=temp_filename)
temp_filename.rename(filename)
finally:
try:
temp_filename.unlink()
except FileNotFoundError:
pass
def extract_file_from_zip(
zip_path: Path | str,
rel_model_path: str,
filename: Path | str,
):
filename = Path(filename)
filename.parent.mkdir(exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
with open(filename, "wb") as f:
f.write(zip_ref.read(rel_model_path))
def image_to_tensor(img: np.ndarray, device: str, half) -> torch.Tensor:
img = img.astype(np.float32) / 255.0
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
if img.shape[2] == 1:
pass
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img, (2, 0, 1))
tensor = torch.from_numpy(img).to(device)
if half is not None:
tensor = tensor.to(half)
return tensor.unsqueeze(0)
def tensor_to_image(tensor: torch.Tensor) -> np.ndarray:
image = tensor.cpu().squeeze().numpy()
image = np.transpose(image, (1, 2, 0))
image = np.clip((image * 255.0).round(), 0, 255)
image = image.astype(np.uint8)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
return image
def image_inference_tensor(
model: ImageModelDescriptor, tensor: torch.Tensor
) -> torch.Tensor:
model.eval()
with torch.no_grad():
return model(tensor)
def image_inference(model: ImageModelDescriptor, image: np.ndarray, device: str, half: str) -> np.ndarray:
return tensor_to_image(image_inference_tensor(model, image_to_tensor(image, device, half)))
def get_h_w_c(image: np.ndarray) -> tuple[int, int, int]:
if len(image.shape) == 2:
return image.shape[0], image.shape[1], 1
return image.shape[0], image.shape[1], image.shape[2]
class SpandrelUpscaler(Upscaler):
model_url = ""
model_type = ""
model_file = ""
scale = 4
def __init__(self, create_dirs=False):
super().__init__(create_dirs)
self.name = "Spandrel"
self.scale = 1
self.scalers = []
def do_upscale(self, img: PIL.Image, selected_model: str):
self.load_model(selected_model)
return self.internal_upscale(img)
def load_model(self, path: str):
self.model = ModelLoader().load_from_file(path)
print(f"Model size reqs: {self.model.size_requirements}")
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(device)
if self.model.supports_half:
self.model.to(torch.half)
self.model.eval()
def preprocess(self, image: Image) -> Image:
square = self.model.size_requirements.square
minimum = self.model.size_requirements.minimum
multiple_of = self.model.size_requirements.multiple_of
if square:
# Pad the shorter side to make the image square
size = max(image.width, image.height)
new_image = Image.new("RGB", (size, size))
new_image.paste(image, ((size - image.width) // 2, (size - image.height) // 2))
image = new_image
if minimum > 1:
size = max(image.width, image.height)
if size < minimum:
new_width = int(minimum * image.width / size)
new_height = int(minimum * image.height / size)
image = image.resize((new_width, new_height), resample=NEAREST)
if multiple_of > 1:
new_width = int(multiple_of * image.width // multiple_of)
new_height = int(multiple_of * image.height // multiple_of)
image = image.resize((new_width, new_height), resample=NEAREST)
return image
def postprocess(self, image: Image, original_width: int, original_height: int) -> Image:
square = self.model.size_requirements.square
if square:
original_aspect_ratio = original_width / original_height
current_aspect_ratio = image.width / image.height
if current_aspect_ratio > original_aspect_ratio:
# Image is wider than the original, crop width
new_width = int(original_aspect_ratio * image.height)
left = (image.width - new_width) // 2
image = image.crop((left, 0, left + new_width, image.height))
elif current_aspect_ratio < original_aspect_ratio:
# Image is taller than the original, crop height
new_height = int(image.width / original_aspect_ratio)
top = (image.height - new_height) // 2
image = image.crop((0, top, image.width, top + new_height))
return image
def internal_upscale(self, image: Image):
original_width = image.width
original_height = image.height
needs_preprocess = self.model.size_requirements.check(image.width, image.height)
if not needs_preprocess:
image = self.preprocess(image)
# Convert image to cv2 format
image = np.array(image)
image_h, image_w, image_c = get_h_w_c(image)
if self.model.input_channels == 1 and image_c == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
device = "cuda" if torch.cuda.is_available() else "cpu"
# If the image size is already greater than 2048, we'll likely OOM on GPU, so do it on CPU
if image_h > 2048 or image_w > 2048:
device = "cpu"
try:
self.model.to(device)
half = None
if self.model.supports_half:
half = torch.half
output = image_inference(self.model, image, device, half)
# Convert output to PIL format
output = Image.fromarray(output)
if needs_preprocess:
output = self.postprocess(output, original_width, original_height)
return output
except Exception as e:
print(f"Failed to upscale image: {e}")
traceback.print_exc()
return Image.fromarray(image)
def unload(self):
try:
del self.model
except:
pass
self.model = None
def load(self):
pass