refactor: integrate sharpfin for high-quality image resize

Vendor sharpfin library (Apache 2.0) and add centralized wrapper
module (images_sharpfin.py) replacing torchvision tensor/PIL
conversion and resize operations throughout the codebase.

- Add modules/sharpfin/ vendored library with MKS2021, Lanczos3,
  Mitchell, Catmull-Rom kernels and optional Triton sparse acceleration
- Add modules/images_sharpfin.py wrapper with to_tensor(), to_pil(),
  pil_to_tensor(), normalize(), resize(), resize_tensor()
- Add resize_quality and resize_linearize_srgb settings
- Add MKS2021 and Lanczos3 upscaler entries
- Replace torchvision.transforms.functional imports across 18 files
- to_pil() auto-detects HWC/BHWC layout, adds .round() before uint8
- Sparse Triton path falls back to dense GPU on compilation failure
- Mixed-axis resize splits into two single-axis scale() calls
- Masks and non-sRGB data always use linearize=False
pull/4668/head
CalamitousFelicitousness 2026-02-10 00:13:35 +00:00 committed by vladmandic
parent 2c4d0751d9
commit 76aa949a26
30 changed files with 2878 additions and 78 deletions

View File

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import QuickGELUActivation
import torchvision
import torchvision.transforms.functional as TVF
from modules import images_sharpfin
import einops
from einops.layers.torch import Rearrange
import huggingface_hub
@ -1035,8 +1035,8 @@ def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
padded_image.paste(image, (pad_left, pad_top))
if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.Resampling.LANCZOS)
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
image_tensor = images_sharpfin.to_tensor(padded_image)
image_tensor = images_sharpfin.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
return image_tensor

View File

@ -4,7 +4,7 @@ import time
import numpy as np
import torch
from PIL import Image
from modules import shared, upscaler
from modules import shared, upscaler, images_sharpfin
def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
@ -36,7 +36,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
def resize(im: Union[Image.Image, torch.Tensor], w, h):
w, h = int(w), int(h)
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
return im.resize((w, h), resample=Image.Resampling.LANCZOS) # force for mask
return images_sharpfin.resize(im, (w, h), linearize=False) # force for mask
if isinstance(im, torch.Tensor):
scale = max(w // 8 / im.shape[-1] , h // 8 / im.shape[-2])
else:
@ -53,7 +53,7 @@ def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width:
shared.log.warning(f"Resize upscaler: invalid={upscaler_name} fallback={selected_upscaler.name}")
shared.log.debug(f"Resize upscaler: available={[u.name for u in shared.sd_upscalers]}")
if isinstance(im, Image.Image) and (im.width != w or im.height != h): # probably downsample after upscaler created larger image
im = im.resize((w, h), resample=Image.Resampling.LANCZOS)
im = images_sharpfin.resize(im, (w, h))
return im
def crop(im: Image.Image):

298
modules/images_sharpfin.py Normal file
View File

@ -0,0 +1,298 @@
"""Sharpfin wrapper for high-quality image resize and tensor conversion.
Provides drop-in replacements for torchvision.transforms.functional operations
with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization,
and Triton GPU acceleration when available.
All public functions include try/except fallback to PIL/torchvision.
"""
import torch
import numpy as np
from PIL import Image
_sharpfin_checked = False
_sharpfin_ok = False
_triton_ok = False
_log = None
def _get_log():
global _log
if _log is None:
try:
from modules.shared import log
_log = log
except Exception:
import logging
_log = logging.getLogger(__name__)
return _log
def _check():
global _sharpfin_checked, _sharpfin_ok, _triton_ok
if not _sharpfin_checked:
# DEBUG: no try/except — let import errors propagate
from modules.sharpfin.functional import scale # pylint: disable=unused-import
_sharpfin_ok = True
try:
from modules.sharpfin import TRITON_AVAILABLE
_triton_ok = TRITON_AVAILABLE
except Exception:
_triton_ok = False
_sharpfin_checked = True
def is_available():
"""Check if sharpfin functional module loaded."""
_check()
return _sharpfin_ok
KERNEL_MAP = {
"Sharpfin MKS2021": "MAGIC_KERNEL_SHARP_2021",
"Sharpfin Lanczos3": "LANCZOS3",
"Sharpfin Mitchell": "MITCHELL",
"Sharpfin Catmull-Rom": "CATMULL_ROM",
}
def _resolve_kernel(kernel=None):
"""Resolve kernel name to ResizeKernel enum. Returns None for PIL fallback."""
if kernel is not None:
name = kernel
else:
try:
from modules import shared
name = getattr(shared.opts, 'resize_quality', 'Sharpfin MKS2021')
except Exception:
name = 'Sharpfin MKS2021'
if name == "PIL Lanczos" or name not in KERNEL_MAP:
return None
from modules.sharpfin.util import ResizeKernel
return getattr(ResizeKernel, KERNEL_MAP[name])
def _resolve_linearize(linearize=None, is_mask=False):
"""Determine sRGB linearization setting."""
if is_mask:
return False
if linearize is not None:
return linearize
try:
from modules import shared
return getattr(shared.opts, 'resize_linearize_srgb', True)
except Exception:
return True
def _get_device_dtype(device=None, dtype=None):
"""Get optimal device/dtype for sharpfin operations."""
if device is not None and dtype is not None:
return device, dtype
try:
from modules import devices
dev = device or devices.device
if dev.type == 'cuda':
return dev, dtype or torch.float16
return dev, dtype or torch.float32
except Exception:
return device or torch.device('cpu'), dtype or torch.float32
def resize(image, target_size, *, kernel=None, linearize=None, device=None, dtype=None):
"""Resize PIL.Image or torch.Tensor, returning same type.
Args:
image: PIL.Image or torch.Tensor [B,C,H,W] / [C,H,W]
target_size: (width, height) for PIL, (H, W) for tensor
kernel: Override kernel name, or None for settings
linearize: Override sRGB linearization, or None for settings
device: Override compute device
dtype: Override compute dtype
"""
_check()
if isinstance(image, Image.Image):
return _resize_pil(image, target_size, kernel=kernel, linearize=linearize, device=device, dtype=dtype)
elif isinstance(image, torch.Tensor):
return resize_tensor(image, target_size, kernel=kernel, linearize=linearize if linearize is not None else False)
return image
def _want_sparse(dev, rk, both_down):
"""Check if Triton sparse acceleration should be attempted."""
return _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
def _scale_pil(scale_fn, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up):
"""Run sharpfin scale with sparse fallback. Returns result tensor."""
global _triton_ok # pylint: disable=global-statement
if both_down or both_up:
use_sparse = _want_sparse(dev, rk, both_down)
if use_sparse:
try:
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
except Exception:
_triton_ok = False
_get_log().info("Sharpfin: Triton sparse disabled, using dense path")
return scale_fn(tensor, out_res, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
# Mixed axis: split into two single-axis resizes
if h > src_h: # H up, W down
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
use_sparse = _want_sparse(dev, rk, True)
if use_sparse:
try:
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
except Exception:
_triton_ok = False
_get_log().info("Sharpfin: Triton sparse disabled, using dense path")
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
# H down, W up
use_sparse = _want_sparse(dev, rk, True)
if use_sparse:
try:
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=True)
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
except Exception:
_triton_ok = False
_get_log().info("Sharpfin: Triton sparse disabled, using dense path")
intermediate = scale_fn(tensor, (h, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
return scale_fn(intermediate, (h, w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=do_linear, use_sparse=False)
def _resize_pil(image, target_size, *, kernel=None, linearize=None, device=None, dtype=None):
"""Resize a PIL Image via sharpfin, falling back to PIL on error."""
w, h = target_size
if image.width == w and image.height == h:
return image
is_mask = image.mode == 'L'
rk = _resolve_kernel(kernel)
if rk is None:
# DEBUG: only "PIL Lanczos" setting should reach here
assert _resolve_kernel.__doc__, "unreachable" # keeps linter happy
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
try:
from modules.sharpfin.functional import scale
do_linear = _resolve_linearize(linearize, is_mask=is_mask)
dev, dt = _get_device_dtype(device, dtype)
tensor = to_tensor(image)
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
tensor = tensor.to(device=dev, dtype=dt)
out_res = (h, w) # sharpfin uses (H, W)
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
both_down = (h <= src_h and w <= src_w)
both_up = (h >= src_h and w >= src_w)
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
return to_pil(result)
# except Exception as e: # DEBUG: PIL fallback disabled for testing
# _get_log().warning(f"Sharpfin resize failed, falling back to PIL: {e}")
# return image.resize((w, h), resample=Image.Resampling.LANCZOS)
finally:
pass
def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
"""Resize tensor [B,C,H,W] or [C,H,W] -> Tensor. For in-pipeline tensor resizes.
Args:
tensor: Input tensor
target_size: (H, W) tuple
kernel: Override kernel name
linearize: sRGB linearization (default False for latent/mask data)
"""
_check()
rk = _resolve_kernel(kernel)
if rk is None:
# DEBUG: only "PIL Lanczos" setting should reach here
mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
return torch.nn.functional.interpolate(tensor if tensor.dim() == 4 else tensor.unsqueeze(0), size=target_size, mode=mode, antialias=True).squeeze(0) if tensor.dim() == 3 else torch.nn.functional.interpolate(tensor, size=target_size, mode=mode, antialias=True)
try:
from modules.sharpfin.functional import scale
dev, dt = _get_device_dtype()
squeezed = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
squeezed = True
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
th, tw = target_size
both_down = (th <= src_h and tw <= src_w)
both_up = (th >= src_h and tw >= src_w)
if both_down or both_up:
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
else:
if th > src_h:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
else:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
if squeezed:
result = result.squeeze(0)
return result
# except Exception as e: # DEBUG: F.interpolate fallback disabled for testing
# _get_log().warning(f"Sharpfin resize_tensor failed, falling back to F.interpolate: {e}")
# mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
# inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
# result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
# return result.squeeze(0) if tensor.dim() == 3 else result
finally:
pass
def to_tensor(image):
"""PIL Image -> float32 CHW tensor [0,1]. Pure torch, no torchvision."""
if not isinstance(image, Image.Image):
raise TypeError(f"Expected PIL Image, got {type(image)}")
pic = np.array(image, copy=True)
if pic.ndim == 2:
pic = pic[:, :, np.newaxis]
tensor = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
if tensor.dtype == torch.uint8:
return tensor.to(torch.float32).div_(255.0)
return tensor.to(torch.float32)
def to_pil(tensor):
"""Float CHW/HWC or BCHW/BHWC tensor [0,1] -> PIL Image. Pure torch, no torchvision."""
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Expected torch.Tensor, got {type(tensor)}")
tensor = tensor.detach().cpu()
if tensor.dim() == 4:
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2]: # BHWC
tensor = tensor.permute(0, 3, 1, 2)
tensor = tensor[0]
elif tensor.dim() == 3:
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2] and tensor.shape[-1] < tensor.shape[-3]: # HWC
tensor = tensor.permute(2, 0, 1)
if tensor.dtype != torch.uint8:
tensor = (tensor.clamp(0, 1) * 255).round().to(torch.uint8)
ndarr = tensor.permute(1, 2, 0).numpy()
if ndarr.shape[2] == 1:
return Image.fromarray(ndarr[:, :, 0], mode='L')
return Image.fromarray(ndarr)
def pil_to_tensor(image):
"""PIL Image -> uint8 CHW tensor (no float scaling). Replaces TF.pil_to_tensor."""
if not isinstance(image, Image.Image):
raise TypeError(f"Expected PIL Image, got {type(image)}")
pic = np.array(image, copy=True)
if pic.ndim == 2:
pic = pic[:, :, np.newaxis]
return torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
def normalize(tensor, mean, std, inplace=False):
"""Tensor normalization. Replaces TF.normalize."""
if not inplace:
tensor = tensor.clone()
mean_t = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std_t = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
if mean_t.ndim == 1:
mean_t = mean_t[:, None, None]
if std_t.ndim == 1:
std_t = std_t[:, None, None]
tensor.sub_(mean_t).div_(std_t)
return tensor

View File

@ -5,7 +5,7 @@ import torch
import numpy as np
from torch.hub import download_url_to_file, get_dir
from PIL import Image
from modules import devices
from modules import devices, images_sharpfin
from installer import log
@ -96,7 +96,5 @@ class SimpleLama:
image, mask = prepare_img_and_mask(image, mask, self.device)
with devices.inference_context():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().float().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
cur_res = Image.fromarray(cur_res)
cur_res = images_sharpfin.to_pil(inpainted[0])
return cur_res

View File

@ -70,7 +70,7 @@ def setup_model(dirname):
self.face_helper.face_parse.to(device)
def restore(self, np_image, p=None, w=None): # pylint: disable=unused-argument
from torchvision.transforms.functional import normalize
from modules import images_sharpfin
from basicsr.utils import img2tensor, tensor2img
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
@ -84,7 +84,7 @@ def setup_model(dirname):
self.face_helper.align_warp_face()
for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
images_sharpfin.normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device)
try:
with devices.inference_context():

View File

@ -2,8 +2,8 @@ from typing import List
import math
import torch
import torchvision
import numpy as np
from modules import images_sharpfin
from PIL import Image
from diffusers.utils import CONFIG_NAME
@ -65,11 +65,9 @@ def edge_detect_for_pixelart(image: PipelineImageInput, image_weight: float = 1.
greyscale_reshaped = greyscale_reshaped.reshape(batch_size, block_size_sq, block_height, block_width)
greyscale_range = greyscale_reshaped.amax(dim=1, keepdim=True).sub_(greyscale_reshaped.amin(dim=1, keepdim=True))
upsample = torchvision.transforms.Resize((height, width), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
range_weight = upsample(greyscale_range)
range_weight = images_sharpfin.resize_tensor(greyscale_range, (height, width), linearize=False)
range_weight = range_weight.div_(range_weight.max())
weight_map = upsample((greyscale > greyscale.median()).to(dtype=torch.float32))
weight_map = images_sharpfin.resize_tensor((greyscale > greyscale.median()).to(dtype=torch.float32), (height, width), linearize=False)
weight_map = weight_map.unsqueeze(0).add_(range_weight).mul_(image_weight / 2)
new_image = new_image.mul_(weight_map).addcmul_(min_pool, (1-weight_map))
@ -161,8 +159,7 @@ def encode_jpeg_tensor(img: torch.FloatTensor, block_size: int=16, cbcr_downscal
img = img[:, :, :(img.shape[-2]//block_size)*block_size, :(img.shape[-1]//block_size)*block_size] # crop to a multiply of block_size
cbcr_block_size = block_size//cbcr_downscale
_, _, height, width = img.shape
downsample = torchvision.transforms.Resize((height//cbcr_downscale, width//cbcr_downscale), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
down_img = downsample(img[:, 1:,:,:])
down_img = images_sharpfin.resize_tensor(img[:, 1:,:,:], (height//cbcr_downscale, width//cbcr_downscale), linearize=False)
y = encode_single_channel_dct_2d(img[:, 0, :,:], block_size=block_size, norm=norm)
cb = encode_single_channel_dct_2d(down_img[:, 0, :,:], block_size=cbcr_block_size, norm=norm)
cr = encode_single_channel_dct_2d(down_img[:, 1, :,:], block_size=cbcr_block_size, norm=norm)
@ -180,9 +177,8 @@ def decode_jpeg_tensor(jpeg_img: torch.FloatTensor, block_size: int=16, cbcr_dow
y = decode_single_channel_dct_2d(y, norm=norm)
cb = decode_single_channel_dct_2d(cb, norm=norm)
cr = decode_single_channel_dct_2d(cr, norm=norm)
upsample = torchvision.transforms.Resize((h_blocks*block_size, w_blocks*block_size), interpolation=torchvision.transforms.InterpolationMode.BICUBIC)
cb = upsample(cb)
cr = upsample(cr)
cb = images_sharpfin.resize_tensor(cb, (h_blocks*block_size, w_blocks*block_size), linearize=False)
cr = images_sharpfin.resize_tensor(cr, (h_blocks*block_size, w_blocks*block_size), linearize=False)
return torch.stack([y,cb,cr], dim=1)

View File

@ -3,8 +3,7 @@ import random
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToPILImage
from modules import devices
from modules import devices, images_sharpfin
from modules.shared import opts, log
from modules.upscaler import Upscaler, UpscalerData
@ -14,7 +13,7 @@ MODELS_MAP = {
"SeedVR2 7B": "seedvr2_ema_7b_fp16.safetensors",
"SeedVR2 7B Sharp": "seedvr2_ema_7b_sharp_fp16.safetensors",
}
to_pil = ToPILImage()
to_pil = images_sharpfin.to_pil
class UpscalerSeedVR(Upscaler):
@ -159,7 +158,7 @@ class UpscalerSeedVR(Upscaler):
)
t1 = time.time()
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}')
img = to_pil(result_tensor.squeeze().permute((2, 0, 1)))
img = to_pil(result_tensor.squeeze())
if opts.upscaler_unload:
self.model.dit = None

View File

@ -3,9 +3,8 @@ import os
import time
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, timer, modelstats, extra_networks, attention
from modules import shared, devices, processing, sd_models, errors, sd_hijack_hypertile, processing_vae, sd_models_compile, timer, modelstats, extra_networks, attention, images_sharpfin
from modules.processing_helpers import resize_hires, calculate_base_steps, calculate_hires_steps, calculate_refiner_steps, save_intermediate, update_sampler, is_txt2img, is_refiner_enabled, get_job_name
from modules.processing_args import set_pipeline_args
from modules.onnx_impl import preprocess_pipeline as preprocess_onnx_pipeline, check_parameters_changed as olive_check_parameters_changed
@ -270,9 +269,9 @@ def process_hires(p: processing.StableDiffusionProcessing, output):
sd_hijack_hypertile.hypertile_set(p, hr=True)
elif torch.is_tensor(output.images) and output.images.shape[-1] == 3: # nhwc
if output.images.dim() == 3:
output.images = TF.to_pil_image(output.images.permute(2,0,1))
output.images = images_sharpfin.to_pil(output.images)
elif output.images.dim() == 4:
output.images = [TF.to_pil_image(output.images[i].permute(2,0,1)) for i in range(output.images.shape[0])]
output.images = [images_sharpfin.to_pil(output.images[i]) for i in range(output.images.shape[0])]
strength = p.hr_denoising_strength if p.hr_denoising_strength > 0 else p.denoising_strength
if (p.hr_upscaler is not None) and (p.hr_upscaler.lower().startswith('latent') or p.hr_force) and strength > 0:
@ -572,7 +571,7 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
if hasattr(shared.sd_model, 'unet') and hasattr(shared.sd_model.unet, 'config') and hasattr(shared.sd_model.unet.config, 'in_channels') and shared.sd_model.unet.config.in_channels == 9 and not is_control:
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.INPAINTING) # force pipeline
if len(getattr(p, 'init_images', [])) == 0:
p.init_images = [TF.to_pil_image(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))]
p.init_images = [images_sharpfin.to_pil(torch.rand((3, getattr(p, 'height', 512), getattr(p, 'width', 512))))]
if not p.prompts:
p.prompts = p.all_prompts[p.iteration * p.batch_size:(p.iteration+1) * p.batch_size]
if not p.negative_prompts:

View File

@ -334,7 +334,7 @@ def vae_decode(latents, model, output_type='np', vae_type='Full', width=None, he
def vae_encode(image, model, vae_type='Full'): # pylint: disable=unused-variable
jobid = shared.state.begin('VAE Encode')
import torchvision.transforms.functional as f
from modules import images_sharpfin
if shared.state.interrupted or shared.state.skipped:
return []
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
@ -342,7 +342,7 @@ def vae_encode(image, model, vae_type='Full'): # pylint: disable=unused-variable
if not hasattr(model, 'vae'):
shared.log.error('VAE not found in model')
return []
tensor = f.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae)
tensor = images_sharpfin.to_tensor(image.convert("RGB")).unsqueeze(0).to(devices.device, devices.dtype_vae)
if vae_type == 'Tiny':
latents = taesd_vae_encode(image=tensor)
elif vae_type == 'Full' and hasattr(model, 'vae'):

View File

@ -2,9 +2,8 @@ import time
import threading
from collections import namedtuple
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from modules import shared, devices, processing, images, sd_samplers, timer
from modules import shared, devices, processing, images, sd_samplers, timer, images_sharpfin
from modules.vae import sd_vae_approx, sd_vae_taesd, sd_vae_stablecascade
@ -84,7 +83,7 @@ def single_sample_to_image(sample, approximation=None):
x_sample = (255.0 * x_sample).to(torch.uint8)
if len(x_sample.shape) == 4:
x_sample = x_sample[0]
image = TF.to_pil_image(x_sample)
image = images_sharpfin.to_pil(x_sample)
except Exception as e:
warn_once(f'Preview: {e}')
image = Image.new(mode="RGB", size=(512, 512))

View File

@ -671,6 +671,10 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
"upscaler_latent_steps": OptionInfo(20, "Upscaler latent steps", gr.Slider, {"minimum": 4, "maximum": 100, "step": 1}),
"upscaler_tile_size": OptionInfo(192, "Upscaler tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"upscaler_tile_overlap": OptionInfo(8, "Upscaler tile overlap", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
"postprocessing_sep_resize": OptionInfo("<h2>Resize</h2>", "", gr.HTML),
"resize_quality": OptionInfo("Sharpfin MKS2021", "Image resize algorithm", gr.Dropdown, {"choices": ["PIL Lanczos", "Sharpfin MKS2021", "Sharpfin Lanczos3", "Sharpfin Mitchell", "Sharpfin Catmull-Rom"]}),
"resize_linearize_srgb": OptionInfo(True, "Apply sRGB linearization during image resize"),
}))

190
modules/sharpfin/LICENSE Normal file
View File

@ -0,0 +1,190 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright 2024 drhead
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,20 @@
"""Sharpfin - High quality image resizing with GPU acceleration.
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
Provides Magic Kernel Sharp 2021 resampling, sRGB linearization,
and Triton sparse GPU acceleration.
"""
from .util import ResizeKernel, SharpenKernel, QuantHandling, srgb_to_linear, linear_to_srgb
try:
from .functional import scale, _upscale, _downscale, _get_resize_kernel
FUNCTIONAL_AVAILABLE = True
except Exception:
FUNCTIONAL_AVAILABLE = False
try:
from .triton_functional import downscale_sparse
TRITON_AVAILABLE = True
except Exception:
TRITON_AVAILABLE = False

174
modules/sharpfin/cms.py Normal file
View File

@ -0,0 +1,174 @@
"""Sharpfin color management (ICC profile handling).
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
"""
from io import BytesIO
from typing import Any, cast
from warnings import warn
import numpy as np
from torch import Tensor
import PIL.Image as image
import PIL.ImageCms as image_cms
from PIL.Image import Image
from PIL.ImageCms import (
Direction, Intent, ImageCmsProfile, PyCMSError,
createProfile, getDefaultIntent, isIntentSupported, profileToProfile
)
from PIL.ImageOps import exif_transpose
image.MAX_IMAGE_PIXELS = None
_SRGB = createProfile(colorSpace='sRGB')
_INTENT_FLAGS = {
Intent.PERCEPTUAL: image_cms.FLAGS["HIGHRESPRECALC"],
Intent.RELATIVE_COLORIMETRIC: (
image_cms.FLAGS["HIGHRESPRECALC"] |
image_cms.FLAGS["BLACKPOINTCOMPENSATION"]
),
Intent.ABSOLUTE_COLORIMETRIC: image_cms.FLAGS["HIGHRESPRECALC"]
}
class CMSWarning(UserWarning):
def __init__(
self,
message: str,
*,
path: str | None = None,
cms_info: dict[str, Any] | None = None,
cause: Exception | None = None,
):
super().__init__(message)
self.__cause__ = cause
self.path = path
self.cms_info = cms_info
def _coalesce_intent(intent: Intent | int) -> Intent:
if isinstance(intent, Intent):
return intent
match intent:
case 0:
return Intent.PERCEPTUAL
case 1:
return Intent.RELATIVE_COLORIMETRIC
case 2:
return Intent.SATURATION
case 3:
return Intent.ABSOLUTE_COLORIMETRIC
case _:
raise ValueError("invalid intent")
def _add_info(info: dict[str, Any], source: object, key: str) -> None:
try:
if (value := getattr(source, key, None)) is not None:
info[key] = value
except Exception:
pass
def apply_srgb(
img: Image
) -> Image:
if hasattr(img, 'filename'):
path = img.filename
else:
path = ""
try:
img.load()
try:
exif_transpose(img, in_place=True)
except Exception:
pass # corrupt EXIF metadata is fine
if (icc_raw := img.info.get("icc_profile")) is not None:
cms_info: dict[str, Any] = {
"native_mode": img.mode,
"transparency": img.has_transparency_data,
}
try:
profile = ImageCmsProfile(BytesIO(icc_raw))
_add_info(cms_info, profile.profile, "profile_description")
_add_info(cms_info, profile.profile, "target")
_add_info(cms_info, profile.profile, "xcolor_space")
_add_info(cms_info, profile.profile, "connection_space")
_add_info(cms_info, profile.profile, "colorimetric_intent")
_add_info(cms_info, profile.profile, "rendering_intent")
working_mode = img.mode
if img.mode.startswith(("RGB", "BGR", "P")):
working_mode = "RGBA" if img.has_transparency_data else "RGB"
elif img.mode.startswith(("L", "I", "F")) or img.mode == "1":
working_mode = "LA" if img.has_transparency_data else "L"
if img.mode != working_mode:
cms_info["working_mode"] = working_mode
img = img.convert(working_mode)
mode = "RGBA" if img.has_transparency_data else "RGB"
intent = Intent.RELATIVE_COLORIMETRIC
if isIntentSupported(profile, intent, Direction.INPUT) != 1:
intent = _coalesce_intent(getDefaultIntent(profile))
cms_info["conversion_intent"] = intent
if (flags := _INTENT_FLAGS.get(intent)) is not None:
if img.mode == mode:
profileToProfile(
img,
profile,
_SRGB,
renderingIntent=intent,
inPlace=True,
flags=flags
)
else:
img = cast(Image, profileToProfile(
img,
profile,
_SRGB,
renderingIntent=intent,
outputMode=mode,
flags=flags
))
else:
warn(CMSWarning(
f"unsupported intent on {path} assuming sRGB: {cms_info}",
path=path,
cms_info=cms_info
))
except PyCMSError as ex:
warn(CMSWarning(
f"{ex} on {path}, assuming sRGB: {cms_info}",
path=path,
cms_info=cms_info,
cause=ex,
))
except Exception as ex:
print(f"{ex} on {path}")
if img.has_transparency_data:
if img.mode != "RGBA":
try:
img = img.convert("RGBA")
except ValueError:
img = img.convert("RGBa").convert("RGBA")
elif img.mode != "RGB":
img = img.convert("RGB")
return img
def put_srgb(img: Image, tensor: Tensor) -> None:
if img.mode not in ("RGB", "RGBA", "RGBa"):
raise ValueError(f"Image has non-RGB mode {img.mode}.")
np.copyto(tensor.numpy(), np.asarray(img)[:, :, :3], casting="no")

View File

@ -0,0 +1,285 @@
"""Sharpfin functional image scaling operations.
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
Imports patched: absolute sharpfin.X relative .X, triton import guarded.
"""
import torch
import numpy as np
import torch.nn.functional as F
from typing import Callable, Tuple
import math
from contextlib import nullcontext
from .util import ResizeKernel, linear_to_srgb, srgb_to_linear
# from Pytorch >= 2.6
set_stance = getattr(torch.compiler, "set_stance", None)
def _get_resize_kernel(k: ResizeKernel):
match k:
case ResizeKernel.NEAREST:
resize_kernel = nearest
kernel_window = 0.5
case ResizeKernel.BILINEAR:
resize_kernel = bilinear
kernel_window = 1.
case ResizeKernel.MITCHELL:
resize_kernel = mitchell # B = 1/3, C = 1/3
kernel_window = 2.
case ResizeKernel.CATMULL_ROM:
resize_kernel = lambda x: mitchell(x, 0.0, 0.5)
kernel_window = 2.
case ResizeKernel.B_SPLINE:
resize_kernel = lambda x: mitchell(x, 1.0, 0.0)
kernel_window = 2.
case ResizeKernel.LANCZOS2:
resize_kernel = lambda x: lanczos(x, 2)
kernel_window = 2.
case ResizeKernel.LANCZOS3:
resize_kernel = lambda x: lanczos(x, 3)
kernel_window = 3.
case ResizeKernel.MAGIC_KERNEL:
resize_kernel = magic_kernel
kernel_window = 1.5
case ResizeKernel.MAGIC_KERNEL_SHARP_2013:
resize_kernel = magic_kernel_sharp_2013
kernel_window = 2.5
case ResizeKernel.MAGIC_KERNEL_SHARP_2021:
resize_kernel = magic_kernel_sharp_2021
kernel_window = 4.5
case _:
raise ValueError(f"Unknown resize kernel {k}")
return resize_kernel, kernel_window
### Resampling kernels
def nearest(x: torch.Tensor) -> torch.Tensor:
x = torch.abs(x)
weights = torch.where(x <= 0.5, 1., 0.)
return weights
def bilinear(x: torch.Tensor) -> torch.Tensor:
x = torch.abs(x)
weights = torch.where(x <= 1.0, 1 - x, 0.)
return weights
def mitchell(x: torch.Tensor, B: float = 1 / 3, C: float = 1 / 3) -> torch.Tensor:
x = torch.abs(x)
weights = torch.where(x <= 2, (-B - 6 * C) * x**3 + (6 * B + 30 * C) * x**2 + (-12 * B - 48 * C) * x + (8 * B + 24 * C), 0)
weights = torch.where(x <= 1, (12 - 9 * B - 6 * C) * x**3 + (-18 + 12 * B + 6 * C) * x**2 + (6 - 2 * B), weights)
return weights
def magic_kernel(x: torch.Tensor) -> torch.Tensor:
x = torch.abs(x)
weights = torch.where(x <= 1.5, (1/2) * (x - 3/2) ** 2, 0)
weights = torch.where(x <= 0.5, (3/4) - x ** 2, weights)
return weights
def magic_kernel_sharp_2013(x: torch.Tensor):
x = torch.abs(x)
weights = torch.where(x <= 2.5, (-1/8) * (x - 5/2) ** 2, 0)
weights = torch.where(x <= 1.5, (1 - x) * (7/4 - x), weights)
weights = torch.where(x <= 0.5, (17/16) - (7/4) * x ** 2, weights)
return weights
def magic_kernel_sharp_2021(x: torch.Tensor):
x = torch.abs(x)
weights = torch.where(x <= 4.5, (-1/288) * (x - 9/2) ** 2, 0)
weights = torch.where(x <= 3.5, (1/36) * (x - 3) * (x - 15/4), weights)
weights = torch.where(x <= 2.5, (1/6) * (x - 2) * (65/24 - x), weights)
weights = torch.where(x <= 1.5, (35/36) * (x - 1) * (x - 239/140), weights)
weights = torch.where(x <= 0.5, (577/576) - (239/144) * x ** 2, weights)
return weights
def lanczos(x: torch.Tensor, n: int):
return torch.where(torch.abs(x) < n, torch.sinc(x) * torch.sinc(x/n), 0)
def sharpen_conv2d(image: torch.Tensor, kernel: torch.Tensor, pad: int) -> torch.Tensor:
image = F.pad(image, (pad,pad,pad,pad), mode='replicate')
return F.conv2d(image, kernel, groups=image.shape[-3])
### Dithering and related functions.
def stochastic_round(
x: torch.Tensor,
out_dtype: torch.dtype,
generator: torch.Generator = torch.Generator(),
):
image = x * torch.iinfo(out_dtype).max
image_quant = image.to(out_dtype)
quant_error = image - image_quant.to(image.dtype)
dither = torch.empty_like(image_quant, dtype=torch.bool)
torch.bernoulli(quant_error, generator=generator, out=dither)
return image_quant + dither
def generate_bayer_matrix(n):
"""Generate an n x n Bayer matrix where n is a power of 2."""
assert (n & (n - 1)) == 0 and n > 0, "n must be a power of 2"
if n == 1:
return np.array([[0]]) # Base case
smaller_matrix = generate_bayer_matrix(n // 2)
return np.block([
[4 * smaller_matrix + 0, 4 * smaller_matrix + 2],
[4 * smaller_matrix + 3, 4 * smaller_matrix + 1]
])
### Scaling transforms
def _downscale_axis(
image: torch.Tensor,
size: int,
resize_kernel: ResizeKernel,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
kernel, window = _get_resize_kernel(resize_kernel)
k = size / image.shape[-1]
PAD = math.ceil((window - 0.5) / k)
# Optimization note: doing torch.arange like this will compile to doing a int64 arange. Float arange
# is much slower. So don't try to get clever and "optimize" by adding the +0.5 and *k to this.
# Source grid is padded to allow "out of range" sampling from the source image.
coords_source = (torch.arange(-PAD, image.shape[-1]+PAD, 1, dtype=torch.float32, device=device) + 0.5) * k
coords_dest = (torch.arange(0, size, 1, dtype=torch.float32, device=device) + 0.5)
# Create a grid of relative distances between each point on this axis.
coord_grid = torch.empty((coords_source.shape[0], coords_dest.shape[0]), dtype=dtype, device=device)
# Coord grid always constructed in torch.float32 because float16 precision breaks down for this
# after 1024.0. This subtraction is the first opportunity we have to safely cast to float16.
torch.sub(coords_source.unsqueeze(-1), other=coords_dest, out=coord_grid)
weights = kernel(coord_grid)
# Normalizing weights to sum to 1 along axis we are resizing on
weights /= weights.sum(dim=0, keepdim=True)
# weights /= (1/k)
# Padded dimension is reduced by the matmul here.
return F.pad(image, (PAD,PAD,0,0), mode='replicate') @ weights
def _upscale_axis(
image: torch.Tensor,
size: int,
resize_kernel: ResizeKernel,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
kernel, window = _get_resize_kernel(resize_kernel)
k = size / image.shape[-1]
PAD = math.ceil((window - 0.5) * k)
# For upsizing, we expect out of range sampling from the destination image.
coords_source = (torch.arange(0, image.shape[-1], 1, dtype=torch.float32, device=device) + 0.5)
coords_dest = (torch.arange(-PAD, size+PAD, 1, dtype=torch.float32, device=device) + 0.5) / k
coord_grid = torch.empty((coords_source.shape[0], coords_dest.shape[0]), dtype=dtype, device=device)
torch.sub(coords_source.unsqueeze(-1), other=coords_dest, out=coord_grid)
weights = kernel(coord_grid)
# We need to explicitly trim padding by summing it into the real area of the destination grid.
weights[:, PAD] += weights[:, :PAD].sum(dim=1)
weights[:, -PAD-1] += weights[:, -PAD:].sum(dim=1)
weights = weights[:, PAD:-PAD]
weights /= weights.sum(dim=0, keepdim=True)
return image @ weights
@torch.compile
def _downscale(
image: torch.Tensor,
out_res: tuple[int, int],
resize_kernel: ResizeKernel,
device: torch.device,
dtype: torch.dtype,
do_srgb_conversion: bool,
):
H, W = out_res
image = image.to(device=device, dtype=dtype)
if do_srgb_conversion:
image = srgb_to_linear(image)
image = _downscale_axis(image, W, resize_kernel, device, dtype)
image = _downscale_axis(image.mT, H, resize_kernel, device, dtype).mT
if do_srgb_conversion:
image = linear_to_srgb(image)
image = image.clamp(0,1)
return image
@torch.compile
def _upscale(
image: torch.Tensor,
out_res: tuple[int, int],
resize_kernel: ResizeKernel,
device: torch.device,
dtype: torch.dtype,
do_srgb_conversion: bool,
):
H, W = out_res
image = image.to(device=device, dtype=dtype)
if do_srgb_conversion:
image = srgb_to_linear(image)
image = _upscale_axis(image, W, resize_kernel, device, dtype)
image = _upscale_axis(image.mT, H, resize_kernel, device, dtype).mT
if do_srgb_conversion:
image = linear_to_srgb(image)
image = image.clamp(0,1)
return image
# Triton sparse downscale - only available with Triton (CUDA)
try:
from .triton_functional import downscale_sparse
except ImportError:
downscale_sparse = None
def scale(
image: torch.Tensor,
out_res: Tuple[int, int],
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float32,
do_srgb_conversion: bool = True,
use_sparse: bool = False,
) -> torch.Tensor:
if isinstance(device, str):
device = torch.device(device)
if use_sparse:
assert device.type != "cpu", "sparse implementation is only for GPU!"
if resize_kernel != ResizeKernel.MAGIC_KERNEL_SHARP_2021:
raise NotImplementedError
if downscale_sparse is None:
raise ImportError("Triton is required for sparse GPU acceleration")
context_manager = (
set_stance("force_eager") if set_stance and device.type == "cpu" else nullcontext()
)
with context_manager:
if image.shape[-1] <= out_res[-1] and image.shape[-2] <= out_res[-2]:
assert not use_sparse
return _upscale(image, out_res, resize_kernel, device, dtype, do_srgb_conversion)
elif image.shape[-1] >= out_res[-1] and image.shape[-2] >= out_res[-2]:
if use_sparse:
return downscale_sparse(image, out_res, resize_kernel)
return _downscale(image, out_res, resize_kernel, device, dtype, do_srgb_conversion)
else:
raise ValueError("Mixed axis resizing (e.g. scaling one axis up and the other down) is not supported. File a bug report with your use case if needed.")

View File

@ -0,0 +1,845 @@
"""Sharpfin sparse matrix backend for Triton DDS matmul.
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
Adapted from https://github.com/stanford-futuredata/stk (Apache 2.0)
"""
import numpy as np
import torch
import triton
import triton.language as tl
from typing import Tuple
from dataclasses import dataclass
from .triton_functional import linear_to_srgb_triton, srgb_to_linear_triton, magic_kernel_sharp_2021_triton, lanczos_triton
# Code is all adapted from https://github.com/stanford-futuredata/stk, licensed under Apache-2.0
# Very reduced set of functions for handling DDS (Dense = Dense @ Sparse) matmul only, with the
# DDS kernel modified to be more flexible on input shapes.
def _validate_matrix(shape, data, row_indices, column_indices, offsets):
if data.dim() == 1:
data = torch.reshape(data, [data.numel(), 1, 1])
if data.shape[-2] != data.shape[-1]:
raise ValueError(
"Expected square blocking in data. "
f"Got block shape {[data.shape[-2], data.shape[-1]]}")
block_size = data.shape[-1]
data = data.view([-1, block_size, block_size])
if data.dim() != 3:
raise ValueError(
"Expected 3D shape for data (nnz, block, block). "
f"Got shape {data.dim()}D shape.")
block_size = data.shape[1]
if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
raise ValueError(
"Matrix shape must be dividible by blocking. "
f"Got shape {shape} with "
f"{[block_size, block_size]} blocking.")
if np.prod(shape) < data.numel():
raise ValueError(
"Invalid matrix. Number of nonzeros exceeds matrix capacity "
f"({data.numel()} v. {np.prod(shape)})")
if row_indices.dim() != 1:
raise ValueError(
f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
if column_indices.dim() != 1:
raise ValueError(
f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
if offsets.dim() != 1:
raise ValueError(
f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
if row_indices.numel() != data.shape[0]:
raise ValueError(
"Expected 1 index per nonzero block. "
f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
if column_indices.numel() != data.shape[0]:
raise ValueError(
"Expected 1 index per nonzero block. "
f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
block_rows = np.prod(shape[:-1]) / block_size
if offsets.numel() != block_rows + 1:
raise ValueError(
"Expected one offset per block row plus one. "
f"Got {offsets.numel()} offsets with {block_rows} block rows.")
is_cuda = (data.is_cuda and
row_indices.is_cuda and
column_indices.is_cuda and
offsets.is_cuda)
is_cpu = (not data.is_cuda and
not row_indices.is_cuda and
not column_indices.is_cuda and
not offsets.is_cuda)
if not (is_cuda or is_cpu):
raise ValueError(
"Expected data & meta-data on common device. "
f"Got data on {data.device}, row_indices on {row_indices.device} "
f"column_indices on {column_indices.device} and "
f"offsets on {offsets.device}.")
if data.dtype != torch.float16:
raise ValueError(
f"Expected float16 data. Got {data.dtype} data.")
if row_indices.dtype != torch.int16:
raise ValueError(
f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
if column_indices.dtype != torch.int16:
raise ValueError(
f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
if offsets.dtype != torch.int32:
raise ValueError(
f"Expected int32 offsets. Got {offsets.dtype} offsets.")
return data
def _transpose(size, data: torch.Tensor, row_indices: torch.Tensor, column_indices: torch.Tensor, offsets):
block_columns = size[1] // data.shape[1]
gather_indices = column_indices.argsort()
column_indices_t = row_indices.gather(0, gather_indices)
block_offsets_t = gather_indices.int()
column_indices_float = column_indices.float()
zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
nnz_per_column = nnz_per_column.int()
offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
return column_indices_t, offsets_t, block_offsets_t
class SBSCMatrix(torch.nn.Module):
"""Single Block Sparse Column (SBSC) matrix format."""
def __init__(
self,
size,
data: torch.Tensor,
offset: int,
block_size: int
):
super().__init__()
self.data = data
self.offset = offset
self.size = size
self.num_blocks = data.shape[0]
self.col_width = data.shape[2]
self.col_block_size = block_size
class Matrix(torch.nn.Module):
"""A matrix stored in block compressed sparse row (BCSR) format."""
def __init__(self,
size,
data: torch.Tensor,
row_indices: torch.Tensor,
column_indices: torch.Tensor,
offsets: torch.Tensor,
column_indices_t: torch.Tensor=None,
offsets_t: torch.Tensor=None,
block_offsets_t: torch.Tensor=None):
super().__init__()
self._size = size
self._data = data
self._row_indices = row_indices
self._column_indices = column_indices
self._offsets = offsets
if ((column_indices_t is None) or (offsets_t is None) or
(block_offsets_t is None)):
column_indices_t, offsets_t, block_offsets_t = _transpose(
size, data, row_indices, column_indices, offsets)
self._column_indices_t = column_indices_t
self._offsets_t = offsets_t
self._block_offsets_t = block_offsets_t
self._transposed = False
max_dim = np.iinfo(np.int16).max * self.blocking
if column_indices.dtype == torch.int16:
if size[0] > max_dim or size[1] > max_dim:
raise ValueError(
"Sparse matrix with shape {size} exceeds representable "
"size with 16-bit indices.")
def validate(self):
_validate_matrix(self._size,
self._data,
self._row_indices,
self._column_indices,
self._offsets)
def to(self, device):
self._data = self._data.to(device)
self._row_indices = self._row_indices.to(device)
self._column_indices = self._column_indices.to(device)
self._offsets = self._offsets.to(device)
self._column_indices_t = self._column_indices_t.to(device)
self._offsets_t = self._offsets_t.to(device)
self._block_offsets_t = self._block_offsets_t.to(device)
return self
def cuda(self):
return self.to(torch.cuda.current_device())
def clone(self):
return Matrix(
self.size(),
self.data.clone(),
self.row_indices.clone(),
self.column_indices.clone(),
self.offsets.clone(),
self.column_indices_t.clone(),
self.offsets_t.clone(),
self.block_offsets_t.clone())
def t(self):
if self.dim() != 2:
raise ValueError(
"t() expects a tensor with <= 2 dimensions, "
f"but self is {self.dim()}D.")
out = Matrix(self.size(),
self.data,
self.row_indices,
self.column_indices,
self.offsets,
self.column_indices_t,
self.offsets_t,
self.block_offsets_t)
out._transposed = not self._transposed
out._size = torch.Size((self._size[1], self._size[0]))
return out
def contiguous(self):
raise ValueError("Not yet implemented.")
def is_contiguous(self):
return not self._transposed
@property
def is_cuda(self):
return self._data.is_cuda
@property
def device(self):
return self._data.device
def size(self):
return self._size
@property
def shape(self):
return self.size()
def dim(self):
return len(self._size)
@property
def data(self):
return self._data
@property
def row_indices(self):
return self._row_indices
@property
def column_indices(self):
return self._column_indices
@property
def offsets(self):
return self._offsets
@property
def offsets_t(self):
return self._offsets_t
@property
def column_indices_t(self):
return self._column_indices_t
@property
def block_offsets_t(self):
return self._block_offsets_t
@property
def dtype(self):
return self.data.dtype
@property
def nnz(self):
return self.data.numel()
@property
def blocking(self):
return self.data.shape[1]
@property
def requires_grad(self):
return self.data.requires_grad
def requires_grad_(self, x):
self.data.requires_grad_(x)
return self
def view(self, *shape):
assert self.is_contiguous()
if shape[-1] != self.size()[-1]:
raise ValueError(
"Can't change view on compressed dimension. "
f"{self.size()[-1]} v. {shape[-1]}.")
if np.prod(shape) != np.prod(self.size()):
raise ValueError(
"Mismatch in numel of Matrix and new shape. "
f"{np.prod(self.size())} v. {np.prod(shape)}")
return Matrix(shape,
self.data,
self.row_indices,
self.column_indices,
self.offsets,
self.column_indices_t,
self.offsets_t,
self.block_offsets_t)
@property
def grad(self):
size = self.size()
if not self.is_contiguous():
size = torch.Size((size[1], size[0]))
out = Matrix(size,
self.data.grad,
self.row_indices,
self.column_indices,
self.offsets,
self.column_indices_t,
self.offsets_t,
self.block_offsets_t)
return out if self.is_contiguous() else out.t()
@torch.no_grad()
def _expand_for_blocking(idxs, blocking):
idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
idxs[:, :, 1] *= blocking
idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
idxs = idxs.repeat(1, blocking, 1, 1)
idxs[:, :, :, 0] *= blocking
idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
idxs = torch.reshape(idxs, [-1, 2])
return idxs
@torch.no_grad()
def to_dense(x):
assert isinstance(x, Matrix)
shape = (np.prod(x.shape[:-1]), x.shape[-1])
row_idxs = x.row_indices.type(torch.int32)
col_idxs = x.column_indices.type(torch.int32)
indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
out.scatter_(0, indices, x.data.flatten())
return out.reshape(x.size())
@dataclass
class TritonConfig:
BLOCK_M: int = 128
BLOCK_N: int = 128
BLOCK_K: int = 32
BLOCK_SIZE: int = 64
NUM_STAGES: int = 4
NUM_WARPS: int = 4
@triton.autotune(
configs=[
triton.Config({}, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
],
key=['M', 'N', 'K'],
)
@triton.jit
def _dds_kernel(
A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
stride_ac, stride_am, stride_ak,
stride_bk, stride_bn,
stride_cc, stride_cm, stride_cn,
row_indices: tl.tensor, column_indices: tl.tensor,
offsets: tl.tensor, block_offsets_t: tl.tensor,
fuse_srgb: tl.constexpr, clamp_output: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
):
pid_c = tl.program_id(0)
pid_m = tl.program_id(1)
pid_n = tl.program_id(2)
num_pid_m = tl.num_programs(1)
num_pid_n = tl.num_programs(2)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
offsets += pid_n
start_inx = tl.load(offsets)
end_inx = tl.load(offsets + 1)
column_indices += start_inx
block_offsets_t += start_inx
BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
A_block_ptr = tl.make_block_ptr(
base=A + pid_c * stride_ac, shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_K),
order=(0, 1)
)
rn = tl.arange(0, BLOCK_N)
rbk = tl.arange(0, BLOCK_K)
B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float16)
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
bk_sub_incr = BLOCK_K * stride_bk
for block_inx in range(end_inx - start_inx):
a_col_idx = tl.load(column_indices + block_inx)
ptr_A = tl.advance(A_block_ptr, (0, a_col_idx * BLOCK_SIZE))
b_block_offset = tl.load(block_offsets_t + block_inx)
ptr_B = B + b_block_offset * BLOCK_ELEMENTS
for sub_block_inx in range(nsub_blocks):
a = tl.load(ptr_A)
b = tl.load(ptr_B)
acc = tl.dot(a, b, acc, out_dtype=tl.float16)
ptr_A = tl.advance(ptr_A, (0, BLOCK_K))
ptr_B += bk_sub_incr
if fuse_srgb:
acc = linear_to_srgb_triton(acc)
if clamp_output:
acc = tl.clamp(acc, 0.0, 1.0)
if fuse_srgb or clamp_output:
acc = acc.to(C.dtype.element_ty)
C_block_ptr = tl.make_block_ptr(
base=C + pid_c * stride_cc, shape=(O_M, O_N),
strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0)
)
tl.store(C_block_ptr, acc, boundary_check=(0, 1))
def triton_dds(
lhs: torch.Tensor,
rhs: Matrix,
fuse_srgb: bool = False,
clamp_output: bool = False,
output_mt: bool = False,
output_slice: None | Tuple[int,int] = None
):
assert isinstance(lhs, torch.Tensor)
assert isinstance(rhs, Matrix)
assert lhs.ndim == 3
CH = lhs.shape[0]
stride_ac = lhs.stride(0)
M, K = lhs.shape[-2:]
N = rhs.shape[-1]
if output_mt:
if output_slice is not None:
O_N, O_M = output_slice
out = torch.empty(
(*lhs.shape[:-2], *output_slice),
dtype=lhs.dtype,
device=lhs.device
)
else:
O_N, O_M = N, M
out = torch.empty(
(*lhs.shape[:-2], rhs.shape[1], lhs.shape[-2]),
dtype=lhs.dtype,
device=lhs.device
)
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
stride_cc = out.stride(-3)
else:
if output_slice is not None:
O_M, O_N = output_slice
out = torch.empty(
(*lhs.shape[:-2], *output_slice),
dtype=lhs.dtype,
device=lhs.device
)
else:
O_M, O_N = M, N
out = torch.empty(
(*lhs.shape[:-1], rhs.shape[1]),
dtype=lhs.dtype,
device=lhs.device
)
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
stride_cc = out.stride(-3)
trans_B = not rhs.is_contiguous()
trans_A = (lhs.stride(-2) > 1 and lhs.stride(-1) > 1)
assert trans_A == False, trans_B == False
assert lhs.shape[-1] <= rhs.shape[0], "incompatible dimensions"
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
if trans_B:
stride_bk, stride_bn = rhs.data.stride(2), rhs.data.stride(1)
b_column_indices, b_offsets = rhs.column_indices, rhs.offsets
else:
stride_bk, stride_bn = rhs.data.stride(1), rhs.data.stride(2)
b_column_indices, b_offsets = rhs.column_indices_t, rhs.offsets_t
grid = lambda META: (CH, triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
_dds_kernel[grid](
lhs, rhs.data, out, M, N, K, O_M, O_N,
stride_ac, stride_am, stride_ak,
stride_bk, stride_bn,
stride_cc, stride_cm, stride_cn,
rhs.row_indices, b_column_indices, b_offsets,
rhs.block_offsets_t, fuse_srgb, clamp_output,
GROUP_M=128, ACC_TYPE=tl.float16, BLOCK_M=min(rhs.data.shape[1], 64),
BLOCK_N=rhs.data.shape[1], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=min(rhs.data.shape[1], 64)
)
return out
@triton.autotune(
configs=[
triton.Config({}, num_stages=4, num_warps=2),
],
key=['BLOCK_SIZE', 'BLOCK_N'],
)
@triton.jit
def _dds_sbsc_kernel(
A: tl.tensor, B: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
stride_ac, stride_am, stride_ak,
stride_bb, stride_bk, stride_bn,
stride_cc, stride_cm, stride_cn,
block_offset: tl.constexpr,
fuse_srgb: tl.constexpr, clamp_output: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
pid_c = tl.program_id(2)
nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
start_row = block_offset * pid_n
A_block_ptr = tl.make_block_ptr(
base=A + pid_c * stride_ac, shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_M, start_row),
block_shape=(BLOCK_M, BLOCK_K),
order=(0, 1)
)
B_block_ptr = tl.make_block_ptr(
base=B + pid_n * stride_bb, shape=(BLOCK_SIZE, BLOCK_N),
strides=(stride_bk, stride_bn),
offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N),
order=(0, 1)
)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for block_slice in range(nsub_blocks):
a = tl.load(A_block_ptr, eviction_policy='evict_first', boundary_check=(0,), padding_option='zero')
b = tl.load(B_block_ptr, eviction_policy='evict_last')
acc = tl.dot(a, b, acc, out_dtype=tl.float32)
A_block_ptr = A_block_ptr.advance((0, BLOCK_K))
B_block_ptr = B_block_ptr.advance((BLOCK_K, 0))
if fuse_srgb:
acc = linear_to_srgb_triton(acc)
if clamp_output:
acc = tl.clamp(acc, 0.0, 1.0)
acc = acc.to(C.dtype.element_ty)
C_block_ptr = tl.make_block_ptr(
base=C + pid_c * stride_cc, shape=(O_M, O_N),
strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0)
)
tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs')
def triton_dds_sbsc(
lhs: torch.Tensor,
rhs: SBSCMatrix,
fuse_srgb: bool = False,
clamp_output: bool = False,
output_mt: bool = False,
output_slice: None | Tuple[int,int] = None
):
assert isinstance(lhs, torch.Tensor)
assert isinstance(rhs, SBSCMatrix)
assert lhs.ndim == 3
CH = lhs.shape[0]
stride_ac = lhs.stride(0)
M, K = lhs.shape[-2:]
N = rhs.size[-1]
if output_mt:
if output_slice is not None:
O_N, O_M = output_slice
out = torch.empty(
(*lhs.shape[:-2], *output_slice),
dtype=lhs.dtype,
device=lhs.device
)
else:
O_N, O_M = N, M
out = torch.empty(
(*lhs.shape[:-2], rhs.size[1], lhs.shape[-2]),
dtype=lhs.dtype,
device=lhs.device
)
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
stride_cc = out.stride(-3)
else:
if output_slice is not None:
O_M, O_N = output_slice
out = torch.empty(
(*lhs.shape[:-2], *output_slice),
dtype=lhs.dtype,
device=lhs.device
)
else:
O_M, O_N = M, N
out = torch.empty(
(*lhs.shape[:-1], rhs.size[1]),
dtype=lhs.dtype,
device=lhs.device
)
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
stride_cc = out.stride(-3)
assert lhs.shape[-1] <= rhs.size[0], f"incompatible dimensions: {lhs.shape[-1]} > {rhs.size[0]}"
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
stride_bb, stride_bk, stride_bn = rhs.data.stride(0), rhs.data.stride(1), rhs.data.stride(2)
grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH)
_dds_sbsc_kernel[grid](
lhs, rhs.data, out, M, N, K, O_M, O_N,
stride_ac, stride_am, stride_ak,
stride_bb, stride_bk, stride_bn,
stride_cc, stride_cm, stride_cn,
rhs.offset, fuse_srgb, clamp_output,
GROUP_M=32, ACC_TYPE=tl.float16, BLOCK_M=32,
BLOCK_N=rhs.data.shape[2], BLOCK_SIZE=rhs.data.shape[1], BLOCK_K=rhs.col_block_size
)
return out
from triton.language.extra import libdevice
@triton.autotune(
configs=[
triton.Config({}, num_stages=4, num_warps=2),
],
key=['BLOCK_SIZE', 'BLOCK_N'],
)
@triton.jit
def _dds_sbsc_zerorhs_kernel(
A: tl.tensor, C: tl.tensor, M, N, K, O_M, O_N,
stride_ac, stride_am, stride_ak,
stride_cc, stride_cm, stride_cn,
k, PAD, block_offset: tl.constexpr,
fuse_srgb: tl.constexpr, gamma_correction: tl.constexpr, clamp_output: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_m = tl.program_id(1)
pid_c = tl.program_id(2)
nsub_blocks = triton.cdiv(BLOCK_SIZE, BLOCK_K)
start_row = block_offset * pid_n
offs_k = (start_row + tl.arange(0, BLOCK_K)) * stride_ak
m_range = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
A_mask = (m_range < M)[None, :].broadcast_to(BLOCK_K, BLOCK_M)
A_M_ptr = A + pid_c * stride_ac + stride_am * m_range
b_k = ((start_row - PAD + tl.arange(0, BLOCK_K)).to(tl.float32) + 0.5) * k
b_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.float32) + 0.5
b_base = (b_k[None, :] - b_n[:, None])
acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float16)
for _ in tl.range(nsub_blocks):
A_ptr = A_M_ptr[None, :] + tl.minimum(tl.maximum(offs_k, PAD) - PAD, K - 1)[:, None]
b = magic_kernel_sharp_2021_triton(b_base) * k
b = b.to(tl.float16)
a = tl.load(A_ptr, mask=A_mask)
if fuse_srgb == 'input':
if gamma_correction == 'fast':
a = libdevice.fast_powf(a, 2.2).to(tl.float16)
elif gamma_correction == 'srgb':
a = srgb_to_linear_triton(a).to(tl.float16)
acc = tl.dot(b, a, acc, out_dtype=tl.float16)
offs_k += BLOCK_K * stride_ak
b_base += BLOCK_K * k
if fuse_srgb == 'output':
if gamma_correction == 'fast':
acc = libdevice.fast_powf(acc, 1.0/2.2)
elif gamma_correction == 'srgb':
acc = linear_to_srgb_triton(acc)
if clamp_output:
acc = tl.clamp(acc, 0.0, 1.0)
if fuse_srgb == 'output' or clamp_output:
acc = acc.to(C.dtype.element_ty)
C_block_ptr = tl.make_block_ptr(
base=C + pid_c * stride_cc, shape=(O_N, O_M),
strides=(stride_cn, stride_cm),
offsets=(pid_n * BLOCK_N, pid_m * BLOCK_M),
block_shape=(BLOCK_N, BLOCK_M),
order=(1, 0)
)
tl.store(C_block_ptr, acc, boundary_check=(0, 1), cache_modifier='.cs')
import math
def triton_dds_zerorhs_sbsc(
lhs: torch.Tensor,
target_size: int,
source_size: int,
kernel_window: float,
block_specs,
fuse_srgb: str = '',
gamma_correction: str = 'fast',
clamp_output: bool = False,
output_mt: bool = False,
output_slice: None | Tuple[int,int] = None
):
assert isinstance(lhs, torch.Tensor)
assert fuse_srgb in ['input', 'output', '']
assert gamma_correction in ['fast', 'srgb']
k = target_size / source_size
PAD = math.ceil((kernel_window - 0.5) / k)
offset, block_height, num_blocks, col_width = block_specs
assert lhs.ndim == 3
CH = lhs.shape[0]
stride_ac = lhs.stride(0)
M, K = lhs.shape[-2:]
N = target_size
if output_mt:
if output_slice is not None:
O_N, O_M = output_slice
out = torch.empty(
(*lhs.shape[:-2], *output_slice),
dtype=lhs.dtype,
device=lhs.device
)
else:
O_N, O_M = N, M
out = torch.empty(
(*lhs.shape[:-2], N, M),
dtype=lhs.dtype,
device=lhs.device
)
stride_cm, stride_cn = out.stride(-1), out.stride(-2)
stride_cc = out.stride(-3)
else:
if output_slice is not None:
O_M, O_N = output_slice
out = torch.empty(
(*lhs.shape[:-2], *output_slice),
dtype=lhs.dtype,
device=lhs.device
)
else:
O_M, O_N = M, N
out = torch.empty(
(*lhs.shape[:-2], M, N),
dtype=lhs.dtype,
device=lhs.device
)
stride_cm, stride_cn = out.stride(-2), out.stride(-1)
stride_cc = out.stride(-3)
stride_am, stride_ak = lhs.stride(-2), lhs.stride(-1)
grid = lambda META: (triton.cdiv(N, META['BLOCK_N']), triton.cdiv(M, META['BLOCK_M']), CH)
_dds_sbsc_zerorhs_kernel[grid](
lhs, out, M, N, K, O_M, O_N,
stride_ac, stride_am, stride_ak,
stride_cc, stride_cm, stride_cn,
k, PAD, offset, fuse_srgb, gamma_correction, clamp_output,
BLOCK_M=32, BLOCK_K=16, BLOCK_N=col_width, BLOCK_SIZE=block_height,
)
return out

View File

@ -0,0 +1,234 @@
"""Sharpfin transform classes for torchvision integration.
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
Imports patched: absolute sharpfin.X -> relative .X, torchvision guarded.
"""
import torch
import torch.nn.functional as F
try:
from torchvision.transforms.v2 import Transform
except ImportError:
class Transform:
_transformed_types = ()
def __init__(self):
pass
from .util import QuantHandling, ResizeKernel, SharpenKernel, srgb_to_linear, linear_to_srgb
from . import functional as SFF
from .cms import apply_srgb
import math
from typing import Any, Dict, Tuple
from PIL import Image
from .functional import _get_resize_kernel
from contextlib import nullcontext
try:
from .triton_functional import downscale_sparse
except ImportError:
downscale_sparse = None
# from Pytorch >= 2.6
set_stance = getattr(torch.compiler, "set_stance", None)
__all__ = ["ResizeKernel", "SharpenKernel", "QuantHandling"]
class Scale(Transform):
"""Rescaling transform supporting multiple algorithms with sRGB linearization."""
_transformed_types = (torch.Tensor,)
def __init__(self,
out_res: tuple[int, int] | int,
device: torch.device | str = torch.device('cpu'),
dtype: torch.dtype = torch.float32,
out_dtype: torch.dtype | None = None,
quantization: QuantHandling = QuantHandling.ROUND,
generator: torch.Generator | None = None,
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
sharpen_kernel: SharpenKernel | None = None,
do_srgb_conversion: bool = True,
use_sparse: bool = False,
):
super().__init__()
if isinstance(device, str):
device = torch.device(device)
if not dtype.is_floating_point:
raise ValueError("dtype must be a floating point type")
if dtype.itemsize == 1:
raise ValueError("float8 types are not supported due to severe accuracy issues and limited function support. float16 or float32 is recommended.")
if out_dtype is not None and not out_dtype.is_floating_point and out_dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
raise ValueError("out_dtype must be a torch float format or a torch unsigned int format")
if use_sparse:
assert device.type != 'cpu'
if resize_kernel != ResizeKernel.MAGIC_KERNEL_SHARP_2021:
raise NotImplementedError
self.use_sparse = use_sparse
if isinstance(out_res, int):
out_res = (out_res, out_res)
self.out_res = out_res
self.device = device
self.dtype = dtype
self.out_dtype = out_dtype if out_dtype is not None else dtype
self.do_srgb_conversion = do_srgb_conversion
if self.out_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
match quantization:
case QuantHandling.TRUNCATE:
self.quantize_function = lambda x: x.mul(torch.iinfo(self.out_dtype).max).to(self.out_dtype)
case QuantHandling.ROUND:
self.quantize_function = lambda x: x.mul(torch.iinfo(self.out_dtype).max).round().to(self.out_dtype)
case QuantHandling.STOCHASTIC_ROUND:
if generator is not None:
self.generator = torch.Generator(self.device)
else:
self.generator = generator
self.quantize_function = lambda x: SFF.stochastic_round(x, self.out_dtype, self.generator)
case QuantHandling.BAYER:
self.bayer_matrix = torch.tensor(SFF.generate_bayer_matrix(16), dtype=self.dtype, device=self.device) / 255
self.quantize_function = lambda x: self.apply_bayer_matrix(x)
case _:
raise ValueError(f"Unknown quantization handling type {quantization}")
else:
self.quantize_function = lambda x: x.to(dtype=out_dtype)
self.resize_kernel, self.kernel_window = _get_resize_kernel(resize_kernel)
match sharpen_kernel:
case SharpenKernel.SHARP_2013:
kernel = torch.tensor([-1, 6, -1], dtype=dtype, device=device) / 4
self.sharp_2013_kernel = torch.outer(kernel, kernel).view(1, 1, 3, 3).expand(3, -1, -1, -1)
self.sharpen_step = lambda x: SFF.sharpen_conv2d(x, self.sharp_2013_kernel, 1)
case SharpenKernel.SHARP_2021:
kernel = torch.tensor([-1, 6, -35, 204, -35, 6, -1], dtype=dtype, device=device) / 144
self.sharp_2021_kernel = torch.outer(kernel, kernel).view(1, 1, 7, 7).expand(3, -1, -1, -1)
self.sharpen_step = lambda x: SFF.sharpen_conv2d(x, self.sharp_2021_kernel, 3)
case None:
self.sharpen_step = lambda x: x
case _:
raise ValueError(f"Unknown sharpen kernel {sharpen_kernel}")
def apply_bayer_matrix(self, x: torch.Tensor):
H, W = x.shape[-2:]
b = self.bayer_matrix.repeat(1,1,math.ceil(H/16),math.ceil(W/16))[:,:,:H,:W]
return (x*255 + b).to(self.out_dtype)
@torch.compile(disable=False)
def downscale(self, image: torch.Tensor, out_res: tuple[int, int]):
H, W = out_res
image = image.to(dtype=self.dtype)
if self.do_srgb_conversion:
image = srgb_to_linear(image)
image = SFF._downscale_axis(image, W, self.kernel_window, self.resize_kernel, self.device, self.dtype)
image = SFF._downscale_axis(image.mT, H, self.kernel_window, self.resize_kernel, self.device, self.dtype).mT
image = self.sharpen_step(image)
if self.do_srgb_conversion:
image = linear_to_srgb(image)
image = image.clamp(0,1)
image = self.quantize_function(image)
return image
@torch.compile(disable=False)
def downscale_sparse(self, image: torch.Tensor, out_res: tuple[int, int]):
image = image.to(dtype=self.dtype)
if downscale_sparse is not None:
image = downscale_sparse(image, out_res)
image = self.quantize_function(image)
return image
@torch.compile(disable=False)
def upscale(self, image: torch.Tensor, out_res: tuple[int, int]):
H, W = out_res
image = image.to(dtype=self.dtype)
if self.do_srgb_conversion:
image = srgb_to_linear(image)
image = self.sharpen_step(image)
image = SFF._upscale_axis(image, W, self.kernel_window, self.resize_kernel, self.device, self.dtype)
image = SFF._upscale_axis(image.mT, H, self.kernel_window, self.resize_kernel, self.device, self.dtype).mT
if self.do_srgb_conversion:
image = linear_to_srgb(image)
image = image.clamp(0,1)
image = self.quantize_function(image)
return image
def _transform(self, inpt: torch.Tensor, params: Dict[str, Any]) -> torch.Tensor:
image = inpt.to(device=self.device)
context_manager = (
set_stance("force_eager") if set_stance and self.device.type == "cpu" else nullcontext()
)
with context_manager:
if image.shape[-1] <= self.out_res[-1] and image.shape[-2] <= self.out_res[-2]:
return self.upscale(image, self.out_res)
elif image.shape[-1] >= self.out_res[-1] and image.shape[-2] >= self.out_res[-2]:
if self.use_sparse:
return self.downscale_sparse(image, self.out_res)
return self.downscale(image, self.out_res)
else:
raise ValueError("Mixed axis resizing (e.g. scaling one axis up and the other down) is not supported. File a bug report with your use case if needed.")
class ApplyCMS(Transform):
"""Apply color management to a PIL Image to standardize it to sRGB color space."""
_transformed_types = (Image.Image,)
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
if not isinstance(inpt, Image.Image):
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
return apply_srgb(inpt)
class AlphaComposite(Transform):
_transformed_types = (Image.Image,)
def __init__(
self,
background: Tuple[int,int,int] = (255, 255, 255)
):
super().__init__()
self.background = background
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
if not isinstance(inpt, Image.Image):
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
if not inpt.has_transparency_data:
return inpt
bg = Image.new("RGB", inpt.size, self.background).convert('RGBA')
return Image.alpha_composite(bg, inpt).convert('RGB')
class AspectRatioCrop(Transform):
_transformed_types = (Image.Image,)
def __init__(
self,
width: int,
height: int,
):
super().__init__()
self.ref_width = width
self.ref_height = height
self.aspect_ratio = width / height
def _transform(self, inpt: Image.Image, params: Dict[str, Any]) -> Image.Image:
if not isinstance(inpt, Image.Image):
raise TypeError(f"inpt should be PIL Image. Got {type(inpt)}")
left, top, right, bottom = 0, 0, inpt.width, inpt.height
inpt_ar = inpt.width / inpt.height
if inpt_ar > self.aspect_ratio:
result_width = int(round(inpt.height / self.ref_height * self.ref_width))
crop_amt = (inpt.width - result_width) // 2
left += crop_amt
right -= crop_amt
elif inpt_ar < self.aspect_ratio:
result_height = int(round(inpt.width / self.ref_width * self.ref_height))
crop_amt = (inpt.height - result_height) // 2
top += crop_amt
bottom -= crop_amt
return inpt.crop((left, top, right, bottom))

View File

@ -0,0 +1,708 @@
"""Sharpfin Triton-accelerated GPU scaling functions.
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
Imports patched: absolute sharpfin.X -> relative .X
Requires: triton (only available on CUDA platforms)
"""
import torch
import math
import triton
import triton.language as tl
from .util import ResizeKernel
from typing import Tuple
import torch.nn.functional as F
from triton.language.extra import libdevice
from .util import linear_to_srgb, srgb_to_linear
# Magic Kernel Sharp with Triton optimizations. Mainly converted to polynomials so that
# FMA operators can be used.
@triton.jit
def magic_kernel_sharp_2021_triton(x: tl.tensor):
out = tl.zeros_like(x) # inplace operation doesn't help much.
x = tl.abs(x)
lte_05 = x <= 0.5
lte_15 = x <= 1.5
lte_25 = x <= 2.5
lte_35 = x <= 3.5
lte_45 = x <= 4.5
x_sq = x*x # triton would compile like this anyways but it helps readability
out = tl.where(lte_05, tl.fma(x_sq, -239/144, 577/576), out)
out = tl.where(lte_15 and not lte_05, tl.fma(x_sq, 35/36, tl.fma(x, -379/144, 239/144)), out)
out = tl.where(lte_25 and not lte_15, tl.fma(x_sq, -1/6, tl.fma(x, 113/144, -65/72)), out)
out = tl.where(lte_35 and not lte_25, tl.fma(x_sq, 1/36, tl.fma(x, -3/16, 5/16)), out)
out = tl.where(lte_45 and not lte_35, tl.fma(x_sq, -1/288, tl.fma(x, 1/32, -9/128)), out)
return out
@triton.jit
def sinc_triton(x: tl.tensor):
y = tl.fma(x, math.pi, 1e-8)
return libdevice.fast_sinf(y) / y
@triton.jit
def lanczos_triton(x: tl.tensor, n: tl.constexpr = 3):
return tl.where(
tl.abs(x) < n,
sinc_triton(x) * sinc_triton(x/n),
0
)
# NOTE: there is no reason to use libdevice.pow, its only differences are with subnormals
@triton.jit
def linear_to_srgb_triton(x):
return tl.where(
x <= 0.0031308,
x * 12.92,
tl.fma(1.055, libdevice.fast_powf(x, 1/2.4), -0.055)
)
@triton.jit
def srgb_to_linear_triton(x):
return tl.where(
x <= 0.04045,
x / 12.92,
libdevice.fast_powf(tl.fma(1/1.055, x, 0.055/1.055), 2.4)
)
from .sparse_backend import triton_dds, triton_dds_sbsc, triton_dds_zerorhs_sbsc, Matrix, SBSCMatrix
def _get_resize_kernel_triton(k: ResizeKernel):
match k:
case ResizeKernel.NEAREST:
raise NotImplementedError
case ResizeKernel.BILINEAR:
raise NotImplementedError
case ResizeKernel.MITCHELL:
raise NotImplementedError
case ResizeKernel.CATMULL_ROM:
raise NotImplementedError
case ResizeKernel.B_SPLINE:
raise NotImplementedError
case ResizeKernel.LANCZOS2:
raise NotImplementedError
case ResizeKernel.LANCZOS3:
resize_kernel = lanczos_triton
kernel_window = 3.
case ResizeKernel.MAGIC_KERNEL:
raise NotImplementedError
case ResizeKernel.MAGIC_KERNEL_SHARP_2013:
raise NotImplementedError
case ResizeKernel.MAGIC_KERNEL_SHARP_2021:
resize_kernel = magic_kernel_sharp_2021_triton
kernel_window = 4.5
case _:
raise ValueError(f"Unknown resize kernel {k}")
return resize_kernel, kernel_window
# Sparse Downscale and support functions.
# Amanatides, John and Woo, Andrew -- Fast Voxel Traversal
def grid_line_tiles(x0, y0, x1, y1, grid_width, grid_height):
tiles = set()
dx = x1 - x0
dy = y1 - y0
x = math.floor(x0)
y = math.floor(y0)
end_x = math.floor(x1)
end_y = math.floor(y1)
step_x = 1 if dx > 0 else -1
step_y = 1 if dy > 0 else -1
t_max_x = ((x + (step_x > 0)) - x0) / dx if dx != 0 else float('inf')
t_max_y = ((y + (step_y > 0)) - y0) / dy if dy != 0 else float('inf')
t_delta_x = abs(1 / dx) if dx != 0 else float('inf')
t_delta_y = abs(1 / dy) if dy != 0 else float('inf')
while True:
if 0 <= x < grid_width and 0 <= y < grid_height:
tiles.add((y,x))
if x == end_x and y == end_y:
break
if t_max_x < t_max_y:
t_max_x += t_delta_x
x += step_x
else:
t_max_y += t_delta_y
y += step_y
return tiles
def tile_mask_function(dest_size, src_size, kernel_window=4.5, tile_size=64):
k = dest_size / src_size
PAD = math.ceil((kernel_window-0.5) / k)
grid_size = math.ceil((src_size + 2*PAD)/tile_size), math.ceil(dest_size/tile_size)
line_1 = 0, 0.5/tile_size, (dest_size)/tile_size, (src_size+0.5)/tile_size
line_2 = 0, (2*PAD - 0.5)/tile_size, (dest_size)/tile_size, (src_size + 2*PAD - 0.5)/tile_size
lines = line_1, line_2
mask = torch.zeros(grid_size, dtype=torch.bool)
tiles = set()
for (x0, y0, x1, y1) in lines:
tiles.update(grid_line_tiles(x0, y0, x1, y1, grid_size[1], grid_size[0]))
tiles = torch.tensor(list(tiles))
mask[tiles[:,0], tiles[:,1]] = True
return mask, tiles
def create_tensor_metadata(
tile_mask: torch.Tensor,
tiles: torch.Tensor,
indices: torch.Tensor,
offsets: torch.Tensor,
offsets_t: torch.Tensor,
):
indices[:,:2] = tiles
torch.argsort(indices[:,1], stable=True, out=indices[:,2]) # block_offsets_t
torch.take(indices[:,0], indices[:,2], out=indices[:,3]) # col_indices_t
# reusing the offsets buffer here helps performance
torch.sum(tile_mask, dim=1, out=offsets[1:])
torch.sum(tile_mask, dim=0, out=offsets_t[1:])
torch.cumsum(offsets, dim=0, out=offsets)
torch.cumsum(offsets_t, dim=0, out=offsets_t)
return indices, offsets, offsets_t
# for isolating the one mandatory graph break
@torch.compiler.disable
def _get_nnz_and_buffers(tile_mask):
num_sparse_blocks = torch.sum(tile_mask).item()
return [
torch.empty((4, num_sparse_blocks), dtype=torch.int64, pin_memory=True).T, # indices
torch.zeros((tile_mask.shape[0] + 1,), dtype=torch.int32, pin_memory=True), # offsets
torch.zeros((tile_mask.shape[1] + 1,), dtype=torch.int32, pin_memory=True) # offsets_t
]
def generate_sparse_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64):
tile_mask, tiles = tile_mask_function(dest_size, src_size, kernel_window, tile_size)
buffers = _get_nnz_and_buffers(tile_mask)
num_sparse_blocks = buffers[0].shape[0]
indices, offsets, offsets_t = create_tensor_metadata(
tile_mask,
tiles,
*buffers
)
indices = indices.to(device='cuda', dtype=torch.int32, non_blocking=True)
return Matrix(
(tile_mask.shape[0] * tile_size, tile_mask.shape[1] * tile_size),
torch.empty(num_sparse_blocks, tile_size, tile_size, dtype=torch.float16, device='cuda'),
row_indices=indices[:,0],
column_indices=indices[:,1],
offsets=offsets.to(device='cuda', non_blocking=True),
column_indices_t=indices[:,3],
offsets_t=offsets_t.to(device='cuda', non_blocking=True),
block_offsets_t=indices[:,2]
)
@triton.jit
def compute_sparse_coord_grid_kernel(
coords_source_ptr, coords_dest_ptr, sparse_data_ptr,
row_indices_ptr, col_indices_ptr,
k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr
):
SPARSE_BLOCK_NUMEL = SPARSE_BLOCK_SIZE * SPARSE_BLOCK_SIZE
sparse_block = tl.program_id(0)
tile_row = tl.program_id(1)
tile_col = tl.program_id(2)
row_offsets = tl.load(row_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_row * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
col_offsets = tl.load(col_indices_ptr + sparse_block) * SPARSE_BLOCK_SIZE + tile_col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_row = row_offsets < M
mask_col = col_offsets < N
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0)
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0)
x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
x = magic_kernel_sharp_2021_triton(x)
x *= k
sparse_block_ptr = sparse_data_ptr + sparse_block * SPARSE_BLOCK_NUMEL
local_row_start = tile_row * BLOCK_SIZE
local_col_start = tile_col * BLOCK_SIZE
local_rows = local_row_start + tl.arange(0, BLOCK_SIZE)
local_cols = local_col_start + tl.arange(0, BLOCK_SIZE)
local_rows_2d = local_rows[:, None]
local_cols_2d = local_cols[None, :]
store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d
tl.store(sparse_block_ptr + store_offset, x)
def compute_sparse_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64):
assert SPARSE_BLOCK_SIZE % BLOCK_SIZE == 0
k = target_size / source_size
PAD = math.ceil((kernel_window - 0.5) / k)
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
M, N = coords_source.shape[0], coords_dest.shape[0]
x = generate_sparse_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE)
SPARSE_NUM_BLOCKS = x.data.shape[0]
grid = lambda meta: (SPARSE_NUM_BLOCKS, triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']), triton.cdiv(SPARSE_BLOCK_SIZE, meta['BLOCK_SIZE']))
compute_sparse_coord_grid_kernel[grid](
coords_source, coords_dest, x.data,
x.row_indices, x.column_indices,
k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
return x
# Dense kernel for downsampling coord_grids
@triton.jit
def compute_coord_grid_kernel(
coords_source_ptr, coords_dest_ptr, coord_grid_ptr, k,
M, N, BLOCK_SIZE: tl.constexpr,
):
row_offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
col_offsets = tl.program_id(1) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_row = row_offsets < M
mask_col = col_offsets < N
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row)
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col)
x = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
x = magic_kernel_sharp_2021_triton(x)
x *= k
tl.store(coord_grid_ptr + row_offsets[:, None] * N + col_offsets[None, :], x, mask=mask_row[:, None] & mask_col[None, :])
def compute_coord_grid(target_size, source_size, kernel_window=4.5, BLOCK_SIZE=32):
k = target_size / source_size
PAD = math.ceil((kernel_window - 0.5) / k)
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
M, N = coords_source.shape[0], coords_dest.shape[0]
coord_grid = torch.empty((M, N), dtype=torch.float16, device='cuda')
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
compute_coord_grid_kernel[grid](coords_source, coords_dest, coord_grid, k, M, N, BLOCK_SIZE)
return coord_grid
@triton.jit
def pad_replicate_kernel(
A, B,
M_X, N_X,
M_Y, N_Y,
M_PAD, N_PAD,
stride_xc, stride_xm, stride_xn,
stride_yc, stride_ym, stride_yn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
fuse_linrgb: tl.constexpr
):
pid_c = tl.program_id(0)
pid_m = tl.program_id(1)
pid_n = tl.program_id(2)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m_cl = tl.maximum(offs_m, M_PAD) - M_PAD
offs_m_cl = tl.minimum(offs_m_cl, M_X - 1)
offs_n_cl = tl.maximum(offs_n, N_PAD) - N_PAD
offs_n_cl = tl.minimum(offs_n_cl, N_X - 1)
mask_m = offs_m < M_Y
mask_n = offs_n < N_Y
A_ptr = A + pid_c * stride_xc + offs_m_cl[:, None] * stride_xm + offs_n_cl[None, :] * stride_xn
B_ptr = B + pid_c * stride_yc + offs_m[:, None] * stride_ym + offs_n[None, :] * stride_yn
t = tl.load(A_ptr)
if fuse_linrgb:
t = srgb_to_linear_triton(t)
tl.store(B_ptr, t, mask=mask_m[:, None] & mask_n[None, :])
def pad_replicate(
img: torch.Tensor,
pad_h: int,
pad_w: int,
sparse_block_size: int = 0,
fuse_linrgb: bool = True,
):
C = img.shape[0]
M_PAD = pad_h
N_PAD = pad_w
if sparse_block_size != 0:
out_H = img.shape[-2] + M_PAD + (-(img.shape[-2] + M_PAD)) % sparse_block_size
out_W = img.shape[-1] + N_PAD + (-(img.shape[-1] + N_PAD)) % sparse_block_size
else:
out_H = img.shape[-2] + M_PAD + M_PAD
out_W = img.shape[-1] + N_PAD + N_PAD
out = torch.empty(C, out_H, out_W, dtype=img.dtype, device=img.device)
BLOCK_M = 1
BLOCK_N = 512
grid = lambda META: (
C,
(out.shape[1] + META['BLOCK_M'] - 1) // META['BLOCK_M'],
(out.shape[2] + META['BLOCK_N'] - 1) // META['BLOCK_N'],
)
pad_replicate_kernel[grid](
img, out,
img.shape[1], img.shape[2],
out.shape[1], out.shape[2],
M_PAD, N_PAD,
img.stride(0), img.stride(1), img.stride(2),
out.stride(0), out.stride(1), out.stride(2),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
fuse_linrgb=fuse_linrgb,
)
return out
def downscale_sparse(
image: torch.Tensor,
target_size: Tuple[int, int],
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
do_gamma_handling=True,
BLOCK_SIZE: int = 32,
SPARSE_BLOCK_SIZE: int = 64,
) -> torch.Tensor:
kernel, window = _get_resize_kernel_triton(resize_kernel)
T_W = target_size[-1]
T_H = target_size[-2]
S_W = image.shape[-1]
S_H = image.shape[-2]
y_s_w = compute_sparse_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
y_s_h = compute_sparse_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
PAD_W = math.ceil((window - 0.5) / (T_W / S_W))
PAD_H = math.ceil((window - 0.5) / (T_H / S_H))
image = pad_replicate(
image,
PAD_H,
PAD_W,
SPARSE_BLOCK_SIZE,
fuse_linrgb=do_gamma_handling
)
image = triton_dds(
image,
y_s_w,
output_mt=True
)
image = triton_dds(
image,
y_s_h,
fuse_srgb=do_gamma_handling,
clamp_output=True,
output_mt=True,
output_slice=(T_H, T_W)
)
return image
def downscale_triton(
image: torch.Tensor,
target_size: torch.Size,
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
do_gamma_handling=True,
) -> torch.Tensor:
kernel, window = _get_resize_kernel_triton(resize_kernel)
y_s_w = compute_coord_grid(target_size[-1], image.shape[-1], window)
y_s_h = compute_coord_grid(target_size[-2], image.shape[-2], window)
PAD_W = math.ceil((window - 0.5) / (target_size[-1] / image.shape[-1]))
PAD_H = math.ceil((window - 0.5) / (target_size[-2] / image.shape[-2]))
image = pad_replicate(image, PAD_H, PAD_W, fuse_linrgb=do_gamma_handling)
image = image.view(-1, image.shape[-1])
image = image @ y_s_w
image = image.view(3, -1, image.shape[-1])
image = image.mT
image = image.reshape(-1, image.shape[-1])
image = image @ y_s_h
image = image.view(3, -1, image.shape[-1])
image = image.mT
if do_gamma_handling:
image = linear_to_srgb(image[:, :target_size[0], :target_size[1]])
image.clamp_(0.,1.)
return image
# Single Block Sparse Column implementations.
def evaluate_line(x, x0, y0, x1, y1):
"""Evaluate the y-coordinate at a given x along a line from (x0, y0) to (x1, y1)."""
if x1 == x0:
return float('inf')
t = (x - x0) / (x1 - x0)
return y0 + t * (y1 - y0)
def pad_height_to_multiple(height, multiple):
"""Pad a height up to the next multiple of 'multiple'."""
return int(math.ceil(height / multiple) * multiple)
def generate_sbsc_structure(
dest_size,
src_size,
kernel_window=4.5,
tile_size=64,
y_tile_size=32
):
k = dest_size / src_size
PAD = math.ceil((kernel_window - 0.5) / k)
line1 = (0, 0.5, dest_size, src_size + 0.5)
line2 = (0, 2 * PAD - 0.5, dest_size, src_size + 2 * PAD - 0.5)
y_mins = []
y_maxs = []
n_blocks = math.ceil(dest_size / tile_size)
max_height = 0
for i in range(n_blocks):
x0 = i * tile_size
x1 = min(dest_size - 1, x0 + tile_size - 1)
yt0 = evaluate_line(x0, *line1)
yt1 = evaluate_line(x1, *line1)
yb0 = evaluate_line(x0, *line2)
yb1 = evaluate_line(x1, *line2)
y_min = min(yt0, yt1)
y_max = max(yb0, yb1)
height = y_max - y_min
padded = pad_height_to_multiple(height, y_tile_size)
y_mins.append(y_min)
y_maxs.append(y_max)
max_height = max(max_height, padded)
slope_top = (line1[3] - line1[1]) / (line1[2] - line1[0])
ideal_step = slope_top * tile_size
lower_bounds = []
upper_bounds = []
for i in range(1, n_blocks):
lower_bounds.append((y_maxs[i] - max_height) / i)
upper_bounds.append(y_mins[i] / i)
lower = math.ceil(max(lower_bounds)) if lower_bounds else 0
upper = math.floor(min(upper_bounds)) if upper_bounds else int(round(ideal_step))
fixed_offset = int(round(ideal_step))
if fixed_offset < lower:
fixed_offset = lower
elif fixed_offset > upper:
fixed_offset = upper
return fixed_offset, max_height, n_blocks, tile_size
def generate_sbsc_matrix(dest_size, src_size, kernel_window=4.5, tile_size=64, y_tile_size=32):
offset, block_height, num_blocks, col_width = generate_sbsc_structure(
dest_size, src_size, kernel_window, tile_size, y_tile_size
)
return SBSCMatrix(
size=((offset * (num_blocks - 1)) + block_height, dest_size),
data=torch.empty((num_blocks, block_height, col_width), dtype=torch.float16, device='cuda'),
offset=offset,
block_size=y_tile_size
)
@triton.jit
def compute_sbsc_coord_grid_kernel(
coords_source_ptr, coords_dest_ptr,
sparse_data_ptr, offset: tl.constexpr,
stride_xb, stride_xw, stride_xh,
k: float, M: int, N: int, BLOCK_SIZE: tl.constexpr, SPARSE_BLOCK_SIZE: tl.constexpr
):
pid_w = tl.program_id(0)
pid_h = tl.program_id(1)
start_row = offset * pid_w + pid_h * BLOCK_SIZE
start_col = pid_w * SPARSE_BLOCK_SIZE
row_offsets = start_row + tl.arange(0, BLOCK_SIZE)
col_offsets = start_col + tl.arange(0, SPARSE_BLOCK_SIZE)
mask_row = row_offsets < M
mask_col = col_offsets < N
coord_source = tl.load(coords_source_ptr + row_offsets, mask=mask_row, other=0.0)
coord_dest = tl.load(coords_dest_ptr + col_offsets, mask=mask_col, other=0.0)
y = tl.cast(coord_source[:, None] - coord_dest[None, :], tl.float16)
y = magic_kernel_sharp_2021_triton(y)
y *= k
sparse_block_ptr = sparse_data_ptr + pid_w * stride_xb
local_row_start = pid_h * BLOCK_SIZE
local_rows = local_row_start + tl.arange(0, BLOCK_SIZE)
local_cols = tl.arange(0, SPARSE_BLOCK_SIZE)
local_rows_2d = local_rows[:, None]
local_cols_2d = local_cols[None, :]
store_offset = local_rows_2d * SPARSE_BLOCK_SIZE + local_cols_2d
tl.store(sparse_block_ptr + store_offset, y)
def compute_sbsc_coord_grid(target_size, source_size, kernel_window, BLOCK_SIZE=32, SPARSE_BLOCK_SIZE=64):
k = target_size / source_size
PAD = math.ceil((kernel_window - 0.5) / k)
coords_source = torch.arange((-PAD + 0.5)*k, (source_size + PAD + 0.5)*k, k, dtype=torch.float32, device='cuda')
coords_dest = torch.arange(0.5, target_size + 0.5, 1, dtype=torch.float32, device='cuda')
M, N = coords_source.shape[0], coords_dest.shape[0]
x = generate_sbsc_matrix(target_size, source_size, kernel_window, SPARSE_BLOCK_SIZE, BLOCK_SIZE)
SPARSE_BLOCKS, BLOCK_HEIGHT, _ = x.data.shape
stride_xb, stride_xh, stride_xw = x.data.stride()
grid = lambda meta: (SPARSE_BLOCKS, triton.cdiv(BLOCK_HEIGHT, meta['BLOCK_SIZE']))
compute_sbsc_coord_grid_kernel[grid](
coords_source, coords_dest,
x.data, x.offset,
stride_xb, stride_xh, stride_xw,
k, M, N, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
return x
def downscale_sbsc(
image: torch.Tensor,
target_size: Tuple[int, int],
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
do_gamma_handling: bool = True,
BLOCK_SIZE: int = 32,
SPARSE_BLOCK_SIZE: int = 64,
) -> torch.Tensor:
kernel, window = _get_resize_kernel_triton(resize_kernel)
T_W = target_size[-1]
T_H = target_size[-2]
S_W = image.shape[-1]
S_H = image.shape[-2]
y_s_w = compute_sbsc_coord_grid(T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
y_s_h = compute_sbsc_coord_grid(T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE)
PAD_W = math.ceil((window - 0.5) / (T_W / S_W))
PAD_H = math.ceil((window - 0.5) / (T_H / S_H))
image = pad_replicate(
image,
PAD_H,
PAD_W,
fuse_linrgb=do_gamma_handling,
sparse_block_size=SPARSE_BLOCK_SIZE,
)
image = triton_dds_sbsc(
image,
y_s_w,
output_mt=True
)
image = triton_dds_sbsc(
image,
y_s_h,
fuse_srgb=do_gamma_handling,
clamp_output=True,
output_mt=True,
)
return image
def downscale_sbsc_zerorhs(
image: torch.Tensor,
target_size: Tuple[int, int],
resize_kernel: ResizeKernel = ResizeKernel.MAGIC_KERNEL_SHARP_2021,
do_gamma_handling=True,
gamma_handling_type: str = 'fast',
BLOCK_SIZE: int = 32,
SPARSE_BLOCK_SIZE: int = 64,
) -> torch.Tensor:
kernel, window = _get_resize_kernel_triton(resize_kernel)
T_W = target_size[-1]
T_H = target_size[-2]
S_W = image.shape[-1]
S_H = image.shape[-2]
block_specs_w = generate_sbsc_structure(
T_W, S_W, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE
)
block_specs_h = generate_sbsc_structure(
T_H, S_H, window, BLOCK_SIZE, SPARSE_BLOCK_SIZE
)
image = triton_dds_zerorhs_sbsc(
image,
T_W, S_W, window, block_specs_w,
fuse_srgb='input' if do_gamma_handling else '',
gamma_correction=gamma_handling_type,
output_mt=True
)
image = triton_dds_zerorhs_sbsc(
image,
T_H, S_H, window, block_specs_h,
fuse_srgb='output' if do_gamma_handling else '',
gamma_correction=gamma_handling_type,
clamp_output=True,
output_mt=True,
)
return image

49
modules/sharpfin/util.py Normal file
View File

@ -0,0 +1,49 @@
"""Sharpfin utility types and color space conversion functions.
Vendored from https://github.com/drhead/sharpfin (Apache 2.0)
"""
from enum import Enum
import torch
def srgb_to_linear(image: torch.Tensor) -> torch.Tensor:
return torch.where(
image <= 0.04045,
image / 12.92,
# Clamping is for protection against NaNs during backwards passes.
((torch.clamp(image, min=0.04045) + 0.055) / 1.055) ** 2.4
)
def linear_to_srgb(image: torch.Tensor) -> torch.Tensor:
return torch.where(
image <= 0.0031308,
image * 12.92,
torch.clamp(1.055 * image ** (1 / 2.4) - 0.055, min=0.0, max=1.0)
)
class ResizeKernel(Enum):
NEAREST = "nearest"
BILINEAR = "bilinear"
CATMULL_ROM = "catmull-rom"
MITCHELL = "mitchell"
B_SPLINE = "b-spline"
LANCZOS2 = "lanczos2"
LANCZOS3 = "lanczos3"
MAGIC_KERNEL = "magic_kernel"
MAGIC_KERNEL_SHARP_2013 = "magic_kernel_sharp_2013"
MAGIC_KERNEL_SHARP_2021 = "magic_kernel_sharp_2021"
class SharpenKernel(Enum):
SHARP_2013 = "sharp_2013"
SHARP_2021 = "sharp_2021"
class QuantHandling(Enum):
TRUNCATE = "truncate"
ROUND = "round"
STOCHASTIC_ROUND = "stochastic_round"
BAYER = "bayer"

View File

@ -109,7 +109,8 @@ class Upscaler:
if img.width >= dest_w and img.height >= dest_h:
break
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=Image.Resampling.LANCZOS)
from modules import images_sharpfin
img = images_sharpfin.resize(img, (int(dest_w), int(dest_h)))
shared.state.end(jobid)
return img

View File

@ -27,6 +27,8 @@ class UpscalerResize(Upscaler):
UpscalerData("Resize Bilinear", None, self),
UpscalerData("Resize Hamming", None, self),
UpscalerData("Resize Box", None, self),
UpscalerData("Resize Sharpfin MKS2021", None, self),
UpscalerData("Resize Sharpfin Lanczos3", None, self),
]
def do_upscale(self, img: Image, selected_model=None):
@ -44,6 +46,12 @@ class UpscalerResize(Upscaler):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=Image.Resampling.HAMMING)
elif selected_model == "Resize Box":
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=Image.Resampling.BOX)
elif selected_model == "Resize Sharpfin MKS2021":
from modules import images_sharpfin
return images_sharpfin.resize(img, (int(img.width * self.scale), int(img.height * self.scale)), kernel="Sharpfin MKS2021")
elif selected_model == "Resize Sharpfin Lanczos3":
from modules import images_sharpfin
return images_sharpfin.resize(img, (int(img.width * self.scale), int(img.height * self.scale)), kernel="Sharpfin Lanczos3")
else:
return img

View File

@ -25,15 +25,15 @@ class UpscalerSpandrel(Upscaler):
self.scalers.append(scaler)
def process(self, img: Image.Image) -> Image.Image:
import torchvision.transforms.functional as TF
tensor = TF.to_tensor(img).unsqueeze(0).to(devices.device)
from modules import images_sharpfin
tensor = images_sharpfin.to_tensor(img).unsqueeze(0).to(devices.device)
img = img.convert('RGB')
t0 = time.time()
with devices.inference_context():
tensor = self.model(tensor)
tensor = tensor.clamp(0, 1).squeeze(0).cpu()
t1 = time.time()
upscaled = TF.to_pil_image(tensor)
upscaled = images_sharpfin.to_pil(tensor)
log.debug(f'Upscale: name="{self.selected}" input={img.size} output={upscaled.size} time={t1 - t0:.2f}')
return upscaled

View File

@ -17,9 +17,8 @@ class UpscalerAsymmetricVAE(Upscaler):
def do_upscale(self, img: Image, selected_model=None):
if selected_model is None:
return img
import torchvision.transforms.functional as F
import diffusers
from modules import shared, devices
from modules import shared, devices, images_sharpfin
if self.vae is None or (selected_model != self.selected):
if 'v1' in selected_model:
repo_id = 'Heasterian/AsymmetricAutoencoderKLUpscaler'
@ -32,11 +31,11 @@ class UpscalerAsymmetricVAE(Upscaler):
self.selected = selected_model
shared.log.debug(f'Upscaler load: selected="{self.selected}" vae="{repo_id}"')
t0 = time.time()
img = img.resize((8 * (img.width // 8), 8 * (img.height // 8)), resample=Image.Resampling.LANCZOS).convert('RGB')
tensor = (F.pil_to_tensor(img).unsqueeze(0) / 255.0).to(device=devices.device, dtype=devices.dtype)
img = images_sharpfin.resize(img, (8 * (img.width // 8), 8 * (img.height // 8))).convert('RGB')
tensor = images_sharpfin.to_tensor(img).unsqueeze(0).to(device=devices.device, dtype=devices.dtype)
self.vae = self.vae.to(device=devices.device)
tensor = self.vae(tensor).sample
upscaled = F.to_pil_image(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
upscaled = images_sharpfin.to_pil(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
self.vae = self.vae.to(device=devices.cpu)
t1 = time.time()
shared.log.debug(f'Upscale: name="{self.selected}" input={img.size} output={upscaled.size} time={t1 - t0:.2f}')
@ -57,10 +56,9 @@ class UpscalerWanUpscale(Upscaler):
def do_upscale(self, img: Image, selected_model=None):
if selected_model is None:
return img
import torchvision.transforms.functional as F
import torch.nn.functional as FN
import diffusers
from modules import shared, devices
from modules import shared, devices, images_sharpfin
if (self.vae_encode is None) or (self.vae_decode is None) or (selected_model != self.selected):
repo_encode = 'Qwen/Qwen-Image-Edit-2509'
subfolder_encode = 'vae'
@ -79,7 +77,7 @@ class UpscalerWanUpscale(Upscaler):
t0 = time.time()
self.vae_encode = self.vae_encode.to(device=devices.device)
tensor = (F.pil_to_tensor(img).unsqueeze(0).unsqueeze(2) / 255.0).to(device=devices.device, dtype=devices.dtype)
tensor = images_sharpfin.to_tensor(img).unsqueeze(0).unsqueeze(2).to(device=devices.device, dtype=devices.dtype)
tensor = self.vae_encode.encode(tensor).latent_dist.mode()
self.vae_encode.to(device=devices.cpu)
@ -88,7 +86,7 @@ class UpscalerWanUpscale(Upscaler):
tensor = FN.pixel_shuffle(tensor.movedim(2, 1), upscale_factor=2).movedim(1, 2) # pixel shuffle needs [..., C, H, W] format
self.vae_decode.to(device=devices.cpu)
upscaled = F.to_pil_image(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
upscaled = images_sharpfin.to_pil(tensor.squeeze().clamp(0.0, 1.0).float().cpu())
t1 = time.time()
shared.log.debug(f'Upscale: name="{self.selected}" input={img.size} output={upscaled.size} time={t1 - t0:.2f}')
return upscaled

View File

@ -293,10 +293,9 @@ class FLitePipeline(DiffusionPipeline):
raise
# 8. Post-process images
from modules import images_sharpfin
images = (decoded_images / 2 + 0.5).clamp(0, 1)
# Convert to PIL Images
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu()
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images]
pil_images = [images_sharpfin.to_pil(img) for img in images]
return FLitePipelineOutput(
images=pil_images,

View File

@ -332,8 +332,8 @@ class StableCascadeDecoderPipelineFixed(diffusers.StableCascadeDecoderPipeline):
if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
images = self.numpy_to_pil(images)
from modules import images_sharpfin
images = [images_sharpfin.to_pil(images[i]) for i in range(images.shape[0])]
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
else:
images = latents

View File

@ -1,6 +1,6 @@
import numpy as np
import torch
import torchvision.transforms.functional as vF
from modules import images_sharpfin
import PIL
@ -13,7 +13,7 @@ def preprocess(image, processor, **kwargs):
elif isinstance(image, np.ndarray):
image = PIL.Image.fromarray(image)
elif isinstance(image, torch.Tensor):
image = vF.to_pil_image(image)
image = images_sharpfin.to_pil(image)
else:
raise TypeError(f"Image must be of type PIL.Image, np.ndarray, or torch.Tensor, got {type(image)} instead.")

View File

@ -14,7 +14,6 @@ from packaging import version
import PIL.Image
import numpy as np
import torch
import torchvision
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
@ -859,7 +858,8 @@ class StableDiffusionXLDiffImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixi
# 4. Preprocess image
#image = self.image_processor.preprocess(image) #ideally we would have preprocess the image with diffusers, but for this POC we won't --- it throws a deprecated warning
map = torchvision.transforms.Resize(tuple(s // self.vae_scale_factor for s in original_image.shape[2:]),antialias=None)(map)
from modules import images_sharpfin
map = images_sharpfin.resize_tensor(map, tuple(s // self.vae_scale_factor for s in original_image.shape[2:]), linearize=False)
# 5. Prepare timesteps
def denoising_value_valid(dnv):
return type(denoising_end) == float and 0 < dnv < 1
@ -1758,7 +1758,8 @@ class StableDiffusionDiffImg2ImgPipeline(DiffusionPipeline):
# 7. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
map = torchvision.transforms.Resize(tuple(s // self.vae_scale_factor for s in image.shape[2:]),antialias=None)(map)
from modules import images_sharpfin
map = images_sharpfin.resize_tensor(map, tuple(s // self.vae_scale_factor for s in image.shape[2:]), linearize=False)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@ -1833,8 +1834,7 @@ class StableDiffusionDiffImg2ImgPipeline(DiffusionPipeline):
import gradio as gr
import diffusers
from PIL import Image, ImageEnhance, ImageOps # pylint: disable=reimported
from torchvision import transforms
from modules import errors, shared, devices, scripts_manager, processing, sd_models, images
from modules import errors, shared, devices, scripts_manager, processing, sd_models, images, images_sharpfin
detector = None
@ -1888,9 +1888,9 @@ class Script(scripts_manager.Script):
else:
return None, None, None
image_mask = image_map.copy()
image_map = transforms.ToTensor()(image_map)
image_map = images_sharpfin.to_tensor(image_map)
image_map = image_map.to(devices.device)
image_init = 2 * transforms.ToTensor()(image_init) - 1
image_init = 2 * images_sharpfin.to_tensor(image_init) - 1
image_init = image_init.unsqueeze(0)
image_init = image_init.to(devices.device)
return image_init, image_map, image_mask

View File

@ -84,7 +84,7 @@ class Script(scripts_manager.Script):
from installer import install
install('lpips')
from torchvision.transforms import ToPILImage, ToTensor
from modules import images_sharpfin
from scripts.lbm import get_model, extract_object, resize_and_center_crop # pylint: disable=no-name-in-module
ori_h_bg, ori_w_bg = fg_image.size
@ -110,7 +110,7 @@ class Script(scripts_manager.Script):
if lbm_method == 'Simple':
output_image = img_pasted
else:
img_pasted_tensor = ToTensor()(img_pasted).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) * 2 - 1
img_pasted_tensor = images_sharpfin.to_tensor(img_pasted).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) * 2 - 1
batch = { "source_image": img_pasted_tensor }
z_source = model.vae.encode(batch[model.source_key])
output_image = model.sample(
@ -120,7 +120,7 @@ class Script(scripts_manager.Script):
max_samples=1,
)
output_image = (output_image[0].clamp(-1, 1).float().cpu() + 1) / 2
output_image = ToPILImage()(output_image)
output_image = images_sharpfin.to_pil(output_image)
if lbm_composite:
output_image = Image.composite(output_image, bg_image, fg_mask)

View File

@ -26,17 +26,13 @@ class Script(scripts_manager.Script):
def encode(self, p: processing.StableDiffusionProcessing, image: Image.Image):
if image is None:
return None
import numpy as np
import torch
from modules import images_sharpfin
if p.width is None or p.width == 0:
p.width = int(8 * (image.width * p.scale_by // 8))
if p.height is None or p.height == 0:
p.height = int(8 * (image.height * p.scale_by // 8))
image = images.resize_image(p.resize_mode, image, p.width, p.height, upscaler_name=p.resize_name, context=p.resize_context)
tensor = np.array(image).astype(np.float16) / 255.0
tensor = tensor[None].transpose(0, 3, 1, 2)
# image = image.transpose(0, 3, 1, 2)
tensor = torch.from_numpy(tensor).to(device=devices.device, dtype=devices.dtype)
tensor = images_sharpfin.to_tensor(image).unsqueeze(0).to(device=devices.device, dtype=devices.dtype)
tensor = 2.0 * tensor - 1.0
with devices.inference_context():
latent = shared.sd_model.vae.tiled_encode(tensor)

View File

@ -18,7 +18,6 @@ import cv2
import numpy as np
from PIL import Image, ImageFilter
import torch
import torchvision
from torchvision import transforms
from transformers import (
CLIPImageProcessor,
@ -1323,7 +1322,8 @@ class StableDiffusionXLSoftFillPipeline(
image.save("noised_image.png")
image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
image = transforms.ToTensor()(image)
from modules import images_sharpfin
image = images_sharpfin.to_tensor(image)
image = image * 2 - 1 # Normalize to [-1, 1]
return image.unsqueeze(0)
@ -1334,7 +1334,8 @@ class StableDiffusionXLSoftFillPipeline(
"""
map = map.convert("L")
map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
map = transforms.ToTensor()(map)
from modules import images_sharpfin
map = images_sharpfin.to_tensor(map)
map = (map - 0.05) / (0.95 - 0.05)
map = torch.clamp(map, 0.0, 1.0)
return 1.0 - map
@ -1349,9 +1350,8 @@ class StableDiffusionXLSoftFillPipeline(
# Prepare mask as rescaled tensor map
map = preprocess_map(mask).to(device)
map = torchvision.transforms.Resize(
tuple(s // self.vae_scale_factor for s in original_image_tensor.shape[2:]), antialias=None
)(map)
from modules import images_sharpfin
map = images_sharpfin.resize_tensor(map, tuple(s // self.vae_scale_factor for s in original_image_tensor.shape[2:]), linearize=False)
# Generate latent tensor with noise
original_with_noise = self.prepare_latents(