mirror of https://github.com/vladmandic/automatic
235 lines
9.8 KiB
Python
235 lines
9.8 KiB
Python
"""Sharpfin transform classes for torchvision integration.
|
|
|
|
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
|
Imports patched: absolute sharpfin.X -> relative .X, torchvision guarded.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
try:
|
|
from torchvision.transforms.v2 import Transform
|
|
except ImportError:
|
|
class Transform:
|
|
_transformed_types = ()
|
|
def __init__(self):
|
|
pass
|
|
|
|
from .util import QuantHandling, ResizeKernel, SharpenKernel, srgb_to_linear, linear_to_srgb
|
|
from . import functional as SFF
|
|
from .cms import apply_srgb
|
|
import math
|
|
from typing import Any, Dict, Tuple
|
|
from PIL import Image
|
|
from .functional import _get_resize_kernel
|
|
from contextlib import nullcontext
|
|
|
|
try:
|
|
from .triton_functional import downscale_sparse
|
|
except ImportError:
|
|
downscale_sparse = None
|
|
|
|
# from Pytorch >= 2.6
|
|
set_stance = getattr(torch.compiler, "set_stance", None)
|
|
|
|
__all__ = ["ResizeKernel", "SharpenKernel", "QuantHandling"]
|
|
|
|
class Scale(Transform):
|
|
"""Rescaling transform supporting multiple algorithms with sRGB linearization."""
|
|
_transformed_types = (torch.Tensor,)
|
|
def __init__(self,
|
|
out_res: tuple[int, int] | int,
|
|
device: torch.device | str = torch.device('cpu'),
|
|
dtype: torch.dtype = torch.float32,
|
|
out_dtype: torch.dtype | None = None,
|
|
quantization: QuantHandling = QuantHandling.ROUND,
|
|
generator: torch.Generator | None = None,
|
|
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
|
sharpen_kernel: SharpenKernel | None = None,
|
|
do_srgb_conversion: bool = True,
|
|
use_sparse: bool = False,
|
|
):
|
|
super().__init__()
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
if not dtype.is_floating_point:
|
|
raise ValueError("dtype must be a floating point type")
|
|
if dtype.itemsize == 1:
|
|
raise ValueError("float8 types are not supported due to severe accuracy issues and limited function support. float16 or float32 is recommended.")
|
|
if out_dtype is not None and not out_dtype.is_floating_point and out_dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
|
|
raise ValueError("out_dtype must be a torch float format or a torch unsigned int format")
|
|
if use_sparse:
|
|
assert device.type != 'cpu'
|
|
if resize_kernel != ResizeKernel.MAGIC_KERNEL_SHARP_2021:
|
|
raise NotImplementedError
|
|
self.use_sparse = use_sparse
|
|
|
|
if isinstance(out_res, int):
|
|
out_res = (out_res, out_res)
|
|
self.out_res = out_res
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.out_dtype = out_dtype if out_dtype is not None else dtype
|
|
self.do_srgb_conversion = do_srgb_conversion
|
|
|
|
if self.out_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
|
|
match quantization:
|
|
case QuantHandling.TRUNCATE:
|
|
self.quantize_function = lambda x: x.mul(torch.iinfo(self.out_dtype).max).to(self.out_dtype)
|
|
case QuantHandling.ROUND:
|
|
self.quantize_function = lambda x: x.mul(torch.iinfo(self.out_dtype).max).round().to(self.out_dtype)
|
|
case QuantHandling.STOCHASTIC_ROUND:
|
|
if generator is not None:
|
|
self.generator = torch.Generator(self.device)
|
|
else:
|
|
self.generator = generator
|
|
self.quantize_function = lambda x: SFF.stochastic_round(x, self.out_dtype, self.generator)
|
|
case QuantHandling.BAYER:
|
|
self.bayer_matrix = torch.tensor(SFF.generate_bayer_matrix(16), dtype=self.dtype, device=self.device) / 255
|
|
self.quantize_function = lambda x: self.apply_bayer_matrix(x)
|
|
case _:
|
|
raise ValueError(f"Unknown quantization handling type {quantization}")
|
|
else:
|
|
self.quantize_function = lambda x: x.to(dtype=out_dtype)
|
|
|
|
self.resize_kernel, self.kernel_window = _get_resize_kernel(resize_kernel)
|
|
|
|
match sharpen_kernel:
|
|
case SharpenKernel.SHARP_2013:
|
|
kernel = torch.tensor([-1, 6, -1], dtype=dtype, device=device) / 4
|
|
self.sharp_2013_kernel = torch.outer(kernel, kernel).view(1, 1, 3, 3).expand(3, -1, -1, -1)
|
|
self.sharpen_step = lambda x: SFF.sharpen_conv2d(x, self.sharp_2013_kernel, 1)
|
|
case SharpenKernel.SHARP_2021:
|
|
kernel = torch.tensor([-1, 6, -35, 204, -35, 6, -1], dtype=dtype, device=device) / 144
|
|
self.sharp_2021_kernel = torch.outer(kernel, kernel).view(1, 1, 7, 7).expand(3, -1, -1, -1)
|
|
self.sharpen_step = lambda x: SFF.sharpen_conv2d(x, self.sharp_2021_kernel, 3)
|
|
case None:
|
|
self.sharpen_step = lambda x: x
|
|
case _:
|
|
raise ValueError(f"Unknown sharpen kernel {sharpen_kernel}")
|
|
|
|
def apply_bayer_matrix(self, x: torch.Tensor):
|
|
H, W = x.shape[-2:]
|
|
b = self.bayer_matrix.repeat(1,1,math.ceil(H/16),math.ceil(W/16))[:,:,:H,:W]
|
|
return (x*255 + b).to(self.out_dtype)
|
|
|
|
@torch.compile(disable=False)
|
|
def downscale(self, image: torch.Tensor, out_res: tuple[int, int]):
|
|
H, W = out_res
|
|
image = image.to(dtype=self.dtype)
|
|
if self.do_srgb_conversion:
|
|
image = srgb_to_linear(image)
|
|
|
|
image = SFF._downscale_axis(image, W, self.kernel_window, self.resize_kernel, self.device, self.dtype)
|
|
image = SFF._downscale_axis(image.mT, H, self.kernel_window, self.resize_kernel, self.device, self.dtype).mT
|
|
|
|
image = self.sharpen_step(image)
|
|
|
|
if self.do_srgb_conversion:
|
|
image = linear_to_srgb(image)
|
|
image = image.clamp(0,1)
|
|
image = self.quantize_function(image)
|
|
return image
|
|
|
|
@torch.compile(disable=False)
|
|
def downscale_sparse(self, image: torch.Tensor, out_res: tuple[int, int]):
|
|
image = image.to(dtype=self.dtype)
|
|
if downscale_sparse is not None:
|
|
image = downscale_sparse(image, out_res)
|
|
image = self.quantize_function(image)
|
|
return image
|
|
|
|
@torch.compile(disable=False)
|
|
def upscale(self, image: torch.Tensor, out_res: tuple[int, int]):
|
|
H, W = out_res
|
|
image = image.to(dtype=self.dtype)
|
|
if self.do_srgb_conversion:
|
|
image = srgb_to_linear(image)
|
|
|
|
image = self.sharpen_step(image)
|
|
|
|
image = SFF._upscale_axis(image, W, self.kernel_window, self.resize_kernel, self.device, self.dtype)
|
|
image = SFF._upscale_axis(image.mT, H, self.kernel_window, self.resize_kernel, self.device, self.dtype).mT
|
|
|
|
if self.do_srgb_conversion:
|
|
image = linear_to_srgb(image)
|
|
image = image.clamp(0,1)
|
|
image = self.quantize_function(image)
|
|
return image
|
|
|
|
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> torch.Tensor:
|
|
image = inpt.to(device=self.device)
|
|
context_manager = (
|
|
set_stance("force_eager") if set_stance and self.device.type == "cpu" else nullcontext()
|
|
)
|
|
with context_manager:
|
|
if image.shape[-1] <= self.out_res[-1] and image.shape[-2] <= self.out_res[-2]:
|
|
return self.upscale(image, self.out_res)
|
|
elif image.shape[-1] >= self.out_res[-1] and image.shape[-2] >= self.out_res[-2]:
|
|
if self.use_sparse:
|
|
return self.downscale_sparse(image, self.out_res)
|
|
return self.downscale(image, self.out_res)
|
|
else:
|
|
raise ValueError("Mixed axis resizing (e.g. scaling one axis up and the other down) is not supported. File a bug report with your use case if needed.")
|
|
|
|
class ApplyCMS(Transform):
|
|
"""Apply color management to a PIL Image to standardize it to sRGB color space."""
|
|
_transformed_types = (Image.Image,)
|
|
|
|
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
|
if not isinstance(inpt, Image.Image):
|
|
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
|
|
|
|
return apply_srgb(inpt)
|
|
|
|
class AlphaComposite(Transform):
|
|
_transformed_types = (Image.Image,)
|
|
def __init__(
|
|
self,
|
|
background: Tuple[int,int,int] = (255, 255, 255)
|
|
):
|
|
super().__init__()
|
|
self.background = background
|
|
|
|
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
|
if not isinstance(inpt, Image.Image):
|
|
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
|
|
if not inpt.has_transparency_data:
|
|
return inpt
|
|
|
|
bg = Image.new("RGB", inpt.size, self.background).convert('RGBA')
|
|
|
|
return Image.alpha_composite(bg, inpt).convert('RGB')
|
|
|
|
class AspectRatioCrop(Transform):
|
|
_transformed_types = (Image.Image,)
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
height: int,
|
|
):
|
|
super().__init__()
|
|
self.ref_width = width
|
|
self.ref_height = height
|
|
self.aspect_ratio = width / height
|
|
|
|
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
|
if not isinstance(inpt, Image.Image):
|
|
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
|
|
|
|
left, top, right, bottom = 0, 0, inpt.width, inpt.height
|
|
inpt_ar = inpt.width / inpt.height
|
|
|
|
if inpt_ar > self.aspect_ratio:
|
|
result_width = int(round(inpt.height / self.ref_height * self.ref_width))
|
|
crop_amt = (inpt.width - result_width) // 2
|
|
left += crop_amt
|
|
right -= crop_amt
|
|
elif inpt_ar < self.aspect_ratio:
|
|
result_height = int(round(inpt.width / self.ref_width * self.ref_height))
|
|
crop_amt = (inpt.height - result_height) // 2
|
|
top += crop_amt
|
|
bottom -= crop_amt
|
|
|
|
return inpt.crop((left, top, right, bottom))
|