mirror of https://github.com/vladmandic/automatic
refactor: integrate sharpfin for high-quality image resize
Vendor sharpfin library (Apache 2.0) and add centralized wrapper module (images_sharpfin.py) replacing torchvision tensor/PIL conversion and resize operations throughout the codebase. - Add modules/sharpfin/ vendored library with MKS2021, Lanczos3, Mitchell, Catmull-Rom kernels and optional Triton sparse acceleration - Add modules/images_sharpfin.py wrapper with to_tensor(), to_pil(), pil_to_tensor(), normalize(), resize(), resize_tensor() - Add resize_quality and resize_linearize_srgb settings - Add MKS2021 and Lanczos3 upscaler entries - Replace torchvision.transforms.functional imports across 18 files - to_pil() auto-detects HWC/BHWC layout, adds .round() before uint8 - Sparse Triton path falls back to dense GPU on compilation failure - Mixed-axis resize splits into two single-axis scale() calls - Masks and non-sRGB data always use linearize=Falsepull/4668/head
parent
2c4d0751d9
commit
76aa949a26
|
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from transformers.activations import QuickGELUActivation
|
||||
import torchvision
|
||||
import torchvision.transforms.functional as TVF
|
||||
from modules import images_sharpfin
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
import huggingface_hub
|
||||
|
|
@ -1035,8 +1035,8 @@ def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
|
|||
padded_image.paste(image, (pad_left, pad_top))
|
||||
if max_dim != target_size:
|
||||
padded_image = padded_image.resize((target_size, target_size), Image.Resampling.LANCZOS)
|
||||
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
|
||||
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
|
||||
image_tensor = images_sharpfin.to_tensor(padded_image)
|
||||
image_tensor = images_sharpfin.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
|
||||
return image_tensor
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import shared, upscaler
|
||||
from modules import shared, upscaler, images_sharpfin
|
||||
|
||||
|
||||
def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
|
||||
|
|
@ -36,7 +36,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
|
|||
def resize(im: Union[Image.Image, torch.Tensor], w, h):
|
||||
w, h = int(w), int(h)
|
||||
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
|
||||
return im.resize((w, h), resample=Image.Resampling.LANCZOS) # force for mask
|
||||
return images_sharpfin.resize(im, (w, h), linearize=False) # force for mask
|
||||
if isinstance(im, torch.Tensor):
|
||||
scale = max(w // 8 / im.shape[-1] , h // 8 / im.shape[-2])
|
||||
else:
|
||||
|
|
@ -53,7 +53,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
|
|||
shared.log.warning(f"Resize upscaler: invalid={upscaler_name} fallback={selected_upscaler.name}")
|
||||
shared.log.debug(f"Resize upscaler: available={[u.name for u in shared.sd_upscalers]}")
|
||||
if isinstance(im, Image.Image) and (im.width != w or im.height != h): # probably downsample after upscaler created larger image
|
||||
im = im.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
im = images_sharpfin.resize(im, (w, h))
|
||||
return im
|
||||
|
||||
def crop(im: Image.Image):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,298 @@
|
|||
"""Sharpfin wrapper for high-quality image resize and tensor conversion.
|
||||
|
||||
Provides drop-in replacements for torchvision.transforms.functional operations
|
||||
with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization,
|
||||
and Triton GPU acceleration when available.
|
||||
|
||||
All public functions include try/except fallback to PIL/torchvision.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
_sharpfin_checked = False
|
||||
_sharpfin_ok = False
|
||||
_triton_ok = False
|
||||
_log = None
|
||||
|
||||
|
||||
def _get_log():
|
||||
global _log
|
||||
if _log is None:
|
||||
try:
|
||||
from modules.shared import log
|
||||
_log = log
|
||||
except Exception:
|
||||
import logging
|
||||
_log = logging.getLogger(__name__)
|
||||
return _log
|
||||
|
||||
|
||||
def _check():
|
||||
global _sharpfin_checked, _sharpfin_ok, _triton_ok
|
||||
if not _sharpfin_checked:
|
||||
# DEBUG: no try/except — let import errors propagate
|
||||
from modules.sharpfin.functional import scale # pylint: disable=unused-import
|
||||
_sharpfin_ok = True
|
||||
try:
|
||||
from modules.sharpfin import TRITON_AVAILABLE
|
||||
_triton_ok = TRITON_AVAILABLE
|
||||
except Exception:
|
||||
_triton_ok = False
|
||||
_sharpfin_checked = True
|
||||
|
||||
|
||||
def is_available():
|
||||
"""Check if sharpfin functional module loaded."""
|
||||
_check()
|
||||
return _sharpfin_ok
|
||||
|
||||
|
||||
KERNEL_MAP = {
|
||||
"Sharpfin MKS2021": "MAGIC_KERNEL_SHARP_2021",
|
||||
"Sharpfin Lanczos3": "LANCZOS3",
|
||||
"Sharpfin Mitchell": "MITCHELL",
|
||||
"Sharpfin Catmull-Rom": "CATMULL_ROM",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_kernel(kernel=None):
|
||||
"""Resolve kernel name to ResizeKernel enum. Returns None for PIL fallback."""
|
||||
if kernel is not None:
|
||||
name = kernel
|
||||
else:
|
||||
try:
|
||||
from modules import shared
|
||||
name = getattr(shared.opts, 'resize_quality', 'Sharpfin MKS2021')
|
||||
except Exception:
|
||||
name = 'Sharpfin MKS2021'
|
||||
if name == "PIL Lanczos" or name not in KERNEL_MAP:
|
||||
return None
|
||||
from modules.sharpfin.util import ResizeKernel
|
||||
return getattr(ResizeKernel, KERNEL_MAP[name])
|
||||
|
||||
|
||||
def _resolve_linearize(linearize=None, is_mask=False):
|
||||
"""Determine sRGB linearization setting."""
|
||||
if is_mask:
|
||||
return False
|
||||
if linearize is not None:
|
||||
return linearize
|
||||
try:
|
||||
from modules import shared
|
||||
return getattr(shared.opts, 'resize_linearize_srgb', True)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
def _get_device_dtype(device=None, dtype=None):
|
||||
"""Get optimal device/dtype for sharpfin operations."""
|
||||
if device is not None and dtype is not None:
|
||||
return device, dtype
|
||||
try:
|
||||
from modules import devices
|
||||
dev = device or devices.device
|
||||
if dev.type == 'cuda':
|
||||
return dev, dtype or torch.float16
|
||||
return dev, dtype or torch.float32
|
||||
except Exception:
|
||||
return device or torch.device('cpu'), dtype or torch.float32
|
||||
|
||||
|
||||
def resize(image, target_size, *, kernel=None, linearize=None, device=None, dtype=None):
|
||||
"""Resize PIL.Image or torch.Tensor, returning same type.
|
||||
|
||||
Args:
|
||||
image: PIL.Image or torch.Tensor [B,C,H,W] / [C,H,W]
|
||||
target_size: (width, height) for PIL, (H, W) for tensor
|
||||
kernel: Override kernel name, or None for settings
|
||||
linearize: Override sRGB linearization, or None for settings
|
||||
device: Override compute device
|
||||
dtype: Override compute dtype
|
||||
"""
|
||||
_check()
|
||||
if isinstance(image, Image.Image):
|
||||
return _resize_pil(image, target_size, kernel=kernel, linearize=linearize, device=device, dtype=dtype)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
return resize_tensor(image, target_size, kernel=kernel, linearize=linearize if linearize is not None else False)
|
||||
return image
|
||||
|
||||
|
||||
def _want_sparse(dev, rk, both_down):
|
||||
"""Check if Triton sparse acceleration should be attempted."""
|
||||
return _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
|
||||
|
||||
|
||||
def _scale_pil(scale_fn, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up):
|
||||
"""Run sharpfin scale with sparse fallback. Returns result tensor."""
|
||||
global _triton_ok # pylint: disable=global-statement
|
||||
if both_down or both_up:
|
||||
use_sparse = _want_sparse(dev, rk, both_down)
|
||||
if use_sparse:
|
||||
try:
|
||||
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
|
||||
except Exception:
|
||||
_triton_ok = False
|
||||
_get_log().info("Sharpfin: Triton sparse disabled, using dense path")
|
||||
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
|
||||
# Mixed axis: split into two single-axis resizes
|
||||
if h > src_h: # H up, W down
|
||||
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
|
||||
use_sparse = _want_sparse(dev, rk, True)
|
||||
if use_sparse:
|
||||
try:
|
||||
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
|
||||
except Exception:
|
||||
_triton_ok = False
|
||||
_get_log().info("Sharpfin: Triton sparse disabled, using dense path")
|
||||
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
|
||||
# H down, W up
|
||||
use_sparse = _want_sparse(dev, rk, True)
|
||||
if use_sparse:
|
||||
try:
|
||||
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
|
||||
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
|
||||
except Exception:
|
||||
_triton_ok = False
|
||||
_get_log().info("Sharpfin: Triton sparse disabled, using dense path")
|
||||
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
|
||||
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
|
||||
|
||||
|
||||
def _resize_pil(image, target_size, *, kernel=None, linearize=None, device=None, dtype=None):
|
||||
"""Resize a PIL Image via sharpfin, falling back to PIL on error."""
|
||||
w, h = target_size
|
||||
if image.width == w and image.height == h:
|
||||
return image
|
||||
is_mask = image.mode == 'L'
|
||||
rk = _resolve_kernel(kernel)
|
||||
if rk is None:
|
||||
# DEBUG: only "PIL Lanczos" setting should reach here
|
||||
assert _resolve_kernel.__doc__, "unreachable" # keeps linter happy
|
||||
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
try:
|
||||
from modules.sharpfin.functional import scale
|
||||
do_linear = _resolve_linearize(linearize, is_mask=is_mask)
|
||||
dev, dt = _get_device_dtype(device, dtype)
|
||||
tensor = to_tensor(image)
|
||||
if tensor.dim() == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
tensor = tensor.to(device=dev, dtype=dt)
|
||||
out_res = (h, w) # sharpfin uses (H, W)
|
||||
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
|
||||
both_down = (h <= src_h and w <= src_w)
|
||||
both_up = (h >= src_h and w >= src_w)
|
||||
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
|
||||
return to_pil(result)
|
||||
# except Exception as e: # DEBUG: PIL fallback disabled for testing
|
||||
# _get_log().warning(f"Sharpfin resize failed, falling back to PIL: {e}")
|
||||
# return image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
|
||||
"""Resize tensor [B,C,H,W] or [C,H,W] -> Tensor. For in-pipeline tensor resizes.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
target_size: (H, W) tuple
|
||||
kernel: Override kernel name
|
||||
linearize: sRGB linearization (default False for latent/mask data)
|
||||
"""
|
||||
_check()
|
||||
rk = _resolve_kernel(kernel)
|
||||
if rk is None:
|
||||
# DEBUG: only "PIL Lanczos" setting should reach here
|
||||
mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
|
||||
return torch.nn.functional.interpolate(tensor if tensor.dim() == 4 else tensor.unsqueeze(0), size=target_size, mode=mode, antialias=True).squeeze(0) if tensor.dim() == 3 else torch.nn.functional.interpolate(tensor, size=target_size, mode=mode, antialias=True)
|
||||
try:
|
||||
from modules.sharpfin.functional import scale
|
||||
dev, dt = _get_device_dtype()
|
||||
squeezed = False
|
||||
if tensor.dim() == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
squeezed = True
|
||||
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
|
||||
th, tw = target_size
|
||||
both_down = (th <= src_h and tw <= src_w)
|
||||
both_up = (th >= src_h and tw >= src_w)
|
||||
if both_down or both_up:
|
||||
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
|
||||
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
|
||||
else:
|
||||
if th > src_h:
|
||||
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
else:
|
||||
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
return result
|
||||
# except Exception as e: # DEBUG: F.interpolate fallback disabled for testing
|
||||
# _get_log().warning(f"Sharpfin resize_tensor failed, falling back to F.interpolate: {e}")
|
||||
# mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
|
||||
# inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
|
||||
# result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
|
||||
# return result.squeeze(0) if tensor.dim() == 3 else result
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def to_tensor(image):
|
||||
"""PIL Image -> float32 CHW tensor [0,1]. Pure torch, no torchvision."""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise TypeError(f"Expected PIL Image, got {type(image)}")
|
||||
pic = np.array(image, copy=True)
|
||||
if pic.ndim == 2:
|
||||
pic = pic[:, :, np.newaxis]
|
||||
tensor = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
|
||||
if tensor.dtype == torch.uint8:
|
||||
return tensor.to(torch.float32).div_(255.0)
|
||||
return tensor.to(torch.float32)
|
||||
|
||||
|
||||
def to_pil(tensor):
|
||||
"""Float CHW/HWC or BCHW/BHWC tensor [0,1] -> PIL Image. Pure torch, no torchvision."""
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
|
||||
tensor = tensor.detach().cpu()
|
||||
if tensor.dim() == 4:
|
||||
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2]: # BHWC
|
||||
tensor = tensor.permute(0, 3, 1, 2)
|
||||
tensor = tensor[0]
|
||||
elif tensor.dim() == 3:
|
||||
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2] and tensor.shape[-1] < tensor.shape[-3]: # HWC
|
||||
tensor = tensor.permute(2, 0, 1)
|
||||
if tensor.dtype != torch.uint8:
|
||||
tensor = (tensor.clamp(0, 1) * 255).round().to(torch.uint8)
|
||||
ndarr = tensor.permute(1, 2, 0).numpy()
|
||||
if ndarr.shape[2] == 1:
|
||||
return Image.fromarray(ndarr[:, :, 0], mode='L')
|
||||
return Image.fromarray(ndarr)
|
||||
|
||||
|
||||
def pil_to_tensor(image):
|
||||
"""PIL Image -> uint8 CHW tensor (no float scaling). Replaces TF.pil_to_tensor."""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise TypeError(f"Expected PIL Image, got {type(image)}")
|
||||
pic = np.array(image, copy=True)
|
||||
if pic.ndim == 2:
|
||||
pic = pic[:, :, np.newaxis]
|
||||
return torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
|
||||
|
||||
|
||||
def normalize(tensor, mean, std, inplace=False):
|
||||
"""Tensor normalization. Replaces TF.normalize."""
|
||||
if not inplace:
|
||||
tensor = tensor.clone()
|
||||
mean_t = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
||||
std_t = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
||||
if mean_t.ndim == 1:
|
||||
mean_t = mean_t[:, None, None]
|
||||
if std_t.ndim == 1:
|
||||
std_t = std_t[:, None, None]
|
||||
tensor.sub_(mean_t).div_(std_t)
|
||||
return tensor
|
||||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from PIL import Image
|
||||
from modules import devices
|
||||
from modules import devices, images_sharpfin
|
||||
from installer import log
|
||||
|
||||
|
||||
|
|
@ -96,7 +96,5 @@ class SimpleLama:
|
|||
image, mask = prepare_img_and_mask(image, mask, self.device)
|
||||
with devices.inference_context():
|
||||
inpainted = self.model(image, mask)
|
||||
cur_res = inpainted[0].permute(1, 2, 0).detach().float().cpu().numpy()
|
||||
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
|
||||
cur_res = Image.fromarray(cur_res)
|
||||
cur_res = images_sharpfin.to_pil(inpainted[0])
|
||||
return cur_res
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def setup_model(dirname):
|
|||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def restore(self, np_image, p=None, w=None): # pylint: disable=unused-argument
|
||||
from torchvision.transforms.functional import normalize
|
||||
from modules import images_sharpfin
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
np_image = np_image[:, :, ::-1]
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
|
@ -84,7 +84,7 @@ def setup_model(dirname):
|
|||
self.face_helper.align_warp_face()
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
images_sharpfin.normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device)
|
||||
try:
|
||||
with devices.inference_context():
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from typing import List
|
|||
|
||||
import math
|
||||
import torch
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from modules import images_sharpfin
|
||||
|
||||
from PIL import Image
|
||||
from diffusers.utils import CONFIG_NAME
|
||||
|
|
@ -65,11 +65,9 @@ def edge_detect_for_pixelart(image: PipelineImageInput, image_weight: float = 1.
|
|||
greyscale_reshaped = greyscale_reshaped.reshape(batch_size, block_size_sq, block_height, block_width)
|
||||
|
||||
greyscale_range = greyscale_reshaped.amax(dim=1, keepdim=True).sub_(greyscale_reshaped.amin(dim=1, keepdim=True))
|
||||
upsample = torchvision.transforms.Resize((height, width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
|
||||
|
||||
range_weight = upsample(greyscale_range)
|
||||
range_weight = images_sharpfin.resize_tensor(greyscale_range, (height, width), linearize=False)
|
||||
range_weight = range_weight.div_(range_weight.max())
|
||||
weight_map = upsample((greyscale > greyscale.median()).to(dtype=torch.float32))
|
||||
weight_map = images_sharpfin.resize_tensor((greyscale > greyscale.median()).to(dtype=torch.float32), (height, width), linearize=False)
|
||||
weight_map = weight_map.unsqueeze(0).add_(range_weight).mul_(image_weight / 2)
|
||||
|
||||
new_image = new_image.mul_(weight_map).addcmul_(min_pool, (1-weight_map))
|
||||
|
|
@ -161,8 +159,7 @@ def encode_jpeg_tensor(img: torch.FloatTensor, block_size: int=16, cbcr_downscal
|
|||
img = img[:, :, :(img.shape[-2]//block_size)*block_size, :(img.shape[-1]//block_size)*block_size] # crop to a multiply of block_size
|
||||
cbcr_block_size = block_size//cbcr_downscale
|
||||
_, _, height, width = img.shape
|
||||
downsample = torchvision.transforms.Resize((height//cbcr_downscale, width//cbcr_downscale), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
|
||||
down_img = downsample(img[:, 1:,:,:])
|
||||
down_img = images_sharpfin.resize_tensor(img[:, 1:,:,:], (height//cbcr_downscale, width//cbcr_downscale), linearize=False)
|
||||
y = encode_single_channel_dct_2d(img[:, 0, :,:], block_size=block_size, norm=norm)
|
||||
cb = encode_single_channel_dct_2d(down_img[:, 0, :,:], block_size=cbcr_block_size, norm=norm)
|
||||
cr = encode_single_channel_dct_2d(down_img[:, 1, :,:], block_size=cbcr_block_size, norm=norm)
|
||||
|
|
@ -180,9 +177,8 @@ def decode_jpeg_tensor(jpeg_img: torch.FloatTensor, block_size: int=16, cbcr_dow
|
|||
y = decode_single_channel_dct_2d(y, norm=norm)
|
||||
cb = decode_single_channel_dct_2d(cb, norm=norm)
|
||||
cr = decode_single_channel_dct_2d(cr, norm=norm)
|
||||
upsample = torchvision.transforms.Resize((h_blocks*block_size, w_blocks*block_size), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
|
||||
cb = upsample(cb)
|
||||
cr = upsample(cr)
|
||||
cb = images_sharpfin.resize_tensor(cb, (h_blocks*block_size, w_blocks*block_size), linearize=False)
|
||||
cr = images_sharpfin.resize_tensor(cr, (h_blocks*block_size, w_blocks*block_size), linearize=False)
|
||||
return torch.stack([y,cb,cr], dim=1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,7 @@ import random
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToPILImage
|
||||
from modules import devices
|
||||
from modules import devices, images_sharpfin
|
||||
from modules.shared import opts, log
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
|
@ -14,7 +13,7 @@ MODELS_MAP = {
|
|||
"SeedVR2 7B": "seedvr2_ema_7b_fp16.safetensors",
|
||||
"SeedVR2 7B Sharp": "seedvr2_ema_7b_sharp_fp16.safetensors",
|
||||
}
|
||||
to_pil = ToPILImage()
|
||||
to_pil = images_sharpfin.to_pil
|
||||
|
||||
|
||||
class UpscalerSeedVR(Upscaler):
|
||||
|
|
@ -159,7 +158,7 @@ class UpscalerSeedVR(Upscaler):
|
|||
)
|
||||
t1 = time.time()
|
||||
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}')
|
||||
img = to_pil(result_tensor.squeeze().permute((2, 0, 1)))
|
||||
img = to_pil(result_tensor.squeeze())
|
||||
|
||||
if opts.upscaler_unload:
|
||||
self.model.dit = None
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@ import os
|
|||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, timer, modelstats, extra_networks, attention
|
||||
from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, timer, modelstats, extra_networks, attention, images_sharpfin
|
||||
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled, get_job_name
|
||||
from modules.processing_args import set_pipeline_args
|
||||
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
|
||||
|
|
@ -270,9 +269,9 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
|
|||
sd_hijack_hypertile.hypertile_set(p, hr=True)
|
||||
elif torch.is_tensor(output.images) and output.images.shape[-1] == 3: # nhwc
|
||||
if output.images.dim() == 3:
|
||||
output.images = TF.to_pil_image(output.images.permute(2,0,1))
|
||||
output.images = images_sharpfin.to_pil(output.images)
|
||||
elif output.images.dim() == 4:
|
||||
output.images = [TF.to_pil_image(output.images[i].permute(2,0,1)) for i in range(output.images.shape[0])]
|
||||
output.images = [images_sharpfin.to_pil(output.images[i]) for i in range(output.images.shape[0])]
|
||||
|
||||
strength = p.hr_denoising_strength if p.hr_denoising_strength > 0 else p.denoising_strength
|
||||
if (p.hr_upscaler is not None) and (p.hr_upscaler.lower().startswith('latent') or p.hr_force) and strength > 0:
|
||||
|
|
@ -572,7 +571,7 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
|
|||
if hasattr(shared.sd_model, 'unet') and hasattr(shared.sd_model.unet, 'config') and hasattr(shared.sd_model.unet.config, 'in_channels') and shared.sd_model.unet.config.in_channels == 9 and not is_control:
|
||||
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # force pipeline
|
||||
if len(getattr(p, 'init_images', [])) == 0:
|
||||
p.init_images = [TF.to_pil_image(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))]
|
||||
p.init_images = [images_sharpfin.to_pil(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))]
|
||||
if not p.prompts:
|
||||
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
|
||||
if not p.negative_prompts:
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
|
|||
|
||||
def vae_encode(image, model, vae_type='Full'): # pylint: disable=unused-variable
|
||||
jobid = shared.state.begin('VAE Encode')
|
||||
import torchvision.transforms.functional as f
|
||||
from modules import images_sharpfin
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
return []
|
||||
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
|
||||
|
|
@ -342,7 +342,7 @@ def vae_encode(image, model, vae_type='Full'): # pylint: disable=unused-variable
|
|||
if not hasattr(model, 'vae'):
|
||||
shared.log.error('VAE not found in model')
|
||||
return []
|
||||
tensor = f.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae)
|
||||
tensor = images_sharpfin.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae)
|
||||
if vae_type == 'Tiny':
|
||||
latents = taesd_vae_encode(image=tensor)
|
||||
elif vae_type == 'Full' and hasattr(model, 'vae'):
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ import time
|
|||
import threading
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from modules import shared, devices, processing, images, sd_samplers, timer
|
||||
from modules import shared, devices, processing, images, sd_samplers, timer, images_sharpfin
|
||||
from modules.vae import sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade
|
||||
|
||||
|
||||
|
|
@ -84,7 +83,7 @@ def single_sample_to_image(sample, approximation=None):
|
|||
x_sample = (255.0 * x_sample).to(torch.uint8)
|
||||
if len(x_sample.shape) == 4:
|
||||
x_sample = x_sample[0]
|
||||
image = TF.to_pil_image(x_sample)
|
||||
image = images_sharpfin.to_pil(x_sample)
|
||||
except Exception as e:
|
||||
warn_once(f'Preview: {e}')
|
||||
image = Image.new(mode="RGB", size=(512, 512))
|
||||
|
|
|
|||
|
|
@ -671,6 +671,10 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
|||
"upscaler_latent_steps": OptionInfo(20, "Upscaler latent steps", gr.Slider, {"minimum": 4, "maximum": 100, "step": 1}),
|
||||
"upscaler_tile_size": OptionInfo(192, "Upscaler tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"upscaler_tile_overlap": OptionInfo(8, "Upscaler tile overlap", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
||||
|
||||
"postprocessing_sep_resize": OptionInfo("<h2>Resize</h2>", "", gr.HTML),
|
||||
"resize_quality": OptionInfo("Sharpfin MKS2021", "Image resize algorithm", gr.Dropdown, {"choices": ["PIL Lanczos", "Sharpfin MKS2021", "Sharpfin Lanczos3", "Sharpfin Mitchell", "Sharpfin Catmull-Rom"]}),
|
||||
"resize_linearize_srgb": OptionInfo(True, "Apply sRGB linearization during image resize"),
|
||||
}))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Copyright 2024 drhead
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
"""Sharpfin - High quality image resizing with GPU acceleration.
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
Provides Magic Kernel Sharp 2021 resampling, sRGB linearization,
|
||||
and Triton sparse GPU acceleration.
|
||||
"""
|
||||
|
||||
from .util import ResizeKernel, SharpenKernel, QuantHandling, srgb_to_linear, linear_to_srgb
|
||||
|
||||
try:
|
||||
from .functional import scale, _upscale, _downscale, _get_resize_kernel
|
||||
FUNCTIONAL_AVAILABLE = True
|
||||
except Exception:
|
||||
FUNCTIONAL_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from .triton_functional import downscale_sparse
|
||||
TRITON_AVAILABLE = True
|
||||
except Exception:
|
||||
TRITON_AVAILABLE = False
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
"""Sharpfin color management (ICC profile handling).
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
"""
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Any, cast
|
||||
from warnings import warn
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
import PIL.Image as image
|
||||
import PIL.ImageCms as image_cms
|
||||
|
||||
from PIL.Image import Image
|
||||
from PIL.ImageCms import (
|
||||
Direction, Intent, ImageCmsProfile, PyCMSError,
|
||||
createProfile, getDefaultIntent, isIntentSupported, profileToProfile
|
||||
)
|
||||
from PIL.ImageOps import exif_transpose
|
||||
|
||||
image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
_SRGB = createProfile(colorSpace='sRGB')
|
||||
|
||||
_INTENT_FLAGS = {
|
||||
Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
|
||||
Intent.RELATIVE_COLORIMETRIC: (
|
||||
image_cms.FLAGS["HIGHRESPRECALC"] |
|
||||
image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
|
||||
),
|
||||
Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
|
||||
}
|
||||
|
||||
class CMSWarning(UserWarning):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
path: str | None = None,
|
||||
cms_info: dict[str, Any] | None = None,
|
||||
cause: Exception | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.__cause__ = cause
|
||||
|
||||
self.path = path
|
||||
self.cms_info = cms_info
|
||||
|
||||
def _coalesce_intent(intent: Intent | int) -> Intent:
|
||||
if isinstance(intent, Intent):
|
||||
return intent
|
||||
|
||||
match intent:
|
||||
case 0:
|
||||
return Intent.PERCEPTUAL
|
||||
case 1:
|
||||
return Intent.RELATIVE_COLORIMETRIC
|
||||
case 2:
|
||||
return Intent.SATURATION
|
||||
case 3:
|
||||
return Intent.ABSOLUTE_COLORIMETRIC
|
||||
case _:
|
||||
raise ValueError("invalid intent")
|
||||
|
||||
def _add_info(info: dict[str, Any], source: object, key: str) -> None:
|
||||
try:
|
||||
if (value := getattr(source, key, None)) is not None:
|
||||
info[key] = value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def apply_srgb(
|
||||
img: Image
|
||||
) -> Image:
|
||||
if hasattr(img, 'filename'):
|
||||
path = img.filename
|
||||
else:
|
||||
path = ""
|
||||
|
||||
try:
|
||||
img.load()
|
||||
|
||||
try:
|
||||
exif_transpose(img, in_place=True)
|
||||
except Exception:
|
||||
pass # corrupt EXIF metadata is fine
|
||||
|
||||
if (icc_raw := img.info.get("icc_profile")) is not None:
|
||||
cms_info: dict[str, Any] = {
|
||||
"native_mode": img.mode,
|
||||
"transparency": img.has_transparency_data,
|
||||
}
|
||||
|
||||
try:
|
||||
profile = ImageCmsProfile(BytesIO(icc_raw))
|
||||
_add_info(cms_info, profile.profile, "profile_description")
|
||||
_add_info(cms_info, profile.profile, "target")
|
||||
_add_info(cms_info, profile.profile, "xcolor_space")
|
||||
_add_info(cms_info, profile.profile, "connection_space")
|
||||
_add_info(cms_info, profile.profile, "colorimetric_intent")
|
||||
_add_info(cms_info, profile.profile, "rendering_intent")
|
||||
|
||||
working_mode = img.mode
|
||||
if img.mode.startswith(("RGB", "BGR", "P")):
|
||||
working_mode = "RGBA" if img.has_transparency_data else "RGB"
|
||||
elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
|
||||
working_mode = "LA" if img.has_transparency_data else "L"
|
||||
|
||||
if img.mode != working_mode:
|
||||
cms_info["working_mode"] = working_mode
|
||||
img = img.convert(working_mode)
|
||||
|
||||
mode = "RGBA" if img.has_transparency_data else "RGB"
|
||||
|
||||
intent = Intent.RELATIVE_COLORIMETRIC
|
||||
if isIntentSupported(profile, intent, Direction.INPUT) != 1:
|
||||
intent = _coalesce_intent(getDefaultIntent(profile))
|
||||
|
||||
cms_info["conversion_intent"] = intent
|
||||
|
||||
if (flags := _INTENT_FLAGS.get(intent)) is not None:
|
||||
if img.mode == mode:
|
||||
profileToProfile(
|
||||
img,
|
||||
profile,
|
||||
_SRGB,
|
||||
renderingIntent=intent,
|
||||
inPlace=True,
|
||||
flags=flags
|
||||
)
|
||||
else:
|
||||
img = cast(Image, profileToProfile(
|
||||
img,
|
||||
profile,
|
||||
_SRGB,
|
||||
renderingIntent=intent,
|
||||
outputMode=mode,
|
||||
flags=flags
|
||||
))
|
||||
else:
|
||||
warn(CMSWarning(
|
||||
f"unsupported intent on {path} assuming sRGB: {cms_info}",
|
||||
path=path,
|
||||
cms_info=cms_info
|
||||
))
|
||||
except PyCMSError as ex:
|
||||
warn(CMSWarning(
|
||||
f"{ex} on {path}, assuming sRGB: {cms_info}",
|
||||
path=path,
|
||||
cms_info=cms_info,
|
||||
cause=ex,
|
||||
))
|
||||
|
||||
except Exception as ex:
|
||||
print(f"{ex} on {path}")
|
||||
|
||||
if img.has_transparency_data:
|
||||
if img.mode != "RGBA":
|
||||
try:
|
||||
img = img.convert("RGBA")
|
||||
except ValueError:
|
||||
img = img.convert("RGBa").convert("RGBA")
|
||||
elif img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
|
||||
return img
|
||||
|
||||
def put_srgb(img: Image, tensor: Tensor) -> None:
|
||||
if img.mode not in ("RGB", "RGBA", "RGBa"):
|
||||
raise ValueError(f"Image has non-RGB mode {img.mode}.")
|
||||
|
||||
np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
"""Sharpfin functional image scaling operations.
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
Imports patched: absolute sharpfin.X → relative .X, triton import guarded.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from typing import Callable, Tuple
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
|
||||
from .util import ResizeKernel, linear_to_srgb, srgb_to_linear
|
||||
|
||||
# from Pytorch >= 2.6
|
||||
set_stance = getattr(torch.compiler, "set_stance", None)
|
||||
|
||||
|
||||
def _get_resize_kernel(k: ResizeKernel):
|
||||
match k:
|
||||
case ResizeKernel.NEAREST:
|
||||
resize_kernel = nearest
|
||||
kernel_window = 0.5
|
||||
case ResizeKernel.BILINEAR:
|
||||
resize_kernel = bilinear
|
||||
kernel_window = 1.
|
||||
case ResizeKernel.MITCHELL:
|
||||
resize_kernel = mitchell # B = 1/3, C = 1/3
|
||||
kernel_window = 2.
|
||||
case ResizeKernel.CATMULL_ROM:
|
||||
resize_kernel = lambda x: mitchell(x, 0.0, 0.5)
|
||||
kernel_window = 2.
|
||||
case ResizeKernel.B_SPLINE:
|
||||
resize_kernel = lambda x: mitchell(x, 1.0, 0.0)
|
||||
kernel_window = 2.
|
||||
case ResizeKernel.LANCZOS2:
|
||||
resize_kernel = lambda x: lanczos(x, 2)
|
||||
kernel_window = 2.
|
||||
case ResizeKernel.LANCZOS3:
|
||||
resize_kernel = lambda x: lanczos(x, 3)
|
||||
kernel_window = 3.
|
||||
case ResizeKernel.MAGIC_KERNEL:
|
||||
resize_kernel = magic_kernel
|
||||
kernel_window = 1.5
|
||||
case ResizeKernel.MAGIC_KERNEL_SHARP_2013:
|
||||
resize_kernel = magic_kernel_sharp_2013
|
||||
kernel_window = 2.5
|
||||
case ResizeKernel.MAGIC_KERNEL_SHARP_2021:
|
||||
resize_kernel = magic_kernel_sharp_2021
|
||||
kernel_window = 4.5
|
||||
case _:
|
||||
raise ValueError(f"Unknown resize kernel {k}")
|
||||
return resize_kernel, kernel_window
|
||||
|
||||
|
||||
### Resampling kernels
|
||||
def nearest(x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.abs(x)
|
||||
|
||||
weights = torch.where(x <= 0.5, 1., 0.)
|
||||
|
||||
return weights
|
||||
|
||||
def bilinear(x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.abs(x)
|
||||
|
||||
weights = torch.where(x <= 1.0, 1 - x, 0.)
|
||||
|
||||
return weights
|
||||
|
||||
def mitchell(x: torch.Tensor, B: float = 1 / 3, C: float = 1 / 3) -> torch.Tensor:
|
||||
x = torch.abs(x)
|
||||
|
||||
weights = torch.where(x <= 2, (-B - 6 * C) * x**3 + (6 * B + 30 * C) * x**2 + (-12 * B - 48 * C) * x + (8 * B + 24 * C), 0)
|
||||
weights = torch.where(x <= 1, (12 - 9 * B - 6 * C) * x**3 + (-18 + 12 * B + 6 * C) * x**2 + (6 - 2 * B), weights)
|
||||
|
||||
return weights
|
||||
|
||||
def magic_kernel(x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.abs(x)
|
||||
|
||||
weights = torch.where(x <= 1.5, (1/2) * (x - 3/2) ** 2, 0)
|
||||
weights = torch.where(x <= 0.5, (3/4) - x ** 2, weights)
|
||||
|
||||
return weights
|
||||
|
||||
def magic_kernel_sharp_2013(x: torch.Tensor):
|
||||
x = torch.abs(x)
|
||||
|
||||
weights = torch.where(x <= 2.5, (-1/8) * (x - 5/2) ** 2, 0)
|
||||
weights = torch.where(x <= 1.5, (1 - x) * (7/4 - x), weights)
|
||||
weights = torch.where(x <= 0.5, (17/16) - (7/4) * x ** 2, weights)
|
||||
|
||||
return weights
|
||||
|
||||
def magic_kernel_sharp_2021(x: torch.Tensor):
|
||||
x = torch.abs(x)
|
||||
|
||||
weights = torch.where(x <= 4.5, (-1/288) * (x - 9/2) ** 2, 0)
|
||||
weights = torch.where(x <= 3.5, (1/36) * (x - 3) * (x - 15/4), weights)
|
||||
weights = torch.where(x <= 2.5, (1/6) * (x - 2) * (65/24 - x), weights)
|
||||
weights = torch.where(x <= 1.5, (35/36) * (x - 1) * (x - 239/140), weights)
|
||||
weights = torch.where(x <= 0.5, (577/576) - (239/144) * x ** 2, weights)
|
||||
|
||||
return weights
|
||||
|
||||
def lanczos(x: torch.Tensor, n: int):
|
||||
return torch.where(torch.abs(x) < n, torch.sinc(x) * torch.sinc(x/n), 0)
|
||||
|
||||
def sharpen_conv2d(image: torch.Tensor, kernel: torch.Tensor, pad: int) -> torch.Tensor:
|
||||
image = F.pad(image, (pad,pad,pad,pad), mode='replicate')
|
||||
return F.conv2d(image, kernel, groups=image.shape[-3])
|
||||
|
||||
### Dithering and related functions.
|
||||
def stochastic_round(
|
||||
x: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
generator: torch.Generator = torch.Generator(),
|
||||
):
|
||||
image = x * torch.iinfo(out_dtype).max
|
||||
image_quant = image.to(out_dtype)
|
||||
quant_error = image - image_quant.to(image.dtype)
|
||||
dither = torch.empty_like(image_quant, dtype=torch.bool)
|
||||
torch.bernoulli(quant_error, generator=generator, out=dither)
|
||||
return image_quant + dither
|
||||
|
||||
def generate_bayer_matrix(n):
|
||||
"""Generate an n x n Bayer matrix where n is a power of 2."""
|
||||
assert (n & (n - 1)) == 0 and n > 0, "n must be a power of 2"
|
||||
|
||||
if n == 1:
|
||||
return np.array([[0]]) # Base case
|
||||
|
||||
smaller_matrix = generate_bayer_matrix(n // 2)
|
||||
|
||||
return np.block([
|
||||
[4 * smaller_matrix + 0, 4 * smaller_matrix + 2],
|
||||
[4 * smaller_matrix + 3, 4 * smaller_matrix + 1]
|
||||
])
|
||||
|
||||
### Scaling transforms
|
||||
|
||||
def _downscale_axis(
|
||||
image: torch.Tensor,
|
||||
size: int,
|
||||
resize_kernel: ResizeKernel,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
kernel, window = _get_resize_kernel(resize_kernel)
|
||||
k = size / image.shape[-1]
|
||||
PAD = math.ceil((window - 0.5) / k)
|
||||
|
||||
# Optimization note: doing torch.arange like this will compile to doing a int64 arange. Float arange
|
||||
# is much slower. So don't try to get clever and "optimize" by adding the +0.5 and *k to this.
|
||||
# Source grid is padded to allow "out of range" sampling from the source image.
|
||||
coords_source = (torch.arange(-PAD, image.shape[-1]+PAD, 1, dtype=torch.float32, device=device) + 0.5) * k
|
||||
coords_dest = (torch.arange(0, size, 1, dtype=torch.float32, device=device) + 0.5)
|
||||
|
||||
# Create a grid of relative distances between each point on this axis.
|
||||
coord_grid = torch.empty((coords_source.shape[0], coords_dest.shape[0]), dtype=dtype, device=device)
|
||||
# Coord grid always constructed in torch.float32 because float16 precision breaks down for this
|
||||
# after 1024.0. This subtraction is the first opportunity we have to safely cast to float16.
|
||||
torch.sub(coords_source.unsqueeze(-1), other=coords_dest, out=coord_grid)
|
||||
|
||||
weights = kernel(coord_grid)
|
||||
|
||||
# Normalizing weights to sum to 1 along axis we are resizing on
|
||||
weights /= weights.sum(dim=0, keepdim=True)
|
||||
# weights /= (1/k)
|
||||
|
||||
# Padded dimension is reduced by the matmul here.
|
||||
return F.pad(image, (PAD,PAD,0,0), mode='replicate') @ weights
|
||||
|
||||
def _upscale_axis(
|
||||
image: torch.Tensor,
|
||||
size: int,
|
||||
resize_kernel: ResizeKernel,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
kernel, window = _get_resize_kernel(resize_kernel)
|
||||
k = size / image.shape[-1]
|
||||
PAD = math.ceil((window - 0.5) * k)
|
||||
|
||||
# For upsizing, we expect out of range sampling from the destination image.
|
||||
coords_source = (torch.arange(0, image.shape[-1], 1, dtype=torch.float32, device=device) + 0.5)
|
||||
coords_dest = (torch.arange(-PAD, size+PAD, 1, dtype=torch.float32, device=device) + 0.5) / k
|
||||
|
||||
coord_grid = torch.empty((coords_source.shape[0], coords_dest.shape[0]), dtype=dtype, device=device)
|
||||
torch.sub(coords_source.unsqueeze(-1), other=coords_dest, out=coord_grid)
|
||||
|
||||
weights = kernel(coord_grid)
|
||||
|
||||
# We need to explicitly trim padding by summing it into the real area of the destination grid.
|
||||
weights[:, PAD] += weights[:, :PAD].sum(dim=1)
|
||||
weights[:, -PAD-1] += weights[:, -PAD:].sum(dim=1)
|
||||
weights = weights[:, PAD:-PAD]
|
||||
|
||||
weights /= weights.sum(dim=0, keepdim=True)
|
||||
|
||||
return image @ weights
|
||||
|
||||
@torch.compile
|
||||
def _downscale(
|
||||
image: torch.Tensor,
|
||||
out_res: tuple[int, int],
|
||||
resize_kernel: ResizeKernel,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
do_srgb_conversion: bool,
|
||||
):
|
||||
H, W = out_res
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if do_srgb_conversion:
|
||||
image = srgb_to_linear(image)
|
||||
|
||||
image = _downscale_axis(image, W, resize_kernel, device, dtype)
|
||||
image = _downscale_axis(image.mT, H, resize_kernel, device, dtype).mT
|
||||
|
||||
if do_srgb_conversion:
|
||||
image = linear_to_srgb(image)
|
||||
image = image.clamp(0,1)
|
||||
return image
|
||||
|
||||
@torch.compile
|
||||
def _upscale(
|
||||
image: torch.Tensor,
|
||||
out_res: tuple[int, int],
|
||||
resize_kernel: ResizeKernel,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
do_srgb_conversion: bool,
|
||||
):
|
||||
H, W = out_res
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if do_srgb_conversion:
|
||||
image = srgb_to_linear(image)
|
||||
|
||||
image = _upscale_axis(image, W, resize_kernel, device, dtype)
|
||||
image = _upscale_axis(image.mT, H, resize_kernel, device, dtype).mT
|
||||
|
||||
if do_srgb_conversion:
|
||||
image = linear_to_srgb(image)
|
||||
image = image.clamp(0,1)
|
||||
return image
|
||||
|
||||
# Triton sparse downscale - only available with Triton (CUDA)
|
||||
try:
|
||||
from .triton_functional import downscale_sparse
|
||||
except ImportError:
|
||||
downscale_sparse = None
|
||||
|
||||
def scale(
|
||||
image: torch.Tensor,
|
||||
out_res: Tuple[int, int],
|
||||
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
do_srgb_conversion: bool = True,
|
||||
use_sparse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if use_sparse:
|
||||
assert device.type != "cpu", "sparse implementation is only for GPU!"
|
||||
if resize_kernel != ResizeKernel.MAGIC_KERNEL_SHARP_2021:
|
||||
raise NotImplementedError
|
||||
if downscale_sparse is None:
|
||||
raise ImportError("Triton is required for sparse GPU acceleration")
|
||||
|
||||
context_manager = (
|
||||
set_stance("force_eager") if set_stance and device.type == "cpu" else nullcontext()
|
||||
)
|
||||
with context_manager:
|
||||
if image.shape[-1] <= out_res[-1] and image.shape[-2] <= out_res[-2]:
|
||||
assert not use_sparse
|
||||
return _upscale(image, out_res, resize_kernel, device, dtype, do_srgb_conversion)
|
||||
elif image.shape[-1] >= out_res[-1] and image.shape[-2] >= out_res[-2]:
|
||||
if use_sparse:
|
||||
return downscale_sparse(image, out_res, resize_kernel)
|
||||
return _downscale(image, out_res, resize_kernel, device, dtype, do_srgb_conversion)
|
||||
else:
|
||||
raise ValueError("Mixed axis resizing (e.g. scaling one axis up and the other down) is not supported. File a bug report with your use case if needed.")
|
||||
|
|
@ -0,0 +1,845 @@
|
|||
"""Sharpfin sparse matrix backend for Triton DDS matmul.
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
Adapted from https://github.com/stanford-futuredata/stk (Apache 2.0)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass
|
||||
from .triton_functional import linear_to_srgb_triton, srgb_to_linear_triton, magic_kernel_sharp_2021_triton, lanczos_triton
|
||||
|
||||
# Code is all adapted from https://github.com/stanford-futuredata/stk, licensed under Apache-2.0
|
||||
# Very reduced set of functions for handling DDS (Dense = Dense @ Sparse) matmul only, with the
|
||||
# DDS kernel modified to be more flexible on input shapes.
|
||||
|
||||
def _validate_matrix(shape, data, row_indices, column_indices, offsets):
|
||||
if data.dim() == 1:
|
||||
data = torch.reshape(data, [data.numel(), 1, 1])
|
||||
|
||||
if data.shape[-2] != data.shape[-1]:
|
||||
raise ValueError(
|
||||
"Expected square blocking in data. "
|
||||
f"Got block shape {[data.shape[-2], data.shape[-1]]}")
|
||||
|
||||
block_size = data.shape[-1]
|
||||
data = data.view([-1, block_size, block_size])
|
||||
|
||||
if data.dim() != 3:
|
||||
raise ValueError(
|
||||
"Expected 3D shape for data (nnz, block, block). "
|
||||
f"Got shape {data.dim()}D shape.")
|
||||
|
||||
block_size = data.shape[1]
|
||||
if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
|
||||
raise ValueError(
|
||||
"Matrix shape must be dividible by blocking. "
|
||||
f"Got shape {shape} with "
|
||||
f"{[block_size, block_size]} blocking.")
|
||||
|
||||
if np.prod(shape) < data.numel():
|
||||
raise ValueError(
|
||||
"Invalid matrix. Number of nonzeros exceeds matrix capacity "
|
||||
f"({data.numel()} v. {np.prod(shape)})")
|
||||
|
||||
if row_indices.dim() != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
|
||||
|
||||
if column_indices.dim() != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
|
||||
|
||||
if offsets.dim() != 1:
|
||||
raise ValueError(
|
||||
f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
|
||||
|
||||
if row_indices.numel() != data.shape[0]:
|
||||
raise ValueError(
|
||||
"Expected 1 index per nonzero block. "
|
||||
f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
|
||||
|
||||
if column_indices.numel() != data.shape[0]:
|
||||
raise ValueError(
|
||||
"Expected 1 index per nonzero block. "
|
||||
f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
|
||||
|
||||
block_rows = np.prod(shape[:-1]) / block_size
|
||||
if offsets.numel() != block_rows + 1:
|
||||
raise ValueError(
|
||||
"Expected one offset per block row plus one. "
|
||||
f"Got {offsets.numel()} offsets with {block_rows} block rows.")
|
||||
|
||||
is_cuda = (data.is_cuda and
|
||||
row_indices.is_cuda and
|
||||
column_indices.is_cuda and
|
||||
offsets.is_cuda)
|
||||
is_cpu = (not data.is_cuda and
|
||||
not row_indices.is_cuda and
|
||||
not column_indices.is_cuda and
|
||||
not offsets.is_cuda)
|
||||
if not (is_cuda or is_cpu):
|
||||
raise ValueError(
|
||||
"Expected data & meta-data on common device. "
|
||||
f"Got data on {data.device}, row_indices on {row_indices.device} "
|
||||
f"column_indices on {column_indices.device} and "
|
||||
f"offsets on {offsets.device}.")
|
||||
|
||||
if data.dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"Expected float16 data. Got {data.dtype} data.")
|
||||
if row_indices.dtype != torch.int16:
|
||||
raise ValueError(
|
||||
f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
|
||||
if column_indices.dtype != torch.int16:
|
||||
raise ValueError(
|
||||
f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
|
||||
if offsets.dtype != torch.int32:
|
||||
raise ValueError(
|
||||
f"Expected int32 offsets. Got {offsets.dtype} offsets.")
|
||||
return data
|
||||
|
||||
def _transpose(size, data: torch.Tensor, row_indices: torch.Tensor, column_indices: torch.Tensor, offsets):
|
||||
block_columns = size[1] // data.shape[1]
|
||||
|
||||
gather_indices = column_indices.argsort()
|
||||
column_indices_t = row_indices.gather(0, gather_indices)
|
||||
block_offsets_t = gather_indices.int()
|
||||
|
||||
column_indices_float = column_indices.float()
|
||||
|
||||
zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
|
||||
nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
|
||||
nnz_per_column = nnz_per_column.int()
|
||||
offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
|
||||
return column_indices_t, offsets_t, block_offsets_t
|
||||
|
||||
class SBSCMatrix(torch.nn.Module):
|
||||
"""Single Block Sparse Column (SBSC) matrix format."""
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
data: torch.Tensor,
|
||||
offset: int,
|
||||
block_size: int
|
||||
):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.offset = offset
|
||||
self.size = size
|
||||
self.num_blocks = data.shape[0]
|
||||
self.col_width = data.shape[2]
|
||||
self.col_block_size = block_size
|
||||
|
||||
class Matrix(torch.nn.Module):
|
||||
"""A matrix stored in block compressed sparse row (BCSR) format."""
|
||||
|
||||
def __init__(self,
|
||||
size,
|
||||
data: torch.Tensor,
|
||||
row_indices: torch.Tensor,
|
||||
column_indices: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
column_indices_t: torch.Tensor=None,
|
||||
offsets_t: torch.Tensor=None,
|
||||
block_offsets_t: torch.Tensor=None):
|
||||
super().__init__()
|
||||
self._size = size
|
||||
self._data = data
|
||||
self._row_indices = row_indices
|
||||
self._column_indices = column_indices
|
||||
self._offsets = offsets
|
||||
|
||||
if ((column_indices_t is None) or (offsets_t is None) or
|
||||
(block_offsets_t is None)):
|
||||
column_indices_t, offsets_t, block_offsets_t = _transpose(
|
||||
size, data, row_indices, column_indices, offsets)
|
||||
self._column_indices_t = column_indices_t
|
||||
self._offsets_t = offsets_t
|
||||
self._block_offsets_t = block_offsets_t
|
||||
|
||||
self._transposed = False
|
||||
|
||||
max_dim = np.iinfo(np.int16).max * self.blocking
|
||||
if column_indices.dtype == torch.int16:
|
||||
if size[0] > max_dim or size[1] > max_dim:
|
||||
raise ValueError(
|
||||
"Sparse matrix with shape {size} exceeds representable "
|
||||
"size with 16-bit indices.")
|
||||
|
||||
def validate(self):
|
||||
_validate_matrix(self._size,
|
||||
self._data,
|
||||
self._row_indices,
|
||||
self._column_indices,
|
||||
self._offsets)
|
||||
|
||||
def to(self, device):
|
||||
self._data = self._data.to(device)
|
||||
self._row_indices = self._row_indices.to(device)
|
||||
self._column_indices = self._column_indices.to(device)
|
||||
self._offsets = self._offsets.to(device)
|
||||
self._column_indices_t = self._column_indices_t.to(device)
|
||||
self._offsets_t = self._offsets_t.to(device)
|
||||
self._block_offsets_t = self._block_offsets_t.to(device)
|
||||
return self
|
||||
|
||||
def cuda(self):
|
||||
return self.to(torch.cuda.current_device())
|
||||
|
||||
def clone(self):
|
||||
return Matrix(
|
||||
self.size(),
|
||||
self.data.clone(),
|
||||
self.row_indices.clone(),
|
||||
self.column_indices.clone(),
|
||||
self.offsets.clone(),
|
||||
self.column_indices_t.clone(),
|
||||
self.offsets_t.clone(),
|
||||
self.block_offsets_t.clone())
|
||||
|
||||
def t(self):
|
||||
if self.dim() != 2:
|
||||
raise ValueError(
|
||||
"t() expects a tensor with <= 2 dimensions, "
|
||||
f"but self is {self.dim()}D.")
|
||||
out = Matrix(self.size(),
|
||||
self.data,
|
||||
self.row_indices,
|
||||
self.column_indices,
|
||||
self.offsets,
|
||||
self.column_indices_t,
|
||||
self.offsets_t,
|
||||
self.block_offsets_t)
|
||||
out._transposed = not self._transposed
|
||||
out._size = torch.Size((self._size[1], self._size[0]))
|
||||
return out
|
||||
|
||||
def contiguous(self):
|
||||
raise ValueError("Not yet implemented.")
|
||||
|
||||
def is_contiguous(self):
|
||||
return not self._transposed
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self._data.is_cuda
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._data.device
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.size()
|
||||
|
||||
def dim(self):
|
||||
return len(self._size)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def row_indices(self):
|
||||
return self._row_indices
|
||||
|
||||
@property
|
||||
def column_indices(self):
|
||||
return self._column_indices
|
||||
|
||||
@property
|
||||
def offsets(self):
|
||||
return self._offsets
|
||||
|
||||
@property
|
||||
def offsets_t(self):
|
||||
return self._offsets_t
|
||||
|
||||
@property
|
||||
def column_indices_t(self):
|
||||
return self._column_indices_t
|
||||
|
||||
@property
|
||||
def block_offsets_t(self):
|
||||
return self._block_offsets_t
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.data.dtype
|
||||
|
||||
@property
|
||||
def nnz(self):
|
||||
return self.data.numel()
|
||||
|
||||
@property
|
||||
def blocking(self):
|
||||
return self.data.shape[1]
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
return self.data.requires_grad
|
||||
|
||||
def requires_grad_(self, x):
|
||||
self.data.requires_grad_(x)
|
||||
return self
|
||||
|
||||
def view(self, *shape):
|
||||
assert self.is_contiguous()
|
||||
if shape[-1] != self.size()[-1]:
|
||||
raise ValueError(
|
||||
"Can't change view on compressed dimension. "
|
||||
f"{self.size()[-1]} v. {shape[-1]}.")
|
||||
if np.prod(shape) != np.prod(self.size()):
|
||||
raise ValueError(
|
||||
"Mismatch in numel of Matrix and new shape. "
|
||||
f"{np.prod(self.size())} v. {np.prod(shape)}")
|
||||
return Matrix(shape,
|
||||
self.data,
|
||||
self.row_indices,
|
||||
self.column_indices,
|
||||
self.offsets,
|
||||
self.column_indices_t,
|
||||
self.offsets_t,
|
||||
self.block_offsets_t)
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
size = self.size()
|
||||
if not self.is_contiguous():
|
||||
size = torch.Size((size[1], size[0]))
|
||||
out = Matrix(size,
|
||||
self.data.grad,
|
||||
self.row_indices,
|
||||
self.column_indices,
|
||||
self.offsets,
|
||||
self.column_indices_t,
|
||||
self.offsets_t,
|
||||
self.block_offsets_t)
|
||||
return out if self.is_contiguous() else out.t()
|
||||
|
||||
@torch.no_grad()
|
||||
def _expand_for_blocking(idxs, blocking):
|
||||
idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
|
||||
|
||||
idxs[:, :, 1] *= blocking
|
||||
idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
|
||||
|
||||
idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
|
||||
idxs = idxs.repeat(1, blocking, 1, 1)
|
||||
|
||||
idxs[:, :, :, 0] *= blocking
|
||||
idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
|
||||
idxs = torch.reshape(idxs, [-1, 2])
|
||||
return idxs
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def to_dense(x):
|
||||
assert isinstance(x, Matrix)
|
||||
|
||||
shape = (np.prod(x.shape[:-1]), x.shape[-1])
|
||||
row_idxs = x.row_indices.type(torch.int32)
|
||||
col_idxs = x.column_indices.type(torch.int32)
|
||||
indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
|
||||
indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
|
||||
|
||||
out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
|
||||
out.scatter_(0, indices, x.data.flatten())
|
||||
return out.reshape(x.size())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonConfig:
|
||||
BLOCK_M: int = 128
|
||||
BLOCK_N: int = 128
|
||||
BLOCK_K: int = 32
|
||||
BLOCK_SIZE: int = 64
|
||||
NUM_STAGES: int = 4
|
||||
NUM_WARPS: int = 4
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
|
||||
],
|
||||
key=['M', 'N', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def _dds_kernel(
|
||||
A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
|
||||
stride_ac, stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cc, stride_cm, stride_cn,
|
||||
row_indices: tl.tensor, column_indices: tl.tensor,
|
||||
offsets: tl.tensor, block_offsets_t: tl.tensor,
|
||||
fuse_srgb: tl.constexpr, clamp_output: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
|
||||
pid_c = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
pid_n = tl.program_id(2)
|
||||
|
||||
num_pid_m = tl.num_programs(1)
|
||||
num_pid_n = tl.num_programs(2)
|
||||
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
|
||||
|
||||
offsets += pid_n
|
||||
|
||||
start_inx = tl.load(offsets)
|
||||
end_inx = tl.load(offsets + 1)
|
||||
|
||||
column_indices += start_inx
|
||||
block_offsets_t += start_inx
|
||||
|
||||
BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
|
||||
|
||||
A_block_ptr = tl.make_block_ptr(
|
||||
base=A + pid_c * stride_ac, shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(pid_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(0, 1)
|
||||
)
|
||||
|
||||
rn = tl.arange(0, BLOCK_N)
|
||||
rbk = tl.arange(0, BLOCK_K)
|
||||
|
||||
B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)
|
||||
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
|
||||
|
||||
bk_sub_incr = BLOCK_K * stride_bk
|
||||
|
||||
for block_inx in range(end_inx - start_inx):
|
||||
a_col_idx = tl.load(column_indices + block_inx)
|
||||
ptr_A = tl.advance(A_block_ptr, (0, a_col_idx * BLOCK_SIZE))
|
||||
|
||||
b_block_offset = tl.load(block_offsets_t + block_inx)
|
||||
ptr_B = B + b_block_offset * BLOCK_ELEMENTS
|
||||
|
||||
for sub_block_inx in range(nsub_blocks):
|
||||
a = tl.load(ptr_A)
|
||||
b = tl.load(ptr_B)
|
||||
|
||||
acc = tl.dot(a, b, acc, out_dtype=tl.float16)
|
||||
|
||||
ptr_A = tl.advance(ptr_A, (0, BLOCK_K))
|
||||
ptr_B += bk_sub_incr
|
||||
|
||||
if fuse_srgb:
|
||||
acc = linear_to_srgb_triton(acc)
|
||||
|
||||
if clamp_output:
|
||||
acc = tl.clamp(acc, 0.0, 1.0)
|
||||
|
||||
if fuse_srgb or clamp_output:
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
C_block_ptr = tl.make_block_ptr(
|
||||
base=C + pid_c * stride_cc, shape=(O_M, O_N),
|
||||
strides=(stride_cm, stride_cn),
|
||||
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
tl.store(C_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def triton_dds(
|
||||
lhs: torch.Tensor,
|
||||
rhs: Matrix,
|
||||
fuse_srgb: bool = False,
|
||||
clamp_output: bool = False,
|
||||
output_mt: bool = False,
|
||||
output_slice: None | Tuple[int,int] = None
|
||||
):
|
||||
assert isinstance(lhs, torch.Tensor)
|
||||
assert isinstance(rhs, Matrix)
|
||||
assert lhs.ndim == 3
|
||||
CH = lhs.shape[0]
|
||||
stride_ac = lhs.stride(0)
|
||||
|
||||
M, K = lhs.shape[-2:]
|
||||
N = rhs.shape[-1]
|
||||
|
||||
if output_mt:
|
||||
if output_slice is not None:
|
||||
O_N, O_M = output_slice
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], *output_slice),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
else:
|
||||
O_N, O_M = N, M
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], rhs.shape[1], lhs.shape[-2]),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
|
||||
stride_cc = out.stride(-3)
|
||||
else:
|
||||
if output_slice is not None:
|
||||
O_M, O_N = output_slice
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], *output_slice),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
else:
|
||||
O_M, O_N = M, N
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-1], rhs.shape[1]),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
|
||||
stride_cc = out.stride(-3)
|
||||
|
||||
trans_B = not rhs.is_contiguous()
|
||||
trans_A = (lhs.stride(-2) > 1 and lhs.stride(-1) > 1)
|
||||
assert trans_A == False, trans_B == False
|
||||
|
||||
assert lhs.shape[-1] <= rhs.shape[0], "incompatible dimensions"
|
||||
|
||||
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
|
||||
|
||||
if trans_B:
|
||||
stride_bk, stride_bn = rhs.data.stride(2), rhs.data.stride(1)
|
||||
b_column_indices, b_offsets = rhs.column_indices, rhs.offsets
|
||||
else:
|
||||
stride_bk, stride_bn = rhs.data.stride(1), rhs.data.stride(2)
|
||||
b_column_indices, b_offsets = rhs.column_indices_t, rhs.offsets_t
|
||||
|
||||
grid = lambda META: (CH, triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
|
||||
|
||||
_dds_kernel[grid](
|
||||
lhs, rhs.data, out, M, N, K, O_M, O_N,
|
||||
stride_ac, stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cc, stride_cm, stride_cn,
|
||||
rhs.row_indices, b_column_indices, b_offsets,
|
||||
rhs.block_offsets_t, fuse_srgb, clamp_output,
|
||||
GROUP_M=128, ACC_TYPE=tl.float16, BLOCK_M=min(rhs.data.shape[1], 64),
|
||||
BLOCK_N=rhs.data.shape[1], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=min(rhs.data.shape[1], 64)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=4, num_warps=2),
|
||||
],
|
||||
key=['BLOCK_SIZE', 'BLOCK_N'],
|
||||
)
|
||||
@triton.jit
|
||||
def _dds_sbsc_kernel(
|
||||
A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
|
||||
stride_ac, stride_am, stride_ak,
|
||||
stride_bb, stride_bk, stride_bn,
|
||||
stride_cc, stride_cm, stride_cn,
|
||||
block_offset: tl.constexpr,
|
||||
fuse_srgb: tl.constexpr, clamp_output: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
|
||||
):
|
||||
|
||||
pid_n = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
pid_c = tl.program_id(2)
|
||||
|
||||
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
|
||||
|
||||
start_row = block_offset * pid_n
|
||||
|
||||
A_block_ptr = tl.make_block_ptr(
|
||||
base=A + pid_c * stride_ac, shape=(M, K),
|
||||
strides=(stride_am, stride_ak),
|
||||
offsets=(pid_m * BLOCK_M, start_row),
|
||||
block_shape=(BLOCK_M, BLOCK_K),
|
||||
order=(0, 1)
|
||||
)
|
||||
|
||||
B_block_ptr = tl.make_block_ptr(
|
||||
base=B + pid_n * stride_bb, shape=(BLOCK_SIZE, BLOCK_N),
|
||||
strides=(stride_bk, stride_bn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_K, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
|
||||
for block_slice in range(nsub_blocks):
|
||||
a = tl.load(A_block_ptr, eviction_policy='evict_first', boundary_check=(0,), padding_option='zero')
|
||||
b = tl.load(B_block_ptr, eviction_policy='evict_last')
|
||||
|
||||
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
|
||||
|
||||
A_block_ptr = A_block_ptr.advance((0, BLOCK_K))
|
||||
B_block_ptr = B_block_ptr.advance((BLOCK_K, 0))
|
||||
|
||||
if fuse_srgb:
|
||||
acc = linear_to_srgb_triton(acc)
|
||||
|
||||
if clamp_output:
|
||||
acc = tl.clamp(acc, 0.0, 1.0)
|
||||
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
C_block_ptr = tl.make_block_ptr(
|
||||
base=C + pid_c * stride_cc, shape=(O_M, O_N),
|
||||
strides=(stride_cm, stride_cn),
|
||||
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs')
|
||||
|
||||
def triton_dds_sbsc(
|
||||
lhs: torch.Tensor,
|
||||
rhs: SBSCMatrix,
|
||||
fuse_srgb: bool = False,
|
||||
clamp_output: bool = False,
|
||||
output_mt: bool = False,
|
||||
output_slice: None | Tuple[int,int] = None
|
||||
):
|
||||
assert isinstance(lhs, torch.Tensor)
|
||||
assert isinstance(rhs, SBSCMatrix)
|
||||
assert lhs.ndim == 3
|
||||
CH = lhs.shape[0]
|
||||
stride_ac = lhs.stride(0)
|
||||
|
||||
M, K = lhs.shape[-2:]
|
||||
N = rhs.size[-1]
|
||||
|
||||
if output_mt:
|
||||
if output_slice is not None:
|
||||
O_N, O_M = output_slice
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], *output_slice),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
else:
|
||||
O_N, O_M = N, M
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], rhs.size[1], lhs.shape[-2]),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
|
||||
stride_cc = out.stride(-3)
|
||||
else:
|
||||
if output_slice is not None:
|
||||
O_M, O_N = output_slice
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], *output_slice),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
else:
|
||||
O_M, O_N = M, N
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-1], rhs.size[1]),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
|
||||
stride_cc = out.stride(-3)
|
||||
|
||||
assert lhs.shape[-1] <= rhs.size[0], f"incompatible dimensions: {lhs.shape[-1]} > {rhs.size[0]}"
|
||||
|
||||
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
|
||||
|
||||
stride_bb, stride_bk, stride_bn = rhs.data.stride(0), rhs.data.stride(1), rhs.data.stride(2)
|
||||
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH)
|
||||
|
||||
_dds_sbsc_kernel[grid](
|
||||
lhs, rhs.data, out, M, N, K, O_M, O_N,
|
||||
stride_ac, stride_am, stride_ak,
|
||||
stride_bb, stride_bk, stride_bn,
|
||||
stride_cc, stride_cm, stride_cn,
|
||||
rhs.offset, fuse_srgb, clamp_output,
|
||||
GROUP_M=32, ACC_TYPE=tl.float16, BLOCK_M=32,
|
||||
BLOCK_N=rhs.data.shape[2], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=rhs.col_block_size
|
||||
)
|
||||
return out
|
||||
|
||||
from triton.language.extra import libdevice
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=4, num_warps=2),
|
||||
],
|
||||
key=['BLOCK_SIZE', 'BLOCK_N'],
|
||||
)
|
||||
@triton.jit
|
||||
def _dds_sbsc_zerorhs_kernel(
|
||||
A: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
|
||||
stride_ac, stride_am, stride_ak,
|
||||
stride_cc, stride_cm, stride_cn,
|
||||
k, PAD, block_offset: tl.constexpr,
|
||||
fuse_srgb: tl.constexpr, gamma_correction: tl.constexpr, clamp_output: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
|
||||
pid_n = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
pid_c = tl.program_id(2)
|
||||
|
||||
nsub_blocks = triton.cdiv(BLOCK_SIZE, BLOCK_K)
|
||||
|
||||
start_row = block_offset * pid_n
|
||||
|
||||
offs_k = (start_row + tl.arange(0, BLOCK_K)) * stride_ak
|
||||
m_range = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
A_mask = (m_range < M)[None, :].broadcast_to(BLOCK_K, BLOCK_M)
|
||||
|
||||
A_M_ptr = A + pid_c * stride_ac + stride_am * m_range
|
||||
|
||||
b_k = ((start_row - PAD + tl.arange(0, BLOCK_K)).to(tl.float32) + 0.5) * k
|
||||
b_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.float32) + 0.5
|
||||
|
||||
b_base = (b_k[None, :] - b_n[:, None])
|
||||
|
||||
acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float16)
|
||||
|
||||
for _ in tl.range(nsub_blocks):
|
||||
A_ptr = A_M_ptr[None, :] + tl.minimum(tl.maximum(offs_k, PAD) - PAD, K - 1)[:, None]
|
||||
|
||||
b = magic_kernel_sharp_2021_triton(b_base) * k
|
||||
|
||||
b = b.to(tl.float16)
|
||||
|
||||
a = tl.load(A_ptr, mask=A_mask)
|
||||
|
||||
if fuse_srgb == 'input':
|
||||
if gamma_correction == 'fast':
|
||||
a = libdevice.fast_powf(a, 2.2).to(tl.float16)
|
||||
elif gamma_correction == 'srgb':
|
||||
a = srgb_to_linear_triton(a).to(tl.float16)
|
||||
|
||||
acc = tl.dot(b, a, acc, out_dtype=tl.float16)
|
||||
|
||||
offs_k += BLOCK_K * stride_ak
|
||||
b_base += BLOCK_K * k
|
||||
|
||||
if fuse_srgb == 'output':
|
||||
if gamma_correction == 'fast':
|
||||
acc = libdevice.fast_powf(acc, 1.0/2.2)
|
||||
elif gamma_correction == 'srgb':
|
||||
acc = linear_to_srgb_triton(acc)
|
||||
|
||||
if clamp_output:
|
||||
acc = tl.clamp(acc, 0.0, 1.0)
|
||||
|
||||
if fuse_srgb == 'output' or clamp_output:
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
C_block_ptr = tl.make_block_ptr(
|
||||
base=C + pid_c * stride_cc, shape=(O_N, O_M),
|
||||
strides=(stride_cn, stride_cm),
|
||||
offsets=(pid_n * BLOCK_N, pid_m * BLOCK_M),
|
||||
block_shape=(BLOCK_N, BLOCK_M),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs')
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def triton_dds_zerorhs_sbsc(
|
||||
lhs: torch.Tensor,
|
||||
target_size: int,
|
||||
source_size: int,
|
||||
kernel_window: float,
|
||||
block_specs,
|
||||
fuse_srgb: str = '',
|
||||
gamma_correction: str = 'fast',
|
||||
clamp_output: bool = False,
|
||||
output_mt: bool = False,
|
||||
output_slice: None | Tuple[int,int] = None
|
||||
):
|
||||
assert isinstance(lhs, torch.Tensor)
|
||||
|
||||
assert fuse_srgb in ['input', 'output', '']
|
||||
assert gamma_correction in ['fast', 'srgb']
|
||||
|
||||
k = target_size / source_size
|
||||
|
||||
PAD = math.ceil((kernel_window - 0.5) / k)
|
||||
|
||||
offset, block_height, num_blocks, col_width = block_specs
|
||||
|
||||
assert lhs.ndim == 3
|
||||
CH = lhs.shape[0]
|
||||
stride_ac = lhs.stride(0)
|
||||
|
||||
M, K = lhs.shape[-2:]
|
||||
N = target_size
|
||||
|
||||
if output_mt:
|
||||
if output_slice is not None:
|
||||
O_N, O_M = output_slice
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], *output_slice),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
else:
|
||||
O_N, O_M = N, M
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], N, M),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
|
||||
stride_cc = out.stride(-3)
|
||||
else:
|
||||
if output_slice is not None:
|
||||
O_M, O_N = output_slice
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], *output_slice),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
else:
|
||||
O_M, O_N = M, N
|
||||
out = torch.empty(
|
||||
(*lhs.shape[:-2], M, N),
|
||||
dtype=lhs.dtype,
|
||||
device=lhs.device
|
||||
)
|
||||
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
|
||||
stride_cc = out.stride(-3)
|
||||
|
||||
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
|
||||
|
||||
grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH)
|
||||
|
||||
_dds_sbsc_zerorhs_kernel[grid](
|
||||
lhs, out, M, N, K, O_M, O_N,
|
||||
stride_ac, stride_am, stride_ak,
|
||||
stride_cc, stride_cm, stride_cn,
|
||||
k, PAD, offset, fuse_srgb, gamma_correction, clamp_output,
|
||||
BLOCK_M=32, BLOCK_K=16, BLOCK_N=col_width, BLOCK_SIZE=block_height,
|
||||
)
|
||||
return out
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
"""Sharpfin transform classes for torchvision integration.
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
Imports patched: absolute sharpfin.X -> relative .X, torchvision guarded.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from torchvision.transforms.v2 import Transform
|
||||
except ImportError:
|
||||
class Transform:
|
||||
_transformed_types = ()
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
from .util import QuantHandling, ResizeKernel, SharpenKernel, srgb_to_linear, linear_to_srgb
|
||||
from . import functional as SFF
|
||||
from .cms import apply_srgb
|
||||
import math
|
||||
from typing import Any, Dict, Tuple
|
||||
from PIL import Image
|
||||
from .functional import _get_resize_kernel
|
||||
from contextlib import nullcontext
|
||||
|
||||
try:
|
||||
from .triton_functional import downscale_sparse
|
||||
except ImportError:
|
||||
downscale_sparse = None
|
||||
|
||||
# from Pytorch >= 2.6
|
||||
set_stance = getattr(torch.compiler, "set_stance", None)
|
||||
|
||||
__all__ = ["ResizeKernel", "SharpenKernel", "QuantHandling"]
|
||||
|
||||
class Scale(Transform):
|
||||
"""Rescaling transform supporting multiple algorithms with sRGB linearization."""
|
||||
_transformed_types = (torch.Tensor,)
|
||||
def __init__(self,
|
||||
out_res: tuple[int, int] | int,
|
||||
device: torch.device | str = torch.device('cpu'),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
quantization: QuantHandling = QuantHandling.ROUND,
|
||||
generator: torch.Generator | None = None,
|
||||
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
||||
sharpen_kernel: SharpenKernel | None = None,
|
||||
do_srgb_conversion: bool = True,
|
||||
use_sparse: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
if not dtype.is_floating_point:
|
||||
raise ValueError("dtype must be a floating point type")
|
||||
if dtype.itemsize == 1:
|
||||
raise ValueError("float8 types are not supported due to severe accuracy issues and limited function support. float16 or float32 is recommended.")
|
||||
if out_dtype is not None and not out_dtype.is_floating_point and out_dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
|
||||
raise ValueError("out_dtype must be a torch float format or a torch unsigned int format")
|
||||
if use_sparse:
|
||||
assert device.type != 'cpu'
|
||||
if resize_kernel != ResizeKernel.MAGIC_KERNEL_SHARP_2021:
|
||||
raise NotImplementedError
|
||||
self.use_sparse = use_sparse
|
||||
|
||||
if isinstance(out_res, int):
|
||||
out_res = (out_res, out_res)
|
||||
self.out_res = out_res
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.out_dtype = out_dtype if out_dtype is not None else dtype
|
||||
self.do_srgb_conversion = do_srgb_conversion
|
||||
|
||||
if self.out_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
|
||||
match quantization:
|
||||
case QuantHandling.TRUNCATE:
|
||||
self.quantize_function = lambda x: x.mul(torch.iinfo(self.out_dtype).max).to(self.out_dtype)
|
||||
case QuantHandling.ROUND:
|
||||
self.quantize_function = lambda x: x.mul(torch.iinfo(self.out_dtype).max).round().to(self.out_dtype)
|
||||
case QuantHandling.STOCHASTIC_ROUND:
|
||||
if generator is not None:
|
||||
self.generator = torch.Generator(self.device)
|
||||
else:
|
||||
self.generator = generator
|
||||
self.quantize_function = lambda x: SFF.stochastic_round(x, self.out_dtype, self.generator)
|
||||
case QuantHandling.BAYER:
|
||||
self.bayer_matrix = torch.tensor(SFF.generate_bayer_matrix(16), dtype=self.dtype, device=self.device) / 255
|
||||
self.quantize_function = lambda x: self.apply_bayer_matrix(x)
|
||||
case _:
|
||||
raise ValueError(f"Unknown quantization handling type {quantization}")
|
||||
else:
|
||||
self.quantize_function = lambda x: x.to(dtype=out_dtype)
|
||||
|
||||
self.resize_kernel, self.kernel_window = _get_resize_kernel(resize_kernel)
|
||||
|
||||
match sharpen_kernel:
|
||||
case SharpenKernel.SHARP_2013:
|
||||
kernel = torch.tensor([-1, 6, -1], dtype=dtype, device=device) / 4
|
||||
self.sharp_2013_kernel = torch.outer(kernel, kernel).view(1, 1, 3, 3).expand(3, -1, -1, -1)
|
||||
self.sharpen_step = lambda x: SFF.sharpen_conv2d(x, self.sharp_2013_kernel, 1)
|
||||
case SharpenKernel.SHARP_2021:
|
||||
kernel = torch.tensor([-1, 6, -35, 204, -35, 6, -1], dtype=dtype, device=device) / 144
|
||||
self.sharp_2021_kernel = torch.outer(kernel, kernel).view(1, 1, 7, 7).expand(3, -1, -1, -1)
|
||||
self.sharpen_step = lambda x: SFF.sharpen_conv2d(x, self.sharp_2021_kernel, 3)
|
||||
case None:
|
||||
self.sharpen_step = lambda x: x
|
||||
case _:
|
||||
raise ValueError(f"Unknown sharpen kernel {sharpen_kernel}")
|
||||
|
||||
def apply_bayer_matrix(self, x: torch.Tensor):
|
||||
H, W = x.shape[-2:]
|
||||
b = self.bayer_matrix.repeat(1,1,math.ceil(H/16),math.ceil(W/16))[:,:,:H,:W]
|
||||
return (x*255 + b).to(self.out_dtype)
|
||||
|
||||
@torch.compile(disable=False)
|
||||
def downscale(self, image: torch.Tensor, out_res: tuple[int, int]):
|
||||
H, W = out_res
|
||||
image = image.to(dtype=self.dtype)
|
||||
if self.do_srgb_conversion:
|
||||
image = srgb_to_linear(image)
|
||||
|
||||
image = SFF._downscale_axis(image, W, self.kernel_window, self.resize_kernel, self.device, self.dtype)
|
||||
image = SFF._downscale_axis(image.mT, H, self.kernel_window, self.resize_kernel, self.device, self.dtype).mT
|
||||
|
||||
image = self.sharpen_step(image)
|
||||
|
||||
if self.do_srgb_conversion:
|
||||
image = linear_to_srgb(image)
|
||||
image = image.clamp(0,1)
|
||||
image = self.quantize_function(image)
|
||||
return image
|
||||
|
||||
@torch.compile(disable=False)
|
||||
def downscale_sparse(self, image: torch.Tensor, out_res: tuple[int, int]):
|
||||
image = image.to(dtype=self.dtype)
|
||||
if downscale_sparse is not None:
|
||||
image = downscale_sparse(image, out_res)
|
||||
image = self.quantize_function(image)
|
||||
return image
|
||||
|
||||
@torch.compile(disable=False)
|
||||
def upscale(self, image: torch.Tensor, out_res: tuple[int, int]):
|
||||
H, W = out_res
|
||||
image = image.to(dtype=self.dtype)
|
||||
if self.do_srgb_conversion:
|
||||
image = srgb_to_linear(image)
|
||||
|
||||
image = self.sharpen_step(image)
|
||||
|
||||
image = SFF._upscale_axis(image, W, self.kernel_window, self.resize_kernel, self.device, self.dtype)
|
||||
image = SFF._upscale_axis(image.mT, H, self.kernel_window, self.resize_kernel, self.device, self.dtype).mT
|
||||
|
||||
if self.do_srgb_conversion:
|
||||
image = linear_to_srgb(image)
|
||||
image = image.clamp(0,1)
|
||||
image = self.quantize_function(image)
|
||||
return image
|
||||
|
||||
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> torch.Tensor:
|
||||
image = inpt.to(device=self.device)
|
||||
context_manager = (
|
||||
set_stance("force_eager") if set_stance and self.device.type == "cpu" else nullcontext()
|
||||
)
|
||||
with context_manager:
|
||||
if image.shape[-1] <= self.out_res[-1] and image.shape[-2] <= self.out_res[-2]:
|
||||
return self.upscale(image, self.out_res)
|
||||
elif image.shape[-1] >= self.out_res[-1] and image.shape[-2] >= self.out_res[-2]:
|
||||
if self.use_sparse:
|
||||
return self.downscale_sparse(image, self.out_res)
|
||||
return self.downscale(image, self.out_res)
|
||||
else:
|
||||
raise ValueError("Mixed axis resizing (e.g. scaling one axis up and the other down) is not supported. File a bug report with your use case if needed.")
|
||||
|
||||
class ApplyCMS(Transform):
|
||||
"""Apply color management to a PIL Image to standardize it to sRGB color space."""
|
||||
_transformed_types = (Image.Image,)
|
||||
|
||||
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
||||
if not isinstance(inpt, Image.Image):
|
||||
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
|
||||
|
||||
return apply_srgb(inpt)
|
||||
|
||||
class AlphaComposite(Transform):
|
||||
_transformed_types = (Image.Image,)
|
||||
def __init__(
|
||||
self,
|
||||
background: Tuple[int,int,int] = (255, 255, 255)
|
||||
):
|
||||
super().__init__()
|
||||
self.background = background
|
||||
|
||||
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
||||
if not isinstance(inpt, Image.Image):
|
||||
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
|
||||
if not inpt.has_transparency_data:
|
||||
return inpt
|
||||
|
||||
bg = Image.new("RGB", inpt.size, self.background).convert('RGBA')
|
||||
|
||||
return Image.alpha_composite(bg, inpt).convert('RGB')
|
||||
|
||||
class AspectRatioCrop(Transform):
|
||||
_transformed_types = (Image.Image,)
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
height: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ref_width = width
|
||||
self.ref_height = height
|
||||
self.aspect_ratio = width / height
|
||||
|
||||
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
||||
if not isinstance(inpt, Image.Image):
|
||||
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
|
||||
|
||||
left, top, right, bottom = 0, 0, inpt.width, inpt.height
|
||||
inpt_ar = inpt.width / inpt.height
|
||||
|
||||
if inpt_ar > self.aspect_ratio:
|
||||
result_width = int(round(inpt.height / self.ref_height * self.ref_width))
|
||||
crop_amt = (inpt.width - result_width) // 2
|
||||
left += crop_amt
|
||||
right -= crop_amt
|
||||
elif inpt_ar < self.aspect_ratio:
|
||||
result_height = int(round(inpt.width / self.ref_width * self.ref_height))
|
||||
crop_amt = (inpt.height - result_height) // 2
|
||||
top += crop_amt
|
||||
bottom -= crop_amt
|
||||
|
||||
return inpt.crop((left, top, right, bottom))
|
||||
|
|
@ -0,0 +1,708 @@
|
|||
"""Sharpfin Triton-accelerated GPU scaling functions.
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
Imports patched: absolute sharpfin.X -> relative .X
|
||||
Requires: triton (only available on CUDA platforms)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .util import ResizeKernel
|
||||
from typing import Tuple
|
||||
import torch.nn.functional as F
|
||||
from triton.language.extra import libdevice
|
||||
from .util import linear_to_srgb, srgb_to_linear
|
||||
|
||||
# Magic Kernel Sharp with Triton optimizations. Mainly converted to polynomials so that
|
||||
# FMA operators can be used.
|
||||
@triton.jit
|
||||
def magic_kernel_sharp_2021_triton(x: tl.tensor):
|
||||
out = tl.zeros_like(x) # inplace operation doesn't help much.
|
||||
x = tl.abs(x)
|
||||
|
||||
lte_05 = x <= 0.5
|
||||
lte_15 = x <= 1.5
|
||||
lte_25 = x <= 2.5
|
||||
lte_35 = x <= 3.5
|
||||
lte_45 = x <= 4.5
|
||||
|
||||
x_sq = x*x # triton would compile like this anyways but it helps readability
|
||||
|
||||
out = tl.where(lte_05, tl.fma(x_sq, -239/144, 577/576), out)
|
||||
out = tl.where(lte_15 and not lte_05, tl.fma(x_sq, 35/36, tl.fma(x, -379/144, 239/144)), out)
|
||||
out = tl.where(lte_25 and not lte_15, tl.fma(x_sq, -1/6, tl.fma(x, 113/144, -65/72)), out)
|
||||
out = tl.where(lte_35 and not lte_25, tl.fma(x_sq, 1/36, tl.fma(x, -3/16, 5/16)), out)
|
||||
out = tl.where(lte_45 and not lte_35, tl.fma(x_sq, -1/288, tl.fma(x, 1/32, -9/128)), out)
|
||||
|
||||
return out
|
||||
|
||||
@triton.jit
|
||||
def sinc_triton(x: tl.tensor):
|
||||
y = tl.fma(x, math.pi, 1e-8)
|
||||
return libdevice.fast_sinf(y) / y
|
||||
|
||||
@triton.jit
|
||||
def lanczos_triton(x: tl.tensor, n: tl.constexpr = 3):
|
||||
return tl.where(
|
||||
tl.abs(x) < n,
|
||||
sinc_triton(x) * sinc_triton(x/n),
|
||||
0
|
||||
)
|
||||
|
||||
# NOTE: there is no reason to use libdevice.pow, its only differences are with subnormals
|
||||
@triton.jit
|
||||
def linear_to_srgb_triton(x):
|
||||
return tl.where(
|
||||
x <= 0.0031308,
|
||||
x * 12.92,
|
||||
tl.fma(1.055, libdevice.fast_powf(x, 1/2.4), -0.055)
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def srgb_to_linear_triton(x):
|
||||
return tl.where(
|
||||
x <= 0.04045,
|
||||
x / 12.92,
|
||||
libdevice.fast_powf(tl.fma(1/1.055, x, 0.055/1.055), 2.4)
|
||||
)
|
||||
|
||||
from .sparse_backend import triton_dds, triton_dds_sbsc, triton_dds_zerorhs_sbsc, Matrix, SBSCMatrix
|
||||
|
||||
def _get_resize_kernel_triton(k: ResizeKernel):
|
||||
match k:
|
||||
case ResizeKernel.NEAREST:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.BILINEAR:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.MITCHELL:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.CATMULL_ROM:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.B_SPLINE:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.LANCZOS2:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.LANCZOS3:
|
||||
resize_kernel = lanczos_triton
|
||||
kernel_window = 3.
|
||||
case ResizeKernel.MAGIC_KERNEL:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.MAGIC_KERNEL_SHARP_2013:
|
||||
raise NotImplementedError
|
||||
case ResizeKernel.MAGIC_KERNEL_SHARP_2021:
|
||||
resize_kernel = magic_kernel_sharp_2021_triton
|
||||
kernel_window = 4.5
|
||||
case _:
|
||||
raise ValueError(f"Unknown resize kernel {k}")
|
||||
return resize_kernel, kernel_window
|
||||
|
||||
# Sparse Downscale and support functions.
|
||||
|
||||
# Amanatides, John and Woo, Andrew -- Fast Voxel Traversal
|
||||
def grid_line_tiles(x0, y0, x1, y1, grid_width, grid_height):
|
||||
tiles = set()
|
||||
|
||||
dx = x1 - x0
|
||||
dy = y1 - y0
|
||||
|
||||
x = math.floor(x0)
|
||||
y = math.floor(y0)
|
||||
|
||||
end_x = math.floor(x1)
|
||||
end_y = math.floor(y1)
|
||||
|
||||
step_x = 1 if dx > 0 else -1
|
||||
step_y = 1 if dy > 0 else -1
|
||||
|
||||
t_max_x = ((x + (step_x > 0)) - x0) / dx if dx != 0 else float('inf')
|
||||
t_max_y = ((y + (step_y > 0)) - y0) / dy if dy != 0 else float('inf')
|
||||
|
||||
t_delta_x = abs(1 / dx) if dx != 0 else float('inf')
|
||||
t_delta_y = abs(1 / dy) if dy != 0 else float('inf')
|
||||
|
||||
while True:
|
||||
if 0 <= x < grid_width and 0 <= y < grid_height:
|
||||
tiles.add((y,x))
|
||||
if x == end_x and y == end_y:
|
||||
break
|
||||
if t_max_x < t_max_y:
|
||||
t_max_x += t_delta_x
|
||||
x += step_x
|
||||
else:
|
||||
t_max_y += t_delta_y
|
||||
y += step_y
|
||||
|
||||
return tiles
|
||||
|
||||
def tile_mask_function(dest_size, src_size, kernel_window=4.5, tile_size=64):
|
||||
k = dest_size / src_size
|
||||
PAD = math.ceil((kernel_window-0.5) / k)
|
||||
|
||||
grid_size = math.ceil((src_size + 2*PAD)/tile_size), math.ceil(dest_size/tile_size)
|
||||
|
||||
line_1 = 0, 0.5/tile_size, (dest_size)/tile_size, (src_size+0.5)/tile_size
|
||||
line_2 = 0, (2*PAD - 0.5)/tile_size, (dest_size)/tile_size, (src_size + 2*PAD - 0.5)/tile_size
|
||||
lines = line_1, line_2
|
||||
|
||||
mask = torch.zeros(grid_size, dtype=torch.bool)
|
||||
|
||||
tiles = set()
|
||||
|
||||
for (x0, y0, x1, y1) in lines:
|
||||
tiles.update(grid_line_tiles(x0, y0, x1, y1, grid_size[1], grid_size[0]))
|
||||
|
||||
tiles = torch.tensor(list(tiles))
|
||||
|
||||
mask[tiles[:,0], tiles[:,1]] = True
|
||||
|
||||
return mask, tiles
|
||||
|
||||
def create_tensor_metadata(
|
||||
tile_mask: torch.Tensor,
|
||||
tiles: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
offsets_t: torch.Tensor,
|
||||
):
|
||||
|
||||
indices[:,:2] = tiles
|
||||
|
||||
torch.argsort(indices[:,1], stable=True, out=indices[:,2]) # block_offsets_t
|
||||
torch.take(indices[:,0], indices[:,2], out=indices[:,3]) # col_indices_t
|
||||
|
||||
# reusing the offsets buffer here helps performance
|
||||
torch.sum(tile_mask, dim=1, out=offsets[1:])
|
||||
torch.sum(tile_mask, dim=0, out=offsets_t[1:])
|
||||
torch.cumsum(offsets, dim=0, out=offsets)
|
||||
torch.cumsum(offsets_t, dim=0, out=offsets_t)
|
||||
|
||||
return indices, offsets, offsets_t
|
||||
|
||||
# for isolating the one mandatory graph break
|
||||
@torch.compiler.disable
|
||||
def _get_nnz_and_buffers(tile_mask):
|
||||
num_sparse_blocks = torch.sum(tile_mask).item()
|
||||
|
||||
return [
|
||||
torch.empty((4, num_sparse_blocks), dtype=torch.int64, pin_memory=True).T, # indices
|
||||
torch.zeros((tile_mask.shape[0] + 1,), dtype=torch.int32, pin_memory=True), # offsets
|
||||
torch.zeros((tile_mask.shape[1] + 1,), dtype=torch.int32, pin_memory=True) # offsets_t
|
||||
]
|
||||
|
||||
|
||||
def generate_sparse_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64):
|
||||
tile_mask, tiles = tile_mask_function(dest_size, src_size, kernel_window, tile_size)
|
||||
|
||||
buffers = _get_nnz_and_buffers(tile_mask)
|
||||
num_sparse_blocks = buffers[0].shape[0]
|
||||
|
||||
indices, offsets, offsets_t = create_tensor_metadata(
|
||||
tile_mask,
|
||||
tiles,
|
||||
*buffers
|
||||
)
|
||||
|
||||
indices = indices.to(device='cuda', dtype=torch.int32, non_blocking=True)
|
||||
|
||||
return Matrix(
|
||||
(tile_mask.shape[0] * tile_size, tile_mask.shape[1] * tile_size),
|
||||
torch.empty(num_sparse_blocks, tile_size, tile_size, dtype=torch.float16, device='cuda'),
|
||||
row_indices=indices[:,0],
|
||||
column_indices=indices[:,1],
|
||||
offsets=offsets.to(device='cuda', non_blocking=True),
|
||||
column_indices_t=indices[:,3],
|
||||
offsets_t=offsets_t.to(device='cuda', non_blocking=True),
|
||||
block_offsets_t=indices[:,2]
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def compute_sparse_coord_grid_kernel(
|
||||
coords_source_ptr, coords_dest_ptr, sparse_data_ptr,
|
||||
row_indices_ptr, col_indices_ptr,
|
||||
k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
SPARSE_BLOCK_NUMEL = SPARSE_BLOCK_SIZE * SPARSE_BLOCK_SIZE
|
||||
sparse_block = tl.program_id(0)
|
||||
|
||||
tile_row = tl.program_id(1)
|
||||
tile_col = tl.program_id(2)
|
||||
|
||||
row_offsets = tl.load(row_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_row * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
col_offsets = tl.load(col_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
mask_row = row_offsets < M
|
||||
mask_col = col_offsets < N
|
||||
|
||||
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0)
|
||||
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0)
|
||||
|
||||
x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
|
||||
|
||||
x = magic_kernel_sharp_2021_triton(x)
|
||||
|
||||
x *= k
|
||||
|
||||
sparse_block_ptr = sparse_data_ptr + sparse_block * SPARSE_BLOCK_NUMEL
|
||||
|
||||
local_row_start = tile_row * BLOCK_SIZE
|
||||
local_col_start = tile_col * BLOCK_SIZE
|
||||
|
||||
local_rows = local_row_start + tl.arange(0, BLOCK_SIZE)
|
||||
local_cols = local_col_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
local_rows_2d = local_rows[:, None]
|
||||
local_cols_2d = local_cols[None, :]
|
||||
|
||||
store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d
|
||||
|
||||
tl.store(sparse_block_ptr + store_offset, x)
|
||||
|
||||
def compute_sparse_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64):
|
||||
assert SPARSE_BLOCK_SIZE % BLOCK_SIZE == 0
|
||||
|
||||
k = target_size / source_size
|
||||
PAD = math.ceil((kernel_window - 0.5) / k)
|
||||
|
||||
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
|
||||
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
M, N = coords_source.shape[0], coords_dest.shape[0]
|
||||
x = generate_sparse_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE)
|
||||
|
||||
SPARSE_NUM_BLOCKS = x.data.shape[0]
|
||||
|
||||
grid = lambda meta: (SPARSE_NUM_BLOCKS, triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']), triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']))
|
||||
compute_sparse_coord_grid_kernel[grid](
|
||||
coords_source, coords_dest, x.data,
|
||||
x.row_indices, x.column_indices,
|
||||
k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
||||
return x
|
||||
|
||||
# Dense kernel for downsampling coord_grids
|
||||
|
||||
@triton.jit
|
||||
def compute_coord_grid_kernel(
|
||||
coords_source_ptr, coords_dest_ptr, coord_grid_ptr, k,
|
||||
M, N, BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
row_offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
col_offsets = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
mask_row = row_offsets < M
|
||||
mask_col = col_offsets < N
|
||||
|
||||
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row)
|
||||
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col)
|
||||
|
||||
x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
|
||||
|
||||
x = magic_kernel_sharp_2021_triton(x)
|
||||
|
||||
x *= k
|
||||
|
||||
tl.store(coord_grid_ptr + row_offsets[:, None] * N + col_offsets[None, :], x, mask=mask_row[:, None] & mask_col[None, :])
|
||||
|
||||
def compute_coord_grid(target_size, source_size, kernel_window=4.5, BLOCK_SIZE=32):
|
||||
k = target_size / source_size
|
||||
PAD = math.ceil((kernel_window - 0.5) / k)
|
||||
|
||||
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
|
||||
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
M, N = coords_source.shape[0], coords_dest.shape[0]
|
||||
coord_grid = torch.empty((M, N), dtype=torch.float16, device='cuda')
|
||||
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
compute_coord_grid_kernel[grid](coords_source, coords_dest, coord_grid, k, M, N, BLOCK_SIZE)
|
||||
return coord_grid
|
||||
|
||||
@triton.jit
|
||||
def pad_replicate_kernel(
|
||||
A, B,
|
||||
M_X, N_X,
|
||||
M_Y, N_Y,
|
||||
M_PAD, N_PAD,
|
||||
stride_xc, stride_xm, stride_xn,
|
||||
stride_yc, stride_ym, stride_yn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
fuse_linrgb: tl.constexpr
|
||||
):
|
||||
pid_c = tl.program_id(0)
|
||||
pid_m = tl.program_id(1)
|
||||
pid_n = tl.program_id(2)
|
||||
|
||||
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_m_cl = tl.maximum(offs_m, M_PAD) - M_PAD
|
||||
offs_m_cl = tl.minimum(offs_m_cl, M_X - 1)
|
||||
offs_n_cl = tl.maximum(offs_n, N_PAD) - N_PAD
|
||||
offs_n_cl = tl.minimum(offs_n_cl, N_X - 1)
|
||||
|
||||
mask_m = offs_m < M_Y
|
||||
mask_n = offs_n < N_Y
|
||||
|
||||
A_ptr = A + pid_c * stride_xc + offs_m_cl[:, None] * stride_xm + offs_n_cl[None, :] * stride_xn
|
||||
B_ptr = B + pid_c * stride_yc + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
|
||||
|
||||
t = tl.load(A_ptr)
|
||||
if fuse_linrgb:
|
||||
t = srgb_to_linear_triton(t)
|
||||
|
||||
tl.store(B_ptr, t, mask=mask_m[:, None] & mask_n[None, :])
|
||||
|
||||
def pad_replicate(
|
||||
img: torch.Tensor,
|
||||
pad_h: int,
|
||||
pad_w: int,
|
||||
sparse_block_size: int = 0,
|
||||
fuse_linrgb: bool = True,
|
||||
):
|
||||
C = img.shape[0]
|
||||
|
||||
M_PAD = pad_h
|
||||
N_PAD = pad_w
|
||||
|
||||
if sparse_block_size != 0:
|
||||
out_H = img.shape[-2] + M_PAD + (-(img.shape[-2] + M_PAD)) % sparse_block_size
|
||||
out_W = img.shape[-1] + N_PAD + (-(img.shape[-1] + N_PAD)) % sparse_block_size
|
||||
else:
|
||||
out_H = img.shape[-2] + M_PAD + M_PAD
|
||||
out_W = img.shape[-1] + N_PAD + N_PAD
|
||||
|
||||
out = torch.empty(C, out_H, out_W, dtype=img.dtype, device=img.device)
|
||||
|
||||
BLOCK_M = 1
|
||||
BLOCK_N = 512
|
||||
|
||||
grid = lambda META: (
|
||||
C,
|
||||
(out.shape[1] + META['BLOCK_M'] - 1) // META['BLOCK_M'],
|
||||
(out.shape[2] + META['BLOCK_N'] - 1) // META['BLOCK_N'],
|
||||
)
|
||||
|
||||
pad_replicate_kernel[grid](
|
||||
img, out,
|
||||
img.shape[1], img.shape[2],
|
||||
out.shape[1], out.shape[2],
|
||||
M_PAD, N_PAD,
|
||||
img.stride(0), img.stride(1), img.stride(2),
|
||||
out.stride(0), out.stride(1), out.stride(2),
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
||||
fuse_linrgb=fuse_linrgb,
|
||||
)
|
||||
return out
|
||||
|
||||
def downscale_sparse(
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
||||
do_gamma_handling=True,
|
||||
BLOCK_SIZE: int = 32,
|
||||
SPARSE_BLOCK_SIZE: int = 64,
|
||||
) -> torch.Tensor:
|
||||
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
||||
|
||||
T_W = target_size[-1]
|
||||
T_H = target_size[-2]
|
||||
S_W = image.shape[-1]
|
||||
S_H = image.shape[-2]
|
||||
|
||||
y_s_w = compute_sparse_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
||||
y_s_h = compute_sparse_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
||||
|
||||
PAD_W = math.ceil((window - 0.5) / (T_W / S_W))
|
||||
PAD_H = math.ceil((window - 0.5) / (T_H / S_H))
|
||||
|
||||
image = pad_replicate(
|
||||
image,
|
||||
PAD_H,
|
||||
PAD_W,
|
||||
SPARSE_BLOCK_SIZE,
|
||||
fuse_linrgb=do_gamma_handling
|
||||
)
|
||||
|
||||
image = triton_dds(
|
||||
image,
|
||||
y_s_w,
|
||||
output_mt=True
|
||||
)
|
||||
|
||||
image = triton_dds(
|
||||
image,
|
||||
y_s_h,
|
||||
fuse_srgb=do_gamma_handling,
|
||||
clamp_output=True,
|
||||
output_mt=True,
|
||||
output_slice=(T_H, T_W)
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def downscale_triton(
|
||||
image: torch.Tensor,
|
||||
target_size: torch.Size,
|
||||
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
||||
do_gamma_handling=True,
|
||||
) -> torch.Tensor:
|
||||
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
||||
|
||||
y_s_w = compute_coord_grid(target_size[-1], image.shape[-1], window)
|
||||
y_s_h = compute_coord_grid(target_size[-2], image.shape[-2], window)
|
||||
|
||||
PAD_W = math.ceil((window - 0.5) / (target_size[-1] / image.shape[-1]))
|
||||
PAD_H = math.ceil((window - 0.5) / (target_size[-2] / image.shape[-2]))
|
||||
|
||||
image = pad_replicate(image, PAD_H, PAD_W, fuse_linrgb=do_gamma_handling)
|
||||
|
||||
image = image.view(-1, image.shape[-1])
|
||||
image = image @ y_s_w
|
||||
image = image.view(3, -1, image.shape[-1])
|
||||
image = image.mT
|
||||
image = image.reshape(-1, image.shape[-1])
|
||||
image = image @ y_s_h
|
||||
image = image.view(3, -1, image.shape[-1])
|
||||
image = image.mT
|
||||
if do_gamma_handling:
|
||||
image = linear_to_srgb(image[:, :target_size[0], :target_size[1]])
|
||||
image.clamp_(0.,1.)
|
||||
return image
|
||||
|
||||
# Single Block Sparse Column implementations.
|
||||
|
||||
def evaluate_line(x, x0, y0, x1, y1):
|
||||
"""Evaluate the y-coordinate at a given x along a line from (x0, y0) to (x1, y1)."""
|
||||
if x1 == x0:
|
||||
return float('inf')
|
||||
t = (x - x0) / (x1 - x0)
|
||||
return y0 + t * (y1 - y0)
|
||||
|
||||
def pad_height_to_multiple(height, multiple):
|
||||
"""Pad a height up to the next multiple of 'multiple'."""
|
||||
return int(math.ceil(height / multiple) * multiple)
|
||||
|
||||
def generate_sbsc_structure(
|
||||
dest_size,
|
||||
src_size,
|
||||
kernel_window=4.5,
|
||||
tile_size=64,
|
||||
y_tile_size=32
|
||||
):
|
||||
k = dest_size / src_size
|
||||
PAD = math.ceil((kernel_window - 0.5) / k)
|
||||
|
||||
line1 = (0, 0.5, dest_size, src_size + 0.5)
|
||||
line2 = (0, 2 * PAD - 0.5, dest_size, src_size + 2 * PAD - 0.5)
|
||||
|
||||
y_mins = []
|
||||
y_maxs = []
|
||||
n_blocks = math.ceil(dest_size / tile_size)
|
||||
max_height = 0
|
||||
|
||||
for i in range(n_blocks):
|
||||
x0 = i * tile_size
|
||||
x1 = min(dest_size - 1, x0 + tile_size - 1)
|
||||
|
||||
yt0 = evaluate_line(x0, *line1)
|
||||
yt1 = evaluate_line(x1, *line1)
|
||||
yb0 = evaluate_line(x0, *line2)
|
||||
yb1 = evaluate_line(x1, *line2)
|
||||
|
||||
y_min = min(yt0, yt1)
|
||||
y_max = max(yb0, yb1)
|
||||
|
||||
height = y_max - y_min
|
||||
padded = pad_height_to_multiple(height, y_tile_size)
|
||||
|
||||
y_mins.append(y_min)
|
||||
y_maxs.append(y_max)
|
||||
max_height = max(max_height, padded)
|
||||
|
||||
slope_top = (line1[3] - line1[1]) / (line1[2] - line1[0])
|
||||
ideal_step = slope_top * tile_size
|
||||
|
||||
lower_bounds = []
|
||||
upper_bounds = []
|
||||
for i in range(1, n_blocks):
|
||||
lower_bounds.append((y_maxs[i] - max_height) / i)
|
||||
upper_bounds.append(y_mins[i] / i)
|
||||
|
||||
lower = math.ceil(max(lower_bounds)) if lower_bounds else 0
|
||||
upper = math.floor(min(upper_bounds)) if upper_bounds else int(round(ideal_step))
|
||||
|
||||
fixed_offset = int(round(ideal_step))
|
||||
if fixed_offset < lower:
|
||||
fixed_offset = lower
|
||||
elif fixed_offset > upper:
|
||||
fixed_offset = upper
|
||||
|
||||
return fixed_offset, max_height, n_blocks, tile_size
|
||||
|
||||
|
||||
def generate_sbsc_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64, y_tile_size=32):
|
||||
offset, block_height, num_blocks, col_width = generate_sbsc_structure(
|
||||
dest_size, src_size, kernel_window, tile_size, y_tile_size
|
||||
)
|
||||
|
||||
return SBSCMatrix(
|
||||
size=((offset * (num_blocks - 1)) + block_height, dest_size),
|
||||
data=torch.empty((num_blocks, block_height, col_width), dtype=torch.float16, device='cuda'),
|
||||
offset=offset,
|
||||
block_size=y_tile_size
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def compute_sbsc_coord_grid_kernel(
|
||||
coords_source_ptr, coords_dest_ptr,
|
||||
sparse_data_ptr, offset: tl.constexpr,
|
||||
stride_xb, stride_xw, stride_xh,
|
||||
k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
|
||||
pid_w = tl.program_id(0)
|
||||
pid_h = tl.program_id(1)
|
||||
|
||||
start_row = offset * pid_w + pid_h * BLOCK_SIZE
|
||||
start_col = pid_w * SPARSE_BLOCK_SIZE
|
||||
|
||||
row_offsets = start_row + tl.arange(0, BLOCK_SIZE)
|
||||
col_offsets = start_col + tl.arange(0, SPARSE_BLOCK_SIZE)
|
||||
|
||||
mask_row = row_offsets < M
|
||||
mask_col = col_offsets < N
|
||||
|
||||
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0)
|
||||
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0)
|
||||
|
||||
y = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
|
||||
|
||||
y = magic_kernel_sharp_2021_triton(y)
|
||||
|
||||
y *= k
|
||||
|
||||
sparse_block_ptr = sparse_data_ptr + pid_w * stride_xb
|
||||
|
||||
local_row_start = pid_h * BLOCK_SIZE
|
||||
|
||||
local_rows = local_row_start + tl.arange(0, BLOCK_SIZE)
|
||||
local_cols = tl.arange(0, SPARSE_BLOCK_SIZE)
|
||||
|
||||
local_rows_2d = local_rows[:, None]
|
||||
local_cols_2d = local_cols[None, :]
|
||||
|
||||
store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d
|
||||
|
||||
tl.store(sparse_block_ptr + store_offset, y)
|
||||
|
||||
def compute_sbsc_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64):
|
||||
k = target_size / source_size
|
||||
PAD = math.ceil((kernel_window - 0.5) / k)
|
||||
|
||||
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
|
||||
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
M, N = coords_source.shape[0], coords_dest.shape[0]
|
||||
x = generate_sbsc_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE, BLOCK_SIZE)
|
||||
|
||||
SPARSE_BLOCKS, BLOCK_HEIGHT, _ = x.data.shape
|
||||
stride_xb, stride_xh, stride_xw = x.data.stride()
|
||||
|
||||
grid = lambda meta: (SPARSE_BLOCKS, triton.cdiv(BLOCK_HEIGHT, meta['BLOCK_SIZE']))
|
||||
compute_sbsc_coord_grid_kernel[grid](
|
||||
coords_source, coords_dest,
|
||||
x.data, x.offset,
|
||||
stride_xb, stride_xh, stride_xw,
|
||||
k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
||||
return x
|
||||
|
||||
|
||||
def downscale_sbsc(
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
||||
do_gamma_handling: bool = True,
|
||||
BLOCK_SIZE: int = 32,
|
||||
SPARSE_BLOCK_SIZE: int = 64,
|
||||
) -> torch.Tensor:
|
||||
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
||||
|
||||
T_W = target_size[-1]
|
||||
T_H = target_size[-2]
|
||||
S_W = image.shape[-1]
|
||||
S_H = image.shape[-2]
|
||||
|
||||
y_s_w = compute_sbsc_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
||||
y_s_h = compute_sbsc_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
||||
|
||||
PAD_W = math.ceil((window - 0.5) / (T_W / S_W))
|
||||
PAD_H = math.ceil((window - 0.5) / (T_H / S_H))
|
||||
|
||||
image = pad_replicate(
|
||||
image,
|
||||
PAD_H,
|
||||
PAD_W,
|
||||
fuse_linrgb=do_gamma_handling,
|
||||
sparse_block_size=SPARSE_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
image = triton_dds_sbsc(
|
||||
image,
|
||||
y_s_w,
|
||||
output_mt=True
|
||||
)
|
||||
|
||||
image = triton_dds_sbsc(
|
||||
image,
|
||||
y_s_h,
|
||||
fuse_srgb=do_gamma_handling,
|
||||
clamp_output=True,
|
||||
output_mt=True,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def downscale_sbsc_zerorhs(
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
||||
do_gamma_handling=True,
|
||||
gamma_handling_type: str = 'fast',
|
||||
BLOCK_SIZE: int = 32,
|
||||
SPARSE_BLOCK_SIZE: int = 64,
|
||||
) -> torch.Tensor:
|
||||
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
||||
|
||||
T_W = target_size[-1]
|
||||
T_H = target_size[-2]
|
||||
S_W = image.shape[-1]
|
||||
S_H = image.shape[-2]
|
||||
|
||||
block_specs_w = generate_sbsc_structure(
|
||||
T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE
|
||||
)
|
||||
|
||||
block_specs_h = generate_sbsc_structure(
|
||||
T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE
|
||||
)
|
||||
|
||||
image = triton_dds_zerorhs_sbsc(
|
||||
image,
|
||||
T_W, S_W, window, block_specs_w,
|
||||
fuse_srgb='input' if do_gamma_handling else '',
|
||||
gamma_correction=gamma_handling_type,
|
||||
output_mt=True
|
||||
)
|
||||
|
||||
image = triton_dds_zerorhs_sbsc(
|
||||
image,
|
||||
T_H, S_H, window, block_specs_h,
|
||||
fuse_srgb='output' if do_gamma_handling else '',
|
||||
gamma_correction=gamma_handling_type,
|
||||
clamp_output=True,
|
||||
output_mt=True,
|
||||
)
|
||||
|
||||
return image
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""Sharpfin utility types and color space conversion functions.
|
||||
|
||||
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import torch
|
||||
|
||||
|
||||
def srgb_to_linear(image: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
image <= 0.04045,
|
||||
image / 12.92,
|
||||
# Clamping is for protection against NaNs during backwards passes.
|
||||
((torch.clamp(image, min=0.04045) + 0.055) / 1.055) ** 2.4
|
||||
)
|
||||
|
||||
|
||||
def linear_to_srgb(image: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(
|
||||
image <= 0.0031308,
|
||||
image * 12.92,
|
||||
torch.clamp(1.055 * image ** (1 / 2.4) - 0.055, min=0.0, max=1.0)
|
||||
)
|
||||
|
||||
|
||||
class ResizeKernel(Enum):
|
||||
NEAREST = "nearest"
|
||||
BILINEAR = "bilinear"
|
||||
CATMULL_ROM = "catmull-rom"
|
||||
MITCHELL = "mitchell"
|
||||
B_SPLINE = "b-spline"
|
||||
LANCZOS2 = "lanczos2"
|
||||
LANCZOS3 = "lanczos3"
|
||||
MAGIC_KERNEL = "magic_kernel"
|
||||
MAGIC_KERNEL_SHARP_2013 = "magic_kernel_sharp_2013"
|
||||
MAGIC_KERNEL_SHARP_2021 = "magic_kernel_sharp_2021"
|
||||
|
||||
|
||||
class SharpenKernel(Enum):
|
||||
SHARP_2013 = "sharp_2013"
|
||||
SHARP_2021 = "sharp_2021"
|
||||
|
||||
|
||||
class QuantHandling(Enum):
|
||||
TRUNCATE = "truncate"
|
||||
ROUND = "round"
|
||||
STOCHASTIC_ROUND = "stochastic_round"
|
||||
BAYER = "bayer"
|
||||
|
|
@ -109,7 +109,8 @@ class Upscaler:
|
|||
if img.width >= dest_w and img.height >= dest_h:
|
||||
break
|
||||
if img.width != dest_w or img.height != dest_h:
|
||||
img = img.resize((int(dest_w), int(dest_h)), resample=Image.Resampling.LANCZOS)
|
||||
from modules import images_sharpfin
|
||||
img = images_sharpfin.resize(img, (int(dest_w), int(dest_h)))
|
||||
shared.state.end(jobid)
|
||||
return img
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ class UpscalerResize(Upscaler):
|
|||
UpscalerData("Resize Bilinear", None, self),
|
||||
UpscalerData("Resize Hamming", None, self),
|
||||
UpscalerData("Resize Box", None, self),
|
||||
UpscalerData("Resize Sharpfin MKS2021", None, self),
|
||||
UpscalerData("Resize Sharpfin Lanczos3", None, self),
|
||||
]
|
||||
|
||||
def do_upscale(self, img: Image, selected_model=None):
|
||||
|
|
@ -44,6 +46,12 @@ class UpscalerResize(Upscaler):
|
|||
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=Image.Resampling.HAMMING)
|
||||
elif selected_model == "Resize Box":
|
||||
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=Image.Resampling.BOX)
|
||||
elif selected_model == "Resize Sharpfin MKS2021":
|
||||
from modules import images_sharpfin
|
||||
return images_sharpfin.resize(img, (int(img.width * self.scale), int(img.height * self.scale)), kernel="Sharpfin MKS2021")
|
||||
elif selected_model == "Resize Sharpfin Lanczos3":
|
||||
from modules import images_sharpfin
|
||||
return images_sharpfin.resize(img, (int(img.width * self.scale), int(img.height * self.scale)), kernel="Sharpfin Lanczos3")
|
||||
else:
|
||||
return img
|
||||
|
||||
|
|
|
|||
|
|
@ -25,15 +25,15 @@ class UpscalerSpandrel(Upscaler):
|
|||
self.scalers.append(scaler)
|
||||
|
||||
def process(self, img: Image.Image) -> Image.Image:
|
||||
import torchvision.transforms.functional as TF
|
||||
tensor = TF.to_tensor(img).unsqueeze(0).to(devices.device)
|
||||
from modules import images_sharpfin
|
||||
tensor = images_sharpfin.to_tensor(img).unsqueeze(0).to(devices.device)
|
||||
img = img.convert('RGB')
|
||||
t0 = time.time()
|
||||
with devices.inference_context():
|
||||
tensor = self.model(tensor)
|
||||
tensor = tensor.clamp(0, 1).squeeze(0).cpu()
|
||||
t1 = time.time()
|
||||
upscaled = TF.to_pil_image(tensor)
|
||||
upscaled = images_sharpfin.to_pil(tensor)
|
||||
log.debug(f'Upscale: name="{self.selected}" input={img.size} output={upscaled.size} time={t1 - t0:.2f}')
|
||||
return upscaled
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,8 @@ class UpscalerAsymmetricVAE(Upscaler):
|
|||
def do_upscale(self, img: Image, selected_model=None):
|
||||
if selected_model is None:
|
||||
return img
|
||||
import torchvision.transforms.functional as F
|
||||
import diffusers
|
||||
from modules import shared, devices
|
||||
from modules import shared, devices, images_sharpfin
|
||||
if self.vae is None or (selected_model != self.selected):
|
||||
if 'v1' in selected_model:
|
||||
repo_id = 'Heasterian/AsymmetricAutoencoderKLUpscaler'
|
||||
|
|
@ -32,11 +31,11 @@ class UpscalerAsymmetricVAE(Upscaler):
|
|||
self.selected = selected_model
|
||||
shared.log.debug(f'Upscaler load: selected="{self.selected}" vae="{repo_id}"')
|
||||
t0 = time.time()
|
||||
img = img.resize((8 * (img.width // 8), 8 * (img.height // 8)), resample=Image.Resampling.LANCZOS).convert('RGB')
|
||||
tensor = (F.pil_to_tensor(img).unsqueeze(0) / 255.0).to(device=devices.device, dtype=devices.dtype)
|
||||
img = images_sharpfin.resize(img, (8 * (img.width // 8), 8 * (img.height // 8))).convert('RGB')
|
||||
tensor = images_sharpfin.to_tensor(img).unsqueeze(0).to(device=devices.device, dtype=devices.dtype)
|
||||
self.vae = self.vae.to(device=devices.device)
|
||||
tensor = self.vae(tensor).sample
|
||||
upscaled = F.to_pil_image(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
|
||||
upscaled = images_sharpfin.to_pil(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
|
||||
self.vae = self.vae.to(device=devices.cpu)
|
||||
t1 = time.time()
|
||||
shared.log.debug(f'Upscale: name="{self.selected}" input={img.size} output={upscaled.size} time={t1 - t0:.2f}')
|
||||
|
|
@ -57,10 +56,9 @@ class UpscalerWanUpscale(Upscaler):
|
|||
def do_upscale(self, img: Image, selected_model=None):
|
||||
if selected_model is None:
|
||||
return img
|
||||
import torchvision.transforms.functional as F
|
||||
import torch.nn.functional as FN
|
||||
import diffusers
|
||||
from modules import shared, devices
|
||||
from modules import shared, devices, images_sharpfin
|
||||
if (self.vae_encode is None) or (self.vae_decode is None) or (selected_model != self.selected):
|
||||
repo_encode = 'Qwen/Qwen-Image-Edit-2509'
|
||||
subfolder_encode = 'vae'
|
||||
|
|
@ -79,7 +77,7 @@ class UpscalerWanUpscale(Upscaler):
|
|||
|
||||
t0 = time.time()
|
||||
self.vae_encode = self.vae_encode.to(device=devices.device)
|
||||
tensor = (F.pil_to_tensor(img).unsqueeze(0).unsqueeze(2) / 255.0).to(device=devices.device, dtype=devices.dtype)
|
||||
tensor = images_sharpfin.to_tensor(img).unsqueeze(0).unsqueeze(2).to(device=devices.device, dtype=devices.dtype)
|
||||
tensor = self.vae_encode.encode(tensor).latent_dist.mode()
|
||||
self.vae_encode.to(device=devices.cpu)
|
||||
|
||||
|
|
@ -88,7 +86,7 @@ class UpscalerWanUpscale(Upscaler):
|
|||
tensor = FN.pixel_shuffle(tensor.movedim(2, 1), upscale_factor=2).movedim(1, 2) # pixel shuffle needs [..., C, H, W] format
|
||||
self.vae_decode.to(device=devices.cpu)
|
||||
|
||||
upscaled = F.to_pil_image(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
|
||||
upscaled = images_sharpfin.to_pil(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
|
||||
t1 = time.time()
|
||||
shared.log.debug(f'Upscale: name="{self.selected}" input={img.size} output={upscaled.size} time={t1 - t0:.2f}')
|
||||
return upscaled
|
||||
|
|
|
|||
|
|
@ -293,10 +293,9 @@ class FLitePipeline(DiffusionPipeline):
|
|||
raise
|
||||
|
||||
# 8. Post-process images
|
||||
from modules import images_sharpfin
|
||||
images = (decoded_images / 2 + 0.5).clamp(0, 1)
|
||||
# Convert to PIL Images
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu()
|
||||
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images]
|
||||
pil_images = [images_sharpfin.to_pil(img) for img in images]
|
||||
|
||||
return FLitePipelineOutput(
|
||||
images=pil_images,
|
||||
|
|
|
|||
|
|
@ -332,8 +332,8 @@ class StableCascadeDecoderPipelineFixed(diffusers.StableCascadeDecoderPipeline):
|
|||
if output_type == "np":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
elif output_type == "pil":
|
||||
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
|
||||
images = self.numpy_to_pil(images)
|
||||
from modules import images_sharpfin
|
||||
images = [images_sharpfin.to_pil(images[i]) for i in range(images.shape[0])]
|
||||
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
|
||||
else:
|
||||
images = latents
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as vF
|
||||
from modules import images_sharpfin
|
||||
import PIL
|
||||
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ def preprocess(image, processor, **kwargs):
|
|||
elif isinstance(image, np.ndarray):
|
||||
image = PIL.Image.fromarray(image)
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = vF.to_pil_image(image)
|
||||
image = images_sharpfin.to_pil(image)
|
||||
else:
|
||||
raise TypeError(f"Image must be of type PIL.Image, np.ndarray, or torch.Tensor, got {type(image)} instead.")
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from packaging import version
|
|||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
|
|
@ -859,7 +858,8 @@ class StableDiffusionXLDiffImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixi
|
|||
|
||||
# 4. Preprocess image
|
||||
#image = self.image_processor.preprocess(image) #ideally we would have preprocess the image with diffusers, but for this POC we won't --- it throws a deprecated warning
|
||||
map = torchvision.transforms.Resize(tuple(s // self.vae_scale_factor for s in original_image.shape[2:]),antialias=None)(map)
|
||||
from modules import images_sharpfin
|
||||
map = images_sharpfin.resize_tensor(map, tuple(s // self.vae_scale_factor for s in original_image.shape[2:]), linearize=False)
|
||||
# 5. Prepare timesteps
|
||||
def denoising_value_valid(dnv):
|
||||
return type(denoising_end) == float and 0 < dnv < 1
|
||||
|
|
@ -1758,7 +1758,8 @@ class StableDiffusionDiffImg2ImgPipeline(DiffusionPipeline):
|
|||
|
||||
# 7. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
map = torchvision.transforms.Resize(tuple(s // self.vae_scale_factor for s in image.shape[2:]),antialias=None)(map)
|
||||
from modules import images_sharpfin
|
||||
map = images_sharpfin.resize_tensor(map, tuple(s // self.vae_scale_factor for s in image.shape[2:]), linearize=False)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
|
@ -1833,8 +1834,7 @@ class StableDiffusionDiffImg2ImgPipeline(DiffusionPipeline):
|
|||
import gradio as gr
|
||||
import diffusers
|
||||
from PIL import Image, ImageEnhance, ImageOps # pylint: disable=reimported
|
||||
from torchvision import transforms
|
||||
from modules import errors, shared, devices, scripts_manager, processing, sd_models, images
|
||||
from modules import errors, shared, devices, scripts_manager, processing, sd_models, images, images_sharpfin
|
||||
|
||||
|
||||
detector = None
|
||||
|
|
@ -1888,9 +1888,9 @@ class Script(scripts_manager.Script):
|
|||
else:
|
||||
return None, None, None
|
||||
image_mask = image_map.copy()
|
||||
image_map = transforms.ToTensor()(image_map)
|
||||
image_map = images_sharpfin.to_tensor(image_map)
|
||||
image_map = image_map.to(devices.device)
|
||||
image_init = 2 * transforms.ToTensor()(image_init) - 1
|
||||
image_init = 2 * images_sharpfin.to_tensor(image_init) - 1
|
||||
image_init = image_init.unsqueeze(0)
|
||||
image_init = image_init.to(devices.device)
|
||||
return image_init, image_map, image_mask
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class Script(scripts_manager.Script):
|
|||
from installer import install
|
||||
install('lpips')
|
||||
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
from modules import images_sharpfin
|
||||
from scripts.lbm import get_model, extract_object, resize_and_center_crop # pylint: disable=no-name-in-module
|
||||
|
||||
ori_h_bg, ori_w_bg = fg_image.size
|
||||
|
|
@ -110,7 +110,7 @@ class Script(scripts_manager.Script):
|
|||
if lbm_method == 'Simple':
|
||||
output_image = img_pasted
|
||||
else:
|
||||
img_pasted_tensor = ToTensor()(img_pasted).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) * 2 - 1
|
||||
img_pasted_tensor = images_sharpfin.to_tensor(img_pasted).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) * 2 - 1
|
||||
batch = { "source_image": img_pasted_tensor }
|
||||
z_source = model.vae.encode(batch[model.source_key])
|
||||
output_image = model.sample(
|
||||
|
|
@ -120,7 +120,7 @@ class Script(scripts_manager.Script):
|
|||
max_samples=1,
|
||||
)
|
||||
output_image = (output_image[0].clamp(-1, 1).float().cpu() + 1) / 2
|
||||
output_image = ToPILImage()(output_image)
|
||||
output_image = images_sharpfin.to_pil(output_image)
|
||||
if lbm_composite:
|
||||
output_image = Image.composite(output_image, bg_image, fg_mask)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,17 +26,13 @@ class Script(scripts_manager.Script):
|
|||
def encode(self, p: processing.StableDiffusionProcessing, image: Image.Image):
|
||||
if image is None:
|
||||
return None
|
||||
import numpy as np
|
||||
import torch
|
||||
from modules import images_sharpfin
|
||||
if p.width is None or p.width == 0:
|
||||
p.width = int(8 * (image.width * p.scale_by // 8))
|
||||
if p.height is None or p.height == 0:
|
||||
p.height = int(8 * (image.height * p.scale_by // 8))
|
||||
image = images.resize_image(p.resize_mode, image, p.width, p.height, upscaler_name=p.resize_name, context=p.resize_context)
|
||||
tensor = np.array(image).astype(np.float16) / 255.0
|
||||
tensor = tensor[None].transpose(0, 3, 1, 2)
|
||||
# image = image.transpose(0, 3, 1, 2)
|
||||
tensor = torch.from_numpy(tensor).to(device=devices.device, dtype=devices.dtype)
|
||||
tensor = images_sharpfin.to_tensor(image).unsqueeze(0).to(device=devices.device, dtype=devices.dtype)
|
||||
tensor = 2.0 * tensor - 1.0
|
||||
with devices.inference_context():
|
||||
latent = shared.sd_model.vae.tiled_encode(tensor)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ import cv2
|
|||
import numpy as np
|
||||
from PIL import Image, ImageFilter
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
|
|
@ -1323,7 +1322,8 @@ class StableDiffusionXLSoftFillPipeline(
|
|||
image.save("noised_image.png")
|
||||
|
||||
image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
|
||||
image = transforms.ToTensor()(image)
|
||||
from modules import images_sharpfin
|
||||
image = images_sharpfin.to_tensor(image)
|
||||
image = image * 2 - 1 # Normalize to [-1, 1]
|
||||
return image.unsqueeze(0)
|
||||
|
||||
|
|
@ -1334,7 +1334,8 @@ class StableDiffusionXLSoftFillPipeline(
|
|||
"""
|
||||
map = map.convert("L")
|
||||
map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
|
||||
map = transforms.ToTensor()(map)
|
||||
from modules import images_sharpfin
|
||||
map = images_sharpfin.to_tensor(map)
|
||||
map = (map - 0.05) / (0.95 - 0.05)
|
||||
map = torch.clamp(map, 0.0, 1.0)
|
||||
return 1.0 - map
|
||||
|
|
@ -1349,9 +1350,8 @@ class StableDiffusionXLSoftFillPipeline(
|
|||
|
||||
# Prepare mask as rescaled tensor map
|
||||
map = preprocess_map(mask).to(device)
|
||||
map = torchvision.transforms.Resize(
|
||||
tuple(s // self.vae_scale_factor for s in original_image_tensor.shape[2:]), antialias=None
|
||||
)(map)
|
||||
from modules import images_sharpfin
|
||||
map = images_sharpfin.resize_tensor(map, tuple(s // self.vae_scale_factor for s in original_image_tensor.shape[2:]), linearize=False)
|
||||
|
||||
# Generate latent tensor with noise
|
||||
original_with_noise = self.prepare_latents(
|
||||
|
|
|
|||
Loading…
Reference in New Issue