mirror of https://github.com/vladmandic/automatic
709 lines
22 KiB
Python
709 lines
22 KiB
Python
"""Sharpfin Triton-accelerated GPU scaling functions.
|
|
|
|
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
|
|
Imports patched: absolute sharpfin.X -> relative .X
|
|
Requires: triton (only available on CUDA platforms)
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from .util import ResizeKernel
|
|
from typing import Tuple
|
|
import torch.nn.functional as F
|
|
from triton.language.extra import libdevice
|
|
from .util import linear_to_srgb, srgb_to_linear
|
|
|
|
# Magic Kernel Sharp with Triton optimizations. Mainly converted to polynomials so that
|
|
# FMA operators can be used.
|
|
@triton.jit
|
|
def magic_kernel_sharp_2021_triton(x: tl.tensor):
|
|
out = tl.zeros_like(x) # inplace operation doesn't help much.
|
|
x = tl.abs(x)
|
|
|
|
lte_05 = x <= 0.5
|
|
lte_15 = x <= 1.5
|
|
lte_25 = x <= 2.5
|
|
lte_35 = x <= 3.5
|
|
lte_45 = x <= 4.5
|
|
|
|
x_sq = x*x # triton would compile like this anyways but it helps readability
|
|
|
|
out = tl.where(lte_05, tl.fma(x_sq, -239/144, 577/576), out)
|
|
out = tl.where(lte_15 and not lte_05, tl.fma(x_sq, 35/36, tl.fma(x, -379/144, 239/144)), out)
|
|
out = tl.where(lte_25 and not lte_15, tl.fma(x_sq, -1/6, tl.fma(x, 113/144, -65/72)), out)
|
|
out = tl.where(lte_35 and not lte_25, tl.fma(x_sq, 1/36, tl.fma(x, -3/16, 5/16)), out)
|
|
out = tl.where(lte_45 and not lte_35, tl.fma(x_sq, -1/288, tl.fma(x, 1/32, -9/128)), out)
|
|
|
|
return out
|
|
|
|
@triton.jit
|
|
def sinc_triton(x: tl.tensor):
|
|
y = tl.fma(x, math.pi, 1e-8)
|
|
return libdevice.fast_sinf(y) / y
|
|
|
|
@triton.jit
|
|
def lanczos_triton(x: tl.tensor, n: tl.constexpr = 3):
|
|
return tl.where(
|
|
tl.abs(x) < n,
|
|
sinc_triton(x) * sinc_triton(x/n),
|
|
0
|
|
)
|
|
|
|
# NOTE: there is no reason to use libdevice.pow, its only differences are with subnormals
|
|
@triton.jit
|
|
def linear_to_srgb_triton(x):
|
|
return tl.where(
|
|
x <= 0.0031308,
|
|
x * 12.92,
|
|
tl.fma(1.055, libdevice.fast_powf(x, 1/2.4), -0.055)
|
|
)
|
|
|
|
@triton.jit
|
|
def srgb_to_linear_triton(x):
|
|
return tl.where(
|
|
x <= 0.04045,
|
|
x / 12.92,
|
|
libdevice.fast_powf(tl.fma(1/1.055, x, 0.055/1.055), 2.4)
|
|
)
|
|
|
|
from .sparse_backend import triton_dds, triton_dds_sbsc, triton_dds_zerorhs_sbsc, Matrix, SBSCMatrix
|
|
|
|
def _get_resize_kernel_triton(k: ResizeKernel):
|
|
match k:
|
|
case ResizeKernel.NEAREST:
|
|
raise NotImplementedError
|
|
case ResizeKernel.BILINEAR:
|
|
raise NotImplementedError
|
|
case ResizeKernel.MITCHELL:
|
|
raise NotImplementedError
|
|
case ResizeKernel.CATMULL_ROM:
|
|
raise NotImplementedError
|
|
case ResizeKernel.B_SPLINE:
|
|
raise NotImplementedError
|
|
case ResizeKernel.LANCZOS2:
|
|
raise NotImplementedError
|
|
case ResizeKernel.LANCZOS3:
|
|
resize_kernel = lanczos_triton
|
|
kernel_window = 3.
|
|
case ResizeKernel.MAGIC_KERNEL:
|
|
raise NotImplementedError
|
|
case ResizeKernel.MAGIC_KERNEL_SHARP_2013:
|
|
raise NotImplementedError
|
|
case ResizeKernel.MAGIC_KERNEL_SHARP_2021:
|
|
resize_kernel = magic_kernel_sharp_2021_triton
|
|
kernel_window = 4.5
|
|
case _:
|
|
raise ValueError(f"Unknown resize kernel {k}")
|
|
return resize_kernel, kernel_window
|
|
|
|
# Sparse Downscale and support functions.
|
|
|
|
# Amanatides, John and Woo, Andrew -- Fast Voxel Traversal
|
|
def grid_line_tiles(x0, y0, x1, y1, grid_width, grid_height):
|
|
tiles = set()
|
|
|
|
dx = x1 - x0
|
|
dy = y1 - y0
|
|
|
|
x = math.floor(x0)
|
|
y = math.floor(y0)
|
|
|
|
end_x = math.floor(x1)
|
|
end_y = math.floor(y1)
|
|
|
|
step_x = 1 if dx > 0 else -1
|
|
step_y = 1 if dy > 0 else -1
|
|
|
|
t_max_x = ((x + (step_x > 0)) - x0) / dx if dx != 0 else float('inf')
|
|
t_max_y = ((y + (step_y > 0)) - y0) / dy if dy != 0 else float('inf')
|
|
|
|
t_delta_x = abs(1 / dx) if dx != 0 else float('inf')
|
|
t_delta_y = abs(1 / dy) if dy != 0 else float('inf')
|
|
|
|
while True:
|
|
if 0 <= x < grid_width and 0 <= y < grid_height:
|
|
tiles.add((y,x))
|
|
if x == end_x and y == end_y:
|
|
break
|
|
if t_max_x < t_max_y:
|
|
t_max_x += t_delta_x
|
|
x += step_x
|
|
else:
|
|
t_max_y += t_delta_y
|
|
y += step_y
|
|
|
|
return tiles
|
|
|
|
def tile_mask_function(dest_size, src_size, kernel_window=4.5, tile_size=64):
|
|
k = dest_size / src_size
|
|
PAD = math.ceil((kernel_window-0.5) / k)
|
|
|
|
grid_size = math.ceil((src_size + 2*PAD)/tile_size), math.ceil(dest_size/tile_size)
|
|
|
|
line_1 = 0, 0.5/tile_size, (dest_size)/tile_size, (src_size+0.5)/tile_size
|
|
line_2 = 0, (2*PAD - 0.5)/tile_size, (dest_size)/tile_size, (src_size + 2*PAD - 0.5)/tile_size
|
|
lines = line_1, line_2
|
|
|
|
mask = torch.zeros(grid_size, dtype=torch.bool)
|
|
|
|
tiles = set()
|
|
|
|
for (x0, y0, x1, y1) in lines:
|
|
tiles.update(grid_line_tiles(x0, y0, x1, y1, grid_size[1], grid_size[0]))
|
|
|
|
tiles = torch.tensor(list(tiles))
|
|
|
|
mask[tiles[:,0], tiles[:,1]] = True
|
|
|
|
return mask, tiles
|
|
|
|
def create_tensor_metadata(
|
|
tile_mask: torch.Tensor,
|
|
tiles: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
offsets_t: torch.Tensor,
|
|
):
|
|
|
|
indices[:,:2] = tiles
|
|
|
|
torch.argsort(indices[:,1], stable=True, out=indices[:,2]) # block_offsets_t
|
|
torch.take(indices[:,0], indices[:,2], out=indices[:,3]) # col_indices_t
|
|
|
|
# reusing the offsets buffer here helps performance
|
|
torch.sum(tile_mask, dim=1, out=offsets[1:])
|
|
torch.sum(tile_mask, dim=0, out=offsets_t[1:])
|
|
torch.cumsum(offsets, dim=0, out=offsets)
|
|
torch.cumsum(offsets_t, dim=0, out=offsets_t)
|
|
|
|
return indices, offsets, offsets_t
|
|
|
|
# for isolating the one mandatory graph break
|
|
@torch.compiler.disable
|
|
def _get_nnz_and_buffers(tile_mask):
|
|
num_sparse_blocks = torch.sum(tile_mask).item()
|
|
|
|
return [
|
|
torch.empty((4, num_sparse_blocks), dtype=torch.int64, pin_memory=True).T, # indices
|
|
torch.zeros((tile_mask.shape[0] + 1,), dtype=torch.int32, pin_memory=True), # offsets
|
|
torch.zeros((tile_mask.shape[1] + 1,), dtype=torch.int32, pin_memory=True) # offsets_t
|
|
]
|
|
|
|
|
|
def generate_sparse_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64):
|
|
tile_mask, tiles = tile_mask_function(dest_size, src_size, kernel_window, tile_size)
|
|
|
|
buffers = _get_nnz_and_buffers(tile_mask)
|
|
num_sparse_blocks = buffers[0].shape[0]
|
|
|
|
indices, offsets, offsets_t = create_tensor_metadata(
|
|
tile_mask,
|
|
tiles,
|
|
*buffers
|
|
)
|
|
|
|
indices = indices.to(device='cuda', dtype=torch.int32, non_blocking=True)
|
|
|
|
return Matrix(
|
|
(tile_mask.shape[0] * tile_size, tile_mask.shape[1] * tile_size),
|
|
torch.empty(num_sparse_blocks, tile_size, tile_size, dtype=torch.float16, device='cuda'),
|
|
row_indices=indices[:,0],
|
|
column_indices=indices[:,1],
|
|
offsets=offsets.to(device='cuda', non_blocking=True),
|
|
column_indices_t=indices[:,3],
|
|
offsets_t=offsets_t.to(device='cuda', non_blocking=True),
|
|
block_offsets_t=indices[:,2]
|
|
)
|
|
|
|
@triton.jit
|
|
def compute_sparse_coord_grid_kernel(
|
|
coords_source_ptr, coords_dest_ptr, sparse_data_ptr,
|
|
row_indices_ptr, col_indices_ptr,
|
|
k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr
|
|
):
|
|
SPARSE_BLOCK_NUMEL = SPARSE_BLOCK_SIZE * SPARSE_BLOCK_SIZE
|
|
sparse_block = tl.program_id(0)
|
|
|
|
tile_row = tl.program_id(1)
|
|
tile_col = tl.program_id(2)
|
|
|
|
row_offsets = tl.load(row_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_row * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
col_offsets = tl.load(col_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
|
|
mask_row = row_offsets < M
|
|
mask_col = col_offsets < N
|
|
|
|
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0)
|
|
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0)
|
|
|
|
x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
|
|
|
|
x = magic_kernel_sharp_2021_triton(x)
|
|
|
|
x *= k
|
|
|
|
sparse_block_ptr = sparse_data_ptr + sparse_block * SPARSE_BLOCK_NUMEL
|
|
|
|
local_row_start = tile_row * BLOCK_SIZE
|
|
local_col_start = tile_col * BLOCK_SIZE
|
|
|
|
local_rows = local_row_start + tl.arange(0, BLOCK_SIZE)
|
|
local_cols = local_col_start + tl.arange(0, BLOCK_SIZE)
|
|
|
|
local_rows_2d = local_rows[:, None]
|
|
local_cols_2d = local_cols[None, :]
|
|
|
|
store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d
|
|
|
|
tl.store(sparse_block_ptr + store_offset, x)
|
|
|
|
def compute_sparse_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64):
|
|
assert SPARSE_BLOCK_SIZE % BLOCK_SIZE == 0
|
|
|
|
k = target_size / source_size
|
|
PAD = math.ceil((kernel_window - 0.5) / k)
|
|
|
|
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
|
|
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
|
|
|
|
M, N = coords_source.shape[0], coords_dest.shape[0]
|
|
x = generate_sparse_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE)
|
|
|
|
SPARSE_NUM_BLOCKS = x.data.shape[0]
|
|
|
|
grid = lambda meta: (SPARSE_NUM_BLOCKS, triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']), triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']))
|
|
compute_sparse_coord_grid_kernel[grid](
|
|
coords_source, coords_dest, x.data,
|
|
x.row_indices, x.column_indices,
|
|
k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
|
return x
|
|
|
|
# Dense kernel for downsampling coord_grids
|
|
|
|
@triton.jit
|
|
def compute_coord_grid_kernel(
|
|
coords_source_ptr, coords_dest_ptr, coord_grid_ptr, k,
|
|
M, N, BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
row_offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
col_offsets = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
|
|
mask_row = row_offsets < M
|
|
mask_col = col_offsets < N
|
|
|
|
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row)
|
|
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col)
|
|
|
|
x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
|
|
|
|
x = magic_kernel_sharp_2021_triton(x)
|
|
|
|
x *= k
|
|
|
|
tl.store(coord_grid_ptr + row_offsets[:, None] * N + col_offsets[None, :], x, mask=mask_row[:, None] & mask_col[None, :])
|
|
|
|
def compute_coord_grid(target_size, source_size, kernel_window=4.5, BLOCK_SIZE=32):
|
|
k = target_size / source_size
|
|
PAD = math.ceil((kernel_window - 0.5) / k)
|
|
|
|
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
|
|
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
|
|
|
|
M, N = coords_source.shape[0], coords_dest.shape[0]
|
|
coord_grid = torch.empty((M, N), dtype=torch.float16, device='cuda')
|
|
|
|
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
|
compute_coord_grid_kernel[grid](coords_source, coords_dest, coord_grid, k, M, N, BLOCK_SIZE)
|
|
return coord_grid
|
|
|
|
@triton.jit
|
|
def pad_replicate_kernel(
|
|
A, B,
|
|
M_X, N_X,
|
|
M_Y, N_Y,
|
|
M_PAD, N_PAD,
|
|
stride_xc, stride_xm, stride_xn,
|
|
stride_yc, stride_ym, stride_yn,
|
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
|
fuse_linrgb: tl.constexpr
|
|
):
|
|
pid_c = tl.program_id(0)
|
|
pid_m = tl.program_id(1)
|
|
pid_n = tl.program_id(2)
|
|
|
|
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
offs_m_cl = tl.maximum(offs_m, M_PAD) - M_PAD
|
|
offs_m_cl = tl.minimum(offs_m_cl, M_X - 1)
|
|
offs_n_cl = tl.maximum(offs_n, N_PAD) - N_PAD
|
|
offs_n_cl = tl.minimum(offs_n_cl, N_X - 1)
|
|
|
|
mask_m = offs_m < M_Y
|
|
mask_n = offs_n < N_Y
|
|
|
|
A_ptr = A + pid_c * stride_xc + offs_m_cl[:, None] * stride_xm + offs_n_cl[None, :] * stride_xn
|
|
B_ptr = B + pid_c * stride_yc + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
|
|
|
|
t = tl.load(A_ptr)
|
|
if fuse_linrgb:
|
|
t = srgb_to_linear_triton(t)
|
|
|
|
tl.store(B_ptr, t, mask=mask_m[:, None] & mask_n[None, :])
|
|
|
|
def pad_replicate(
|
|
img: torch.Tensor,
|
|
pad_h: int,
|
|
pad_w: int,
|
|
sparse_block_size: int = 0,
|
|
fuse_linrgb: bool = True,
|
|
):
|
|
C = img.shape[0]
|
|
|
|
M_PAD = pad_h
|
|
N_PAD = pad_w
|
|
|
|
if sparse_block_size != 0:
|
|
out_H = img.shape[-2] + M_PAD + (-(img.shape[-2] + M_PAD)) % sparse_block_size
|
|
out_W = img.shape[-1] + N_PAD + (-(img.shape[-1] + N_PAD)) % sparse_block_size
|
|
else:
|
|
out_H = img.shape[-2] + M_PAD + M_PAD
|
|
out_W = img.shape[-1] + N_PAD + N_PAD
|
|
|
|
out = torch.empty(C, out_H, out_W, dtype=img.dtype, device=img.device)
|
|
|
|
BLOCK_M = 1
|
|
BLOCK_N = 512
|
|
|
|
grid = lambda META: (
|
|
C,
|
|
(out.shape[1] + META['BLOCK_M'] - 1) // META['BLOCK_M'],
|
|
(out.shape[2] + META['BLOCK_N'] - 1) // META['BLOCK_N'],
|
|
)
|
|
|
|
pad_replicate_kernel[grid](
|
|
img, out,
|
|
img.shape[1], img.shape[2],
|
|
out.shape[1], out.shape[2],
|
|
M_PAD, N_PAD,
|
|
img.stride(0), img.stride(1), img.stride(2),
|
|
out.stride(0), out.stride(1), out.stride(2),
|
|
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
|
fuse_linrgb=fuse_linrgb,
|
|
)
|
|
return out
|
|
|
|
def downscale_sparse(
|
|
image: torch.Tensor,
|
|
target_size: Tuple[int, int],
|
|
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
|
do_gamma_handling=True,
|
|
BLOCK_SIZE: int = 32,
|
|
SPARSE_BLOCK_SIZE: int = 64,
|
|
) -> torch.Tensor:
|
|
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
|
|
|
T_W = target_size[-1]
|
|
T_H = target_size[-2]
|
|
S_W = image.shape[-1]
|
|
S_H = image.shape[-2]
|
|
|
|
y_s_w = compute_sparse_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
|
y_s_h = compute_sparse_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
|
|
|
PAD_W = math.ceil((window - 0.5) / (T_W / S_W))
|
|
PAD_H = math.ceil((window - 0.5) / (T_H / S_H))
|
|
|
|
image = pad_replicate(
|
|
image,
|
|
PAD_H,
|
|
PAD_W,
|
|
SPARSE_BLOCK_SIZE,
|
|
fuse_linrgb=do_gamma_handling
|
|
)
|
|
|
|
image = triton_dds(
|
|
image,
|
|
y_s_w,
|
|
output_mt=True
|
|
)
|
|
|
|
image = triton_dds(
|
|
image,
|
|
y_s_h,
|
|
fuse_srgb=do_gamma_handling,
|
|
clamp_output=True,
|
|
output_mt=True,
|
|
output_slice=(T_H, T_W)
|
|
)
|
|
|
|
return image
|
|
|
|
def downscale_triton(
|
|
image: torch.Tensor,
|
|
target_size: torch.Size,
|
|
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
|
do_gamma_handling=True,
|
|
) -> torch.Tensor:
|
|
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
|
|
|
y_s_w = compute_coord_grid(target_size[-1], image.shape[-1], window)
|
|
y_s_h = compute_coord_grid(target_size[-2], image.shape[-2], window)
|
|
|
|
PAD_W = math.ceil((window - 0.5) / (target_size[-1] / image.shape[-1]))
|
|
PAD_H = math.ceil((window - 0.5) / (target_size[-2] / image.shape[-2]))
|
|
|
|
image = pad_replicate(image, PAD_H, PAD_W, fuse_linrgb=do_gamma_handling)
|
|
|
|
image = image.view(-1, image.shape[-1])
|
|
image = image @ y_s_w
|
|
image = image.view(3, -1, image.shape[-1])
|
|
image = image.mT
|
|
image = image.reshape(-1, image.shape[-1])
|
|
image = image @ y_s_h
|
|
image = image.view(3, -1, image.shape[-1])
|
|
image = image.mT
|
|
if do_gamma_handling:
|
|
image = linear_to_srgb(image[:, :target_size[0], :target_size[1]])
|
|
image.clamp_(0.,1.)
|
|
return image
|
|
|
|
# Single Block Sparse Column implementations.
|
|
|
|
def evaluate_line(x, x0, y0, x1, y1):
|
|
"""Evaluate the y-coordinate at a given x along a line from (x0, y0) to (x1, y1)."""
|
|
if x1 == x0:
|
|
return float('inf')
|
|
t = (x - x0) / (x1 - x0)
|
|
return y0 + t * (y1 - y0)
|
|
|
|
def pad_height_to_multiple(height, multiple):
|
|
"""Pad a height up to the next multiple of 'multiple'."""
|
|
return int(math.ceil(height / multiple) * multiple)
|
|
|
|
def generate_sbsc_structure(
|
|
dest_size,
|
|
src_size,
|
|
kernel_window=4.5,
|
|
tile_size=64,
|
|
y_tile_size=32
|
|
):
|
|
k = dest_size / src_size
|
|
PAD = math.ceil((kernel_window - 0.5) / k)
|
|
|
|
line1 = (0, 0.5, dest_size, src_size + 0.5)
|
|
line2 = (0, 2 * PAD - 0.5, dest_size, src_size + 2 * PAD - 0.5)
|
|
|
|
y_mins = []
|
|
y_maxs = []
|
|
n_blocks = math.ceil(dest_size / tile_size)
|
|
max_height = 0
|
|
|
|
for i in range(n_blocks):
|
|
x0 = i * tile_size
|
|
x1 = min(dest_size - 1, x0 + tile_size - 1)
|
|
|
|
yt0 = evaluate_line(x0, *line1)
|
|
yt1 = evaluate_line(x1, *line1)
|
|
yb0 = evaluate_line(x0, *line2)
|
|
yb1 = evaluate_line(x1, *line2)
|
|
|
|
y_min = min(yt0, yt1)
|
|
y_max = max(yb0, yb1)
|
|
|
|
height = y_max - y_min
|
|
padded = pad_height_to_multiple(height, y_tile_size)
|
|
|
|
y_mins.append(y_min)
|
|
y_maxs.append(y_max)
|
|
max_height = max(max_height, padded)
|
|
|
|
slope_top = (line1[3] - line1[1]) / (line1[2] - line1[0])
|
|
ideal_step = slope_top * tile_size
|
|
|
|
lower_bounds = []
|
|
upper_bounds = []
|
|
for i in range(1, n_blocks):
|
|
lower_bounds.append((y_maxs[i] - max_height) / i)
|
|
upper_bounds.append(y_mins[i] / i)
|
|
|
|
lower = math.ceil(max(lower_bounds)) if lower_bounds else 0
|
|
upper = math.floor(min(upper_bounds)) if upper_bounds else int(round(ideal_step))
|
|
|
|
fixed_offset = int(round(ideal_step))
|
|
if fixed_offset < lower:
|
|
fixed_offset = lower
|
|
elif fixed_offset > upper:
|
|
fixed_offset = upper
|
|
|
|
return fixed_offset, max_height, n_blocks, tile_size
|
|
|
|
|
|
def generate_sbsc_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64, y_tile_size=32):
|
|
offset, block_height, num_blocks, col_width = generate_sbsc_structure(
|
|
dest_size, src_size, kernel_window, tile_size, y_tile_size
|
|
)
|
|
|
|
return SBSCMatrix(
|
|
size=((offset * (num_blocks - 1)) + block_height, dest_size),
|
|
data=torch.empty((num_blocks, block_height, col_width), dtype=torch.float16, device='cuda'),
|
|
offset=offset,
|
|
block_size=y_tile_size
|
|
)
|
|
|
|
@triton.jit
|
|
def compute_sbsc_coord_grid_kernel(
|
|
coords_source_ptr, coords_dest_ptr,
|
|
sparse_data_ptr, offset: tl.constexpr,
|
|
stride_xb, stride_xw, stride_xh,
|
|
k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr
|
|
):
|
|
|
|
pid_w = tl.program_id(0)
|
|
pid_h = tl.program_id(1)
|
|
|
|
start_row = offset * pid_w + pid_h * BLOCK_SIZE
|
|
start_col = pid_w * SPARSE_BLOCK_SIZE
|
|
|
|
row_offsets = start_row + tl.arange(0, BLOCK_SIZE)
|
|
col_offsets = start_col + tl.arange(0, SPARSE_BLOCK_SIZE)
|
|
|
|
mask_row = row_offsets < M
|
|
mask_col = col_offsets < N
|
|
|
|
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0)
|
|
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0)
|
|
|
|
y = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
|
|
|
|
y = magic_kernel_sharp_2021_triton(y)
|
|
|
|
y *= k
|
|
|
|
sparse_block_ptr = sparse_data_ptr + pid_w * stride_xb
|
|
|
|
local_row_start = pid_h * BLOCK_SIZE
|
|
|
|
local_rows = local_row_start + tl.arange(0, BLOCK_SIZE)
|
|
local_cols = tl.arange(0, SPARSE_BLOCK_SIZE)
|
|
|
|
local_rows_2d = local_rows[:, None]
|
|
local_cols_2d = local_cols[None, :]
|
|
|
|
store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d
|
|
|
|
tl.store(sparse_block_ptr + store_offset, y)
|
|
|
|
def compute_sbsc_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64):
|
|
k = target_size / source_size
|
|
PAD = math.ceil((kernel_window - 0.5) / k)
|
|
|
|
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
|
|
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
|
|
|
|
M, N = coords_source.shape[0], coords_dest.shape[0]
|
|
x = generate_sbsc_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE, BLOCK_SIZE)
|
|
|
|
SPARSE_BLOCKS, BLOCK_HEIGHT, _ = x.data.shape
|
|
stride_xb, stride_xh, stride_xw = x.data.stride()
|
|
|
|
grid = lambda meta: (SPARSE_BLOCKS, triton.cdiv(BLOCK_HEIGHT, meta['BLOCK_SIZE']))
|
|
compute_sbsc_coord_grid_kernel[grid](
|
|
coords_source, coords_dest,
|
|
x.data, x.offset,
|
|
stride_xb, stride_xh, stride_xw,
|
|
k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
|
return x
|
|
|
|
|
|
def downscale_sbsc(
|
|
image: torch.Tensor,
|
|
target_size: Tuple[int, int],
|
|
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
|
do_gamma_handling: bool = True,
|
|
BLOCK_SIZE: int = 32,
|
|
SPARSE_BLOCK_SIZE: int = 64,
|
|
) -> torch.Tensor:
|
|
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
|
|
|
T_W = target_size[-1]
|
|
T_H = target_size[-2]
|
|
S_W = image.shape[-1]
|
|
S_H = image.shape[-2]
|
|
|
|
y_s_w = compute_sbsc_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
|
y_s_h = compute_sbsc_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
|
|
|
|
PAD_W = math.ceil((window - 0.5) / (T_W / S_W))
|
|
PAD_H = math.ceil((window - 0.5) / (T_H / S_H))
|
|
|
|
image = pad_replicate(
|
|
image,
|
|
PAD_H,
|
|
PAD_W,
|
|
fuse_linrgb=do_gamma_handling,
|
|
sparse_block_size=SPARSE_BLOCK_SIZE,
|
|
)
|
|
|
|
image = triton_dds_sbsc(
|
|
image,
|
|
y_s_w,
|
|
output_mt=True
|
|
)
|
|
|
|
image = triton_dds_sbsc(
|
|
image,
|
|
y_s_h,
|
|
fuse_srgb=do_gamma_handling,
|
|
clamp_output=True,
|
|
output_mt=True,
|
|
)
|
|
|
|
return image
|
|
|
|
|
|
def downscale_sbsc_zerorhs(
|
|
image: torch.Tensor,
|
|
target_size: Tuple[int, int],
|
|
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
|
|
do_gamma_handling=True,
|
|
gamma_handling_type: str = 'fast',
|
|
BLOCK_SIZE: int = 32,
|
|
SPARSE_BLOCK_SIZE: int = 64,
|
|
) -> torch.Tensor:
|
|
kernel, window = _get_resize_kernel_triton(resize_kernel)
|
|
|
|
T_W = target_size[-1]
|
|
T_H = target_size[-2]
|
|
S_W = image.shape[-1]
|
|
S_H = image.shape[-2]
|
|
|
|
block_specs_w = generate_sbsc_structure(
|
|
T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE
|
|
)
|
|
|
|
block_specs_h = generate_sbsc_structure(
|
|
T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE
|
|
)
|
|
|
|
image = triton_dds_zerorhs_sbsc(
|
|
image,
|
|
T_W, S_W, window, block_specs_w,
|
|
fuse_srgb='input' if do_gamma_handling else '',
|
|
gamma_correction=gamma_handling_type,
|
|
output_mt=True
|
|
)
|
|
|
|
image = triton_dds_zerorhs_sbsc(
|
|
image,
|
|
T_H, S_H, window, block_specs_h,
|
|
fuse_srgb='output' if do_gamma_handling else '',
|
|
gamma_correction=gamma_handling_type,
|
|
clamp_output=True,
|
|
output_mt=True,
|
|
)
|
|
|
|
return image
|