automatic/modules/images_sharpfin.py

299 lines
14 KiB
Python

"""Sharpfin wrapper for high-quality image resize and tensor conversion.
Provides drop-in replacements for torchvision.transforms.functional operations
with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization,
and Triton GPU acceleration when available.
Non-CUDA devices fall back to PIL/torch.nn.functional automatically.
"""
import sys
import torch
import numpy as np
from PIL import Image
from installer import log
_sharpfin_checked = False
_sharpfin_ok = False
_triton_ok = False
def check_sharpfin():
global _sharpfin_checked, _sharpfin_ok, _triton_ok # pylint: disable=global-statement
if not _sharpfin_checked:
from modules.sharpfin.functional import scale # pylint: disable=unused-import
_sharpfin_ok = True
try:
from modules.sharpfin import TRITON_AVAILABLE
_triton_ok = TRITON_AVAILABLE
except Exception:
_triton_ok = False
_sharpfin_checked = True
KERNEL_MAP = {
"Sharpfin MKS2021": "MAGIC_KERNEL_SHARP_2021",
"Sharpfin Lanczos3": "LANCZOS3",
"Sharpfin Mitchell": "MITCHELL",
"Sharpfin Catmull-Rom": "CATMULL_ROM",
}
def get_kernel(kernel=None):
"""Resolve kernel name to ResizeKernel enum. Returns None for PIL fallback."""
if kernel is not None:
name = kernel
else:
from modules import shared
name = shared.opts.resize_quality
if name == "PIL Lanczos" or name not in KERNEL_MAP:
return None
from modules.sharpfin.util import ResizeKernel
return getattr(ResizeKernel, KERNEL_MAP[name])
def get_linearize(linearize=None, is_mask=False):
"""Determine sRGB linearization setting."""
if is_mask:
return False
if linearize is not None:
return linearize
from modules import shared
return shared.opts.resize_linearize_srgb
def allow_sharpfin(device=None):
"""Determine if sharpfin should be used based on device."""
if device is None:
from modules import devices
device = devices.device
# Sharpfin is optimized for CUDA with Triton, for other devices (CPU, MPS, OpenVINO), use torch/PIL optimized kernels
return hasattr(device, 'type') and device.type == 'cuda'
def resize(image, target_size, *, kernel=None, linearize=None, device=None, dtype=None):
"""Resize PIL.Image or torch.Tensor, returning same type.
Args:
image: PIL.Image or torch.Tensor [B,C,H,W] / [C,H,W]
target_size: (width, height) for PIL, (H, W) for tensor
kernel: Override kernel name, or None for settings
linearize: Override sRGB linearization, or None for settings
device: Override compute device
dtype: Override compute dtype
"""
check_sharpfin()
if isinstance(image, Image.Image):
return resize_pil(image, target_size, kernel=kernel, linearize=linearize, device=device, dtype=dtype)
elif isinstance(image, torch.Tensor):
return resize_tensor(image, target_size, kernel=kernel, linearize=linearize if linearize is not None else False)
return image
def _want_sparse(dev, rk, both_down):
"""Check if Triton sparse acceleration should be attempted."""
return _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
def _scale_pil(scale_fn, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up):
"""Run sharpfin scale with sparse fallback. Returns result tensor."""
global _triton_ok # pylint: disable=global-statement
if both_down or both_up:
use_sparse = _want_sparse(dev, rk, both_down)
if use_sparse:
try:
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
except Exception:
_triton_ok = False
log.info("Sharpfin: Triton sparse disabled, using dense path")
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
# Mixed axis: split into two single-axis resizes
if h > src_h: # H up, W down
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
use_sparse = _want_sparse(dev, rk, True)
if use_sparse:
try:
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
except Exception:
_triton_ok = False
log.info("Sharpfin: Triton sparse disabled, using dense path")
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
# H down, W up
use_sparse = _want_sparse(dev, rk, True)
if use_sparse:
try:
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
except Exception:
_triton_ok = False
log.info("Sharpfin: Triton sparse disabled, using dense path")
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
def resize_pil(image: Image.Image, target_size: tuple[int, int], *, kernel=None, linearize=None, device=None, dtype=None):
"""Resize a PIL Image via sharpfin, falling back to PIL on error."""
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
w, h = target_size
is_mask = image.mode == 'L'
if (image.width == w) and (image.height == h):
log.debug(f'Resize image: skip={w}x{h} fn={fn}')
return image
from modules import devices
dev = device if device is not None else devices.device
if not allow_sharpfin(dev):
log.debug(f'Resize image: method=PIL source={image.width}x{image.height} target={w}x{h} device={dev} fn={fn}')
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
rk = get_kernel(kernel)
if rk is None:
log.debug(f'Resize image: method=PI source={image.width}x{image.height} target={w}x{h} kernel=None fn={fn}')
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
from modules.sharpfin.functional import scale
dt = dtype or torch.float16
do_linear = get_linearize(linearize, is_mask=is_mask)
log.debug(f'Resize image: method=sharpfin source={image.width}x{image.height} target={w}x{h} kernel={rk} device={dev} linearize={do_linear} fn={fn}')
tensor = to_tensor(image)
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
tensor = tensor.to(device=dev, dtype=dt)
out_res = (h, w) # sharpfin uses (H, W)
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
both_down = (h <= src_h and w <= src_w)
both_up = (h >= src_h and w >= src_w)
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
return to_pil(result)
def resize_tensor(tensor: torch.Tensor, target_size: tuple[int, int], *, kernel=None, linearize=False):
"""Resize tensor [B,C,H,W] or [C,H,W] -> Tensor. For in-pipeline tensor resizes.
Args:
tensor: Input tensor
target_size: (H, W) tuple
kernel: Override kernel name
linearize: sRGB linearization (default False for latent/mask data)
"""
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
check_sharpfin()
from modules import devices
dev = devices.device
if not allow_sharpfin(dev):
mode = 'bilinear' if (target_size[0] * target_size[1]) > (tensor.shape[-2] * tensor.shape[-1]) else 'area'
log.debug(f'Resize tensor: method=torch mode={mode} shape={tensor.shape} target={target_size} fn={fn}')
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
return result.squeeze(0) if tensor.dim() == 3 else result
rk = get_kernel(kernel)
if rk is None:
mode = 'bilinear' if (target_size[0] * target_size[1]) > (tensor.shape[-2] * tensor.shape[-1]) else 'area'
log.debug(f'Resize tensor: method=torch mode={mode} shape={tensor.shape} target={target_size} kernel=None fn={fn}')
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
return result.squeeze(0) if tensor.dim() == 3 else result
from modules.sharpfin.functional import scale
dt = torch.float16
squeezed = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
squeezed = True
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
th, tw = target_size
both_down = (th <= src_h and tw <= src_w)
both_up = (th >= src_h and tw >= src_w)
if both_down or both_up:
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
log.debug(f'Resize tensor: method=sharpfin shape={tensor.shape} target={target_size} direction={both_up}:{both_down} kernel={rk} sparse={use_sparse} fn={fn}')
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
else:
log.debug(f'Resize tensor: method=sharpfin shape={tensor.shape} target={target_size} direction={both_up}:{both_down} kernel={rk} sparse=False fn={fn}')
if th > src_h:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
else:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
if squeezed:
result = result.squeeze(0)
return result
def to_tensor(image: Image.Image | np.ndarray):
"""PIL Image -> float32 CHW tensor [0,1]. Pure torch, no torchvision."""
# fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
if not isinstance(image, Image.Image):
pic = np.array(image, copy=True)
elif isinstance(image, np.ndarray):
pic = image.copy()
else:
raise TypeError(f"Expected PIL Image or np.ndarray, got {type(image)}")
if pic.ndim == 2:
pic = pic[:, :, np.newaxis]
tensor = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# log.debug(f'Convert: source={type(image)} target={tensor.shape} fn={fn}')
if tensor.dtype == torch.uint8:
return tensor.to(torch.float32).div_(255.0)
return tensor.to(torch.float32)
def to_pil(tensor: torch.Tensor | np.ndarray):
"""Float CHW/HWC or BCHW/BHWC tensor [0,1] -> PIL Image. Pure torch, no torchvision."""
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu()
elif isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
else:
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
try:
if tensor.dim() == 4:
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2]: # BHWC
tensor = tensor.permute(0, 3, 1, 2)
tensor = tensor[0]
elif tensor.dim() == 3:
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2] and tensor.shape[-1] < tensor.shape[-3]: # HWC
tensor = tensor.permute(2, 0, 1)
if tensor.dtype != torch.uint8:
tensor = (tensor.clamp(0, 1) * 255).round().to(torch.uint8)
ndarr = tensor.permute(1, 2, 0).numpy()
if ndarr.shape[2] == 1:
ndarr = ndarr[:, :, 0]
mode = 'L'
elif ndarr.shape[2] == 3:
mode = 'RGB'
else:
mode = 'RGBA'
image = Image.fromarray(ndarr, mode=mode)
except Exception as e:
image = Image.new('RGB', (tensor.shape[-1], tensor.shape[-2]), color=(152, 32, 48))
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.error(f'Convert: source={type(tensor)} target={image} fn={fn} {e}')
return image
def pil_to_tensor(image):
"""PIL Image -> uint8 CHW tensor (no float scaling). Replaces TF.pil_to_tensor."""
if not isinstance(image, Image.Image):
raise TypeError(f"Expected PIL Image, got {type(image)}")
pic = np.array(image, copy=True)
if pic.ndim == 2:
pic = pic[:, :, np.newaxis]
return torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
def normalize(tensor, mean, std, inplace=False):
"""Tensor normalization. Replaces TF.normalize."""
if not inplace:
tensor = tensor.clone()
mean_t = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std_t = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
if mean_t.ndim == 1:
mean_t = mean_t[:, None, None]
if std_t.ndim == 1:
std_t = std_t[:, None, None]
tensor.sub_(mean_t).div_(std_t)
return tensor