automatic/modules/image/sharpfin.py

219 lines
10 KiB
Python

"""Sharpfin wrapper for high-quality image resize and tensor conversion.
Provides drop-in replacements for torchvision.transforms.functional operations
with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization,
and Triton GPU acceleration when available.
Non-CUDA devices fall back to PIL/torch.nn.functional automatically.
"""
import sys
import torch
from PIL import Image
from installer import log
from modules.image.convert import to_tensor, to_pil
_sharpfin_checked = False
_sharpfin_ok = False
_triton_ok = False
def check_sharpfin():
global _sharpfin_checked, _sharpfin_ok, _triton_ok # pylint: disable=global-statement
if not _sharpfin_checked:
from modules.sharpfin.functional import scale # pylint: disable=unused-import
_sharpfin_ok = True
try:
from modules.sharpfin import TRITON_AVAILABLE
_triton_ok = TRITON_AVAILABLE
except Exception:
_triton_ok = False
_sharpfin_checked = True
KERNEL_MAP = {
"Sharpfin MKS2021": "MAGIC_KERNEL_SHARP_2021",
"Sharpfin Lanczos3": "LANCZOS3",
"Sharpfin Mitchell": "MITCHELL",
"Sharpfin Catmull-Rom": "CATMULL_ROM",
}
def get_kernel(kernel=None):
"""Resolve kernel name to ResizeKernel enum. Returns None for PIL fallback."""
if kernel is not None:
name = kernel
else:
from modules import shared
name = shared.opts.resize_quality
if name == "PIL Lanczos" or name not in KERNEL_MAP:
return None
from modules.sharpfin.util import ResizeKernel
return getattr(ResizeKernel, KERNEL_MAP[name])
def get_linearize(linearize=None, is_mask=False):
"""Determine sRGB linearization setting."""
if is_mask:
return False
if linearize is not None:
return linearize
from modules import shared
return shared.opts.resize_linearize_srgb
def allow_sharpfin(device=None):
"""Determine if sharpfin should be used based on device."""
if device is None:
from modules import devices
device = devices.device
# Sharpfin is optimized for CUDA with Triton, for other devices (CPU, MPS, OpenVINO), use torch/PIL optimized kernels
return hasattr(device, 'type') and device.type == 'cuda'
def resize(image, target_size, *, kernel=None, linearize=None, device=None, dtype=None):
"""Resize PIL.Image or torch.Tensor, returning same type.
Args:
image: PIL.Image or torch.Tensor [B,C,H,W] / [C,H,W]
target_size: (width, height) for PIL, (H, W) for tensor
kernel: Override kernel name, or None for settings
linearize: Override sRGB linearization, or None for settings
device: Override compute device
dtype: Override compute dtype
"""
check_sharpfin()
if isinstance(image, Image.Image):
return resize_pil(image, target_size, kernel=kernel, linearize=linearize, device=device, dtype=dtype)
elif isinstance(image, torch.Tensor):
return resize_tensor(image, target_size, kernel=kernel, linearize=linearize if linearize is not None else False)
return image
def _want_sparse(dev, rk, both_down):
"""Check if Triton sparse acceleration should be attempted."""
return _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
def _scale_pil(scale_fn, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up):
"""Run sharpfin scale with sparse fallback. Returns result tensor."""
global _triton_ok # pylint: disable=global-statement
if both_down or both_up:
use_sparse = _want_sparse(dev, rk, both_down)
if use_sparse:
try:
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
except Exception:
_triton_ok = False
log.info("Sharpfin: Triton sparse disabled, using dense path")
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
# Mixed axis: split into two single-axis resizes
if h > src_h: # H up, W down
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
use_sparse = _want_sparse(dev, rk, True)
if use_sparse:
try:
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
except Exception:
_triton_ok = False
log.info("Sharpfin: Triton sparse disabled, using dense path")
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
# H down, W up
use_sparse = _want_sparse(dev, rk, True)
if use_sparse:
try:
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
except Exception:
_triton_ok = False
log.info("Sharpfin: Triton sparse disabled, using dense path")
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
def resize_pil(image: Image.Image, target_size: tuple[int, int], *, kernel=None, linearize=None, device=None, dtype=None):
"""Resize a PIL Image via sharpfin, falling back to PIL on error."""
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
w, h = target_size
is_mask = image.mode == 'L'
if (image.width == w) and (image.height == h):
log.debug(f'Resize image: skip={w}x{h} fn={fn}')
return image
from modules import devices
dev = device if device is not None else devices.device
if not allow_sharpfin(dev):
log.debug(f'Resize image: method=PIL source={image.width}x{image.height} target={w}x{h} device={dev} fn={fn}')
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
rk = get_kernel(kernel)
if rk is None:
log.debug(f'Resize image: method=PIL source={image.width}x{image.height} target={w}x{h} kernel=None fn={fn}')
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
from modules.sharpfin.functional import scale
dt = dtype or torch.float16
do_linear = get_linearize(linearize, is_mask=is_mask)
log.debug(f'Resize image: method=sharpfin source={image.width}x{image.height} target={w}x{h} kernel={rk} device={dev} linearize={do_linear} fn={fn}')
tensor = to_tensor(image)
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
tensor = tensor.to(device=dev, dtype=dt)
out_res = (h, w) # sharpfin uses (H, W)
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
both_down = (h <= src_h and w <= src_w)
both_up = (h >= src_h and w >= src_w)
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
return to_pil(result)
def resize_tensor(tensor: torch.Tensor, target_size: tuple[int, int], *, kernel=None, linearize=False):
"""Resize tensor [B,C,H,W] or [C,H,W] -> Tensor. For in-pipeline tensor resizes.
Args:
tensor: Input tensor
target_size: (H, W) tuple
kernel: Override kernel name
linearize: sRGB linearization (default False for latent/mask data)
"""
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
check_sharpfin()
from modules import devices
dev = devices.device
if not allow_sharpfin(dev):
mode = 'bilinear' if (target_size[0] * target_size[1]) > (tensor.shape[-2] * tensor.shape[-1]) else 'area'
log.debug(f'Resize tensor: method=torch mode={mode} shape={tensor.shape} target={target_size} fn={fn}')
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
return result.squeeze(0) if tensor.dim() == 3 else result
rk = get_kernel(kernel)
if rk is None:
mode = 'bilinear' if (target_size[0] * target_size[1]) > (tensor.shape[-2] * tensor.shape[-1]) else 'area'
log.debug(f'Resize tensor: method=torch mode={mode} shape={tensor.shape} target={target_size} kernel=None fn={fn}')
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
return result.squeeze(0) if tensor.dim() == 3 else result
from modules.sharpfin.functional import scale
dt = torch.float16
squeezed = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
squeezed = True
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
th, tw = target_size
both_down = (th <= src_h and tw <= src_w)
both_up = (th >= src_h and tw >= src_w)
if both_down or both_up:
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
log.debug(f'Resize tensor: method=sharpfin shape={tensor.shape} target={target_size} direction={both_up}:{both_down} kernel={rk} sparse={use_sparse} fn={fn}')
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
else:
log.debug(f'Resize tensor: method=sharpfin shape={tensor.shape} target={target_size} direction={both_up}:{both_down} kernel={rk} sparse=False fn={fn}')
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
if squeezed:
result = result.squeeze(0)
return result