"""Sharpfin Triton-accelerated GPU scaling functions. Vendored from https://github.com/drhead/sharpfin (Apache 2.0) Imports patched: absolute sharpfin.X -> relative .X Requires: triton (only available on CUDA platforms) """ import torch import math import triton import triton.language as tl from .util import ResizeKernel from typing import Tuple import torch.nn.functional as F from triton.language.extra import libdevice from .util import linear_to_srgb, srgb_to_linear # Magic Kernel Sharp with Triton optimizations. Mainly converted to polynomials so that # FMA operators can be used. @triton.jit def magic_kernel_sharp_2021_triton(x: tl.tensor): out = tl.zeros_like(x) # inplace operation doesn't help much. x = tl.abs(x) lte_05 = x <= 0.5 lte_15 = x <= 1.5 lte_25 = x <= 2.5 lte_35 = x <= 3.5 lte_45 = x <= 4.5 x_sq = x*x # triton would compile like this anyways but it helps readability out = tl.where(lte_05, tl.fma(x_sq, -239/144, 577/576), out) out = tl.where(lte_15 and not lte_05, tl.fma(x_sq, 35/36, tl.fma(x, -379/144, 239/144)), out) out = tl.where(lte_25 and not lte_15, tl.fma(x_sq, -1/6, tl.fma(x, 113/144, -65/72)), out) out = tl.where(lte_35 and not lte_25, tl.fma(x_sq, 1/36, tl.fma(x, -3/16, 5/16)), out) out = tl.where(lte_45 and not lte_35, tl.fma(x_sq, -1/288, tl.fma(x, 1/32, -9/128)), out) return out @triton.jit def sinc_triton(x: tl.tensor): y = tl.fma(x, math.pi, 1e-8) return libdevice.fast_sinf(y) / y @triton.jit def lanczos_triton(x: tl.tensor, n: tl.constexpr = 3): return tl.where( tl.abs(x) < n, sinc_triton(x) * sinc_triton(x/n), 0 ) # NOTE: there is no reason to use libdevice.pow, its only differences are with subnormals @triton.jit def linear_to_srgb_triton(x): return tl.where( x <= 0.0031308, x * 12.92, tl.fma(1.055, libdevice.fast_powf(x, 1/2.4), -0.055) ) @triton.jit def srgb_to_linear_triton(x): return tl.where( x <= 0.04045, x / 12.92, libdevice.fast_powf(tl.fma(1/1.055, x, 0.055/1.055), 2.4) ) from .sparse_backend import triton_dds, triton_dds_sbsc, triton_dds_zerorhs_sbsc, Matrix, SBSCMatrix def _get_resize_kernel_triton(k: ResizeKernel): match k: case ResizeKernel.NEAREST: raise NotImplementedError case ResizeKernel.BILINEAR: raise NotImplementedError case ResizeKernel.MITCHELL: raise NotImplementedError case ResizeKernel.CATMULL_ROM: raise NotImplementedError case ResizeKernel.B_SPLINE: raise NotImplementedError case ResizeKernel.LANCZOS2: raise NotImplementedError case ResizeKernel.LANCZOS3: resize_kernel = lanczos_triton kernel_window = 3. case ResizeKernel.MAGIC_KERNEL: raise NotImplementedError case ResizeKernel.MAGIC_KERNEL_SHARP_2013: raise NotImplementedError case ResizeKernel.MAGIC_KERNEL_SHARP_2021: resize_kernel = magic_kernel_sharp_2021_triton kernel_window = 4.5 case _: raise ValueError(f"Unknown resize kernel {k}") return resize_kernel, kernel_window # Sparse Downscale and support functions. # Amanatides, John and Woo, Andrew -- Fast Voxel Traversal def grid_line_tiles(x0, y0, x1, y1, grid_width, grid_height): tiles = set() dx = x1 - x0 dy = y1 - y0 x = math.floor(x0) y = math.floor(y0) end_x = math.floor(x1) end_y = math.floor(y1) step_x = 1 if dx > 0 else -1 step_y = 1 if dy > 0 else -1 t_max_x = ((x + (step_x > 0)) - x0) / dx if dx != 0 else float('inf') t_max_y = ((y + (step_y > 0)) - y0) / dy if dy != 0 else float('inf') t_delta_x = abs(1 / dx) if dx != 0 else float('inf') t_delta_y = abs(1 / dy) if dy != 0 else float('inf') while True: if 0 <= x < grid_width and 0 <= y < grid_height: tiles.add((y,x)) if x == end_x and y == end_y: break if t_max_x < t_max_y: t_max_x += t_delta_x x += step_x else: t_max_y += t_delta_y y += step_y return tiles def tile_mask_function(dest_size, src_size, kernel_window=4.5, tile_size=64): k = dest_size / src_size PAD = math.ceil((kernel_window-0.5) / k) grid_size = math.ceil((src_size + 2*PAD)/tile_size), math.ceil(dest_size/tile_size) line_1 = 0, 0.5/tile_size, (dest_size)/tile_size, (src_size+0.5)/tile_size line_2 = 0, (2*PAD - 0.5)/tile_size, (dest_size)/tile_size, (src_size + 2*PAD - 0.5)/tile_size lines = line_1, line_2 mask = torch.zeros(grid_size, dtype=torch.bool) tiles = set() for (x0, y0, x1, y1) in lines: tiles.update(grid_line_tiles(x0, y0, x1, y1, grid_size[1], grid_size[0])) tiles = torch.tensor(list(tiles)) mask[tiles[:,0], tiles[:,1]] = True return mask, tiles def create_tensor_metadata( tile_mask: torch.Tensor, tiles: torch.Tensor, indices: torch.Tensor, offsets: torch.Tensor, offsets_t: torch.Tensor, ): indices[:,:2] = tiles torch.argsort(indices[:,1], stable=True, out=indices[:,2]) # block_offsets_t torch.take(indices[:,0], indices[:,2], out=indices[:,3]) # col_indices_t # reusing the offsets buffer here helps performance torch.sum(tile_mask, dim=1, out=offsets[1:]) torch.sum(tile_mask, dim=0, out=offsets_t[1:]) torch.cumsum(offsets, dim=0, out=offsets) torch.cumsum(offsets_t, dim=0, out=offsets_t) return indices, offsets, offsets_t # for isolating the one mandatory graph break @torch.compiler.disable def _get_nnz_and_buffers(tile_mask): num_sparse_blocks = torch.sum(tile_mask).item() return [ torch.empty((4, num_sparse_blocks), dtype=torch.int64, pin_memory=True).T, # indices torch.zeros((tile_mask.shape[0] + 1,), dtype=torch.int32, pin_memory=True), # offsets torch.zeros((tile_mask.shape[1] + 1,), dtype=torch.int32, pin_memory=True) # offsets_t ] def generate_sparse_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64): tile_mask, tiles = tile_mask_function(dest_size, src_size, kernel_window, tile_size) buffers = _get_nnz_and_buffers(tile_mask) num_sparse_blocks = buffers[0].shape[0] indices, offsets, offsets_t = create_tensor_metadata( tile_mask, tiles, *buffers ) indices = indices.to(device='cuda', dtype=torch.int32, non_blocking=True) return Matrix( (tile_mask.shape[0] * tile_size, tile_mask.shape[1] * tile_size), torch.empty(num_sparse_blocks, tile_size, tile_size, dtype=torch.float16, device='cuda'), row_indices=indices[:,0], column_indices=indices[:,1], offsets=offsets.to(device='cuda', non_blocking=True), column_indices_t=indices[:,3], offsets_t=offsets_t.to(device='cuda', non_blocking=True), block_offsets_t=indices[:,2] ) @triton.jit def compute_sparse_coord_grid_kernel( coords_source_ptr, coords_dest_ptr, sparse_data_ptr, row_indices_ptr, col_indices_ptr, k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr ): SPARSE_BLOCK_NUMEL = SPARSE_BLOCK_SIZE * SPARSE_BLOCK_SIZE sparse_block = tl.program_id(0) tile_row = tl.program_id(1) tile_col = tl.program_id(2) row_offsets = tl.load(row_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_row * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) col_offsets = tl.load(col_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_row = row_offsets < M mask_col = col_offsets < N coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0) coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0) x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16) x = magic_kernel_sharp_2021_triton(x) x *= k sparse_block_ptr = sparse_data_ptr + sparse_block * SPARSE_BLOCK_NUMEL local_row_start = tile_row * BLOCK_SIZE local_col_start = tile_col * BLOCK_SIZE local_rows = local_row_start + tl.arange(0, BLOCK_SIZE) local_cols = local_col_start + tl.arange(0, BLOCK_SIZE) local_rows_2d = local_rows[:, None] local_cols_2d = local_cols[None, :] store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d tl.store(sparse_block_ptr + store_offset, x) def compute_sparse_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64): assert SPARSE_BLOCK_SIZE % BLOCK_SIZE == 0 k = target_size / source_size PAD = math.ceil((kernel_window - 0.5) / k) coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda') coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda') M, N = coords_source.shape[0], coords_dest.shape[0] x = generate_sparse_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE) SPARSE_NUM_BLOCKS = x.data.shape[0] grid = lambda meta: (SPARSE_NUM_BLOCKS, triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']), triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE'])) compute_sparse_coord_grid_kernel[grid]( coords_source, coords_dest, x.data, x.row_indices, x.column_indices, k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE) return x # Dense kernel for downsampling coord_grids @triton.jit def compute_coord_grid_kernel( coords_source_ptr, coords_dest_ptr, coord_grid_ptr, k, M, N, BLOCK_SIZE: tl.constexpr, ): row_offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) col_offsets = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask_row = row_offsets < M mask_col = col_offsets < N coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row) coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col) x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16) x = magic_kernel_sharp_2021_triton(x) x *= k tl.store(coord_grid_ptr + row_offsets[:, None] * N + col_offsets[None, :], x, mask=mask_row[:, None] & mask_col[None, :]) def compute_coord_grid(target_size, source_size, kernel_window=4.5, BLOCK_SIZE=32): k = target_size / source_size PAD = math.ceil((kernel_window - 0.5) / k) coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda') coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda') M, N = coords_source.shape[0], coords_dest.shape[0] coord_grid = torch.empty((M, N), dtype=torch.float16, device='cuda') grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) compute_coord_grid_kernel[grid](coords_source, coords_dest, coord_grid, k, M, N, BLOCK_SIZE) return coord_grid @triton.jit def pad_replicate_kernel( A, B, M_X, N_X, M_Y, N_Y, M_PAD, N_PAD, stride_xc, stride_xm, stride_xn, stride_yc, stride_ym, stride_yn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, fuse_linrgb: tl.constexpr ): pid_c = tl.program_id(0) pid_m = tl.program_id(1) pid_n = tl.program_id(2) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m_cl = tl.maximum(offs_m, M_PAD) - M_PAD offs_m_cl = tl.minimum(offs_m_cl, M_X - 1) offs_n_cl = tl.maximum(offs_n, N_PAD) - N_PAD offs_n_cl = tl.minimum(offs_n_cl, N_X - 1) mask_m = offs_m < M_Y mask_n = offs_n < N_Y A_ptr = A + pid_c * stride_xc + offs_m_cl[:, None] * stride_xm + offs_n_cl[None, :] * stride_xn B_ptr = B + pid_c * stride_yc + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn t = tl.load(A_ptr) if fuse_linrgb: t = srgb_to_linear_triton(t) tl.store(B_ptr, t, mask=mask_m[:, None] & mask_n[None, :]) def pad_replicate( img: torch.Tensor, pad_h: int, pad_w: int, sparse_block_size: int = 0, fuse_linrgb: bool = True, ): C = img.shape[0] M_PAD = pad_h N_PAD = pad_w if sparse_block_size != 0: out_H = img.shape[-2] + M_PAD + (-(img.shape[-2] + M_PAD)) % sparse_block_size out_W = img.shape[-1] + N_PAD + (-(img.shape[-1] + N_PAD)) % sparse_block_size else: out_H = img.shape[-2] + M_PAD + M_PAD out_W = img.shape[-1] + N_PAD + N_PAD out = torch.empty(C, out_H, out_W, dtype=img.dtype, device=img.device) BLOCK_M = 1 BLOCK_N = 512 grid = lambda META: ( C, (out.shape[1] + META['BLOCK_M'] - 1) // META['BLOCK_M'], (out.shape[2] + META['BLOCK_N'] - 1) // META['BLOCK_N'], ) pad_replicate_kernel[grid]( img, out, img.shape[1], img.shape[2], out.shape[1], out.shape[2], M_PAD, N_PAD, img.stride(0), img.stride(1), img.stride(2), out.stride(0), out.stride(1), out.stride(2), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, fuse_linrgb=fuse_linrgb, ) return out def downscale_sparse( image: torch.Tensor, target_size: Tuple[int, int], resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021, do_gamma_handling=True, BLOCK_SIZE: int = 32, SPARSE_BLOCK_SIZE: int = 64, ) -> torch.Tensor: kernel, window = _get_resize_kernel_triton(resize_kernel) T_W = target_size[-1] T_H = target_size[-2] S_W = image.shape[-1] S_H = image.shape[-2] y_s_w = compute_sparse_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE) y_s_h = compute_sparse_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE) PAD_W = math.ceil((window - 0.5) / (T_W / S_W)) PAD_H = math.ceil((window - 0.5) / (T_H / S_H)) image = pad_replicate( image, PAD_H, PAD_W, SPARSE_BLOCK_SIZE, fuse_linrgb=do_gamma_handling ) image = triton_dds( image, y_s_w, output_mt=True ) image = triton_dds( image, y_s_h, fuse_srgb=do_gamma_handling, clamp_output=True, output_mt=True, output_slice=(T_H, T_W) ) return image def downscale_triton( image: torch.Tensor, target_size: torch.Size, resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021, do_gamma_handling=True, ) -> torch.Tensor: kernel, window = _get_resize_kernel_triton(resize_kernel) y_s_w = compute_coord_grid(target_size[-1], image.shape[-1], window) y_s_h = compute_coord_grid(target_size[-2], image.shape[-2], window) PAD_W = math.ceil((window - 0.5) / (target_size[-1] / image.shape[-1])) PAD_H = math.ceil((window - 0.5) / (target_size[-2] / image.shape[-2])) image = pad_replicate(image, PAD_H, PAD_W, fuse_linrgb=do_gamma_handling) image = image.view(-1, image.shape[-1]) image = image @ y_s_w image = image.view(3, -1, image.shape[-1]) image = image.mT image = image.reshape(-1, image.shape[-1]) image = image @ y_s_h image = image.view(3, -1, image.shape[-1]) image = image.mT if do_gamma_handling: image = linear_to_srgb(image[:, :target_size[0], :target_size[1]]) image.clamp_(0.,1.) return image # Single Block Sparse Column implementations. def evaluate_line(x, x0, y0, x1, y1): """Evaluate the y-coordinate at a given x along a line from (x0, y0) to (x1, y1).""" if x1 == x0: return float('inf') t = (x - x0) / (x1 - x0) return y0 + t * (y1 - y0) def pad_height_to_multiple(height, multiple): """Pad a height up to the next multiple of 'multiple'.""" return int(math.ceil(height / multiple) * multiple) def generate_sbsc_structure( dest_size, src_size, kernel_window=4.5, tile_size=64, y_tile_size=32 ): k = dest_size / src_size PAD = math.ceil((kernel_window - 0.5) / k) line1 = (0, 0.5, dest_size, src_size + 0.5) line2 = (0, 2 * PAD - 0.5, dest_size, src_size + 2 * PAD - 0.5) y_mins = [] y_maxs = [] n_blocks = math.ceil(dest_size / tile_size) max_height = 0 for i in range(n_blocks): x0 = i * tile_size x1 = min(dest_size - 1, x0 + tile_size - 1) yt0 = evaluate_line(x0, *line1) yt1 = evaluate_line(x1, *line1) yb0 = evaluate_line(x0, *line2) yb1 = evaluate_line(x1, *line2) y_min = min(yt0, yt1) y_max = max(yb0, yb1) height = y_max - y_min padded = pad_height_to_multiple(height, y_tile_size) y_mins.append(y_min) y_maxs.append(y_max) max_height = max(max_height, padded) slope_top = (line1[3] - line1[1]) / (line1[2] - line1[0]) ideal_step = slope_top * tile_size lower_bounds = [] upper_bounds = [] for i in range(1, n_blocks): lower_bounds.append((y_maxs[i] - max_height) / i) upper_bounds.append(y_mins[i] / i) lower = math.ceil(max(lower_bounds)) if lower_bounds else 0 upper = math.floor(min(upper_bounds)) if upper_bounds else int(round(ideal_step)) fixed_offset = int(round(ideal_step)) if fixed_offset < lower: fixed_offset = lower elif fixed_offset > upper: fixed_offset = upper return fixed_offset, max_height, n_blocks, tile_size def generate_sbsc_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64, y_tile_size=32): offset, block_height, num_blocks, col_width = generate_sbsc_structure( dest_size, src_size, kernel_window, tile_size, y_tile_size ) return SBSCMatrix( size=((offset * (num_blocks - 1)) + block_height, dest_size), data=torch.empty((num_blocks, block_height, col_width), dtype=torch.float16, device='cuda'), offset=offset, block_size=y_tile_size ) @triton.jit def compute_sbsc_coord_grid_kernel( coords_source_ptr, coords_dest_ptr, sparse_data_ptr, offset: tl.constexpr, stride_xb, stride_xw, stride_xh, k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr ): pid_w = tl.program_id(0) pid_h = tl.program_id(1) start_row = offset * pid_w + pid_h * BLOCK_SIZE start_col = pid_w * SPARSE_BLOCK_SIZE row_offsets = start_row + tl.arange(0, BLOCK_SIZE) col_offsets = start_col + tl.arange(0, SPARSE_BLOCK_SIZE) mask_row = row_offsets < M mask_col = col_offsets < N coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0) coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0) y = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16) y = magic_kernel_sharp_2021_triton(y) y *= k sparse_block_ptr = sparse_data_ptr + pid_w * stride_xb local_row_start = pid_h * BLOCK_SIZE local_rows = local_row_start + tl.arange(0, BLOCK_SIZE) local_cols = tl.arange(0, SPARSE_BLOCK_SIZE) local_rows_2d = local_rows[:, None] local_cols_2d = local_cols[None, :] store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d tl.store(sparse_block_ptr + store_offset, y) def compute_sbsc_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64): k = target_size / source_size PAD = math.ceil((kernel_window - 0.5) / k) coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda') coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda') M, N = coords_source.shape[0], coords_dest.shape[0] x = generate_sbsc_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE, BLOCK_SIZE) SPARSE_BLOCKS, BLOCK_HEIGHT, _ = x.data.shape stride_xb, stride_xh, stride_xw = x.data.stride() grid = lambda meta: (SPARSE_BLOCKS, triton.cdiv(BLOCK_HEIGHT, meta['BLOCK_SIZE'])) compute_sbsc_coord_grid_kernel[grid]( coords_source, coords_dest, x.data, x.offset, stride_xb, stride_xh, stride_xw, k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE) return x def downscale_sbsc( image: torch.Tensor, target_size: Tuple[int, int], resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021, do_gamma_handling: bool = True, BLOCK_SIZE: int = 32, SPARSE_BLOCK_SIZE: int = 64, ) -> torch.Tensor: kernel, window = _get_resize_kernel_triton(resize_kernel) T_W = target_size[-1] T_H = target_size[-2] S_W = image.shape[-1] S_H = image.shape[-2] y_s_w = compute_sbsc_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE) y_s_h = compute_sbsc_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE) PAD_W = math.ceil((window - 0.5) / (T_W / S_W)) PAD_H = math.ceil((window - 0.5) / (T_H / S_H)) image = pad_replicate( image, PAD_H, PAD_W, fuse_linrgb=do_gamma_handling, sparse_block_size=SPARSE_BLOCK_SIZE, ) image = triton_dds_sbsc( image, y_s_w, output_mt=True ) image = triton_dds_sbsc( image, y_s_h, fuse_srgb=do_gamma_handling, clamp_output=True, output_mt=True, ) return image def downscale_sbsc_zerorhs( image: torch.Tensor, target_size: Tuple[int, int], resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021, do_gamma_handling=True, gamma_handling_type: str = 'fast', BLOCK_SIZE: int = 32, SPARSE_BLOCK_SIZE: int = 64, ) -> torch.Tensor: kernel, window = _get_resize_kernel_triton(resize_kernel) T_W = target_size[-1] T_H = target_size[-2] S_W = image.shape[-1] S_H = image.shape[-2] block_specs_w = generate_sbsc_structure( T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE ) block_specs_h = generate_sbsc_structure( T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE ) image = triton_dds_zerorhs_sbsc( image, T_W, S_W, window, block_specs_w, fuse_srgb='input' if do_gamma_handling else '', gamma_correction=gamma_handling_type, output_mt=True ) image = triton_dds_zerorhs_sbsc( image, T_H, S_H, window, block_specs_h, fuse_srgb='output' if do_gamma_handling else '', gamma_correction=gamma_handling_type, clamp_output=True, output_mt=True, ) return image