diff --git a/modules/caption/joytag.py b/modules/caption/joytag.py index ac28e55a1..346e296f1 100644 --- a/modules/caption/joytag.py +++ b/modules/caption/joytag.py @@ -14,10 +14,10 @@ import torch.nn as nn import torch.nn.functional as F from transformers.activations import QuickGELUActivation import torchvision -from modules import images_sharpfin import einops from einops.layers.torch import Rearrange import huggingface_hub +from modules import images_sharpfin from modules import shared, devices, sd_models diff --git a/modules/images_sharpfin.py b/modules/images_sharpfin.py index 37ab8056c..11ab6504a 100644 --- a/modules/images_sharpfin.py +++ b/modules/images_sharpfin.py @@ -4,7 +4,7 @@ Provides drop-in replacements for torchvision.transforms.functional operations with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization, and Triton GPU acceleration when available. -All public functions include try/except fallback to PIL/torchvision. +Non-CUDA devices fall back to PIL/torch.nn.functional automatically. """ import torch @@ -62,16 +62,6 @@ def _resolve_linearize(linearize=None, is_mask=False): return shared.opts.resize_linearize_srgb -def _get_device_dtype(device=None, dtype=None): - """Get device/dtype for sharpfin operations.""" - from modules import devices - dev = device if device is not None else devices.device - if dtype is not None: - return dev, dtype - # float16 for CUDA (efficient), float32 for CPU/other (accurate) - return dev, torch.float16 if dev.type == 'cuda' else torch.float32 - - def _should_use_sharpfin(device=None): """Determine if sharpfin should be used based on device.""" if device is None: @@ -147,32 +137,27 @@ def _resize_pil(image, target_size, *, kernel=None, linearize=None, device=None, w, h = target_size if image.width == w and image.height == h: return image - dev, dt = _get_device_dtype(device, dtype) - # Non-CUDA: use PIL (torchvision has optimized kernels for these devices) + from modules import devices + dev = device if device is not None else devices.device if not _should_use_sharpfin(dev): return image.resize((w, h), resample=Image.Resampling.LANCZOS) is_mask = image.mode == 'L' rk = _resolve_kernel(kernel) if rk is None: return image.resize((w, h), resample=Image.Resampling.LANCZOS) - try: - from modules.sharpfin.functional import scale - do_linear = _resolve_linearize(linearize, is_mask=is_mask) - tensor = to_tensor(image) - if tensor.dim() == 3: - tensor = tensor.unsqueeze(0) - tensor = tensor.to(device=dev, dtype=dt) - out_res = (h, w) # sharpfin uses (H, W) - src_h, src_w = tensor.shape[-2], tensor.shape[-1] - both_down = (h <= src_h and w <= src_w) - both_up = (h >= src_h and w >= src_w) - result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up) - return to_pil(result) - # except Exception as e: # DEBUG: PIL fallback disabled for testing - # _get_log().warning(f"Sharpfin resize failed, falling back to PIL: {e}") - # return image.resize((w, h), resample=Image.Resampling.LANCZOS) - finally: - pass + from modules.sharpfin.functional import scale + dt = dtype if dtype is not None else torch.float16 + do_linear = _resolve_linearize(linearize, is_mask=is_mask) + tensor = to_tensor(image) + if tensor.dim() == 3: + tensor = tensor.unsqueeze(0) + tensor = tensor.to(device=dev, dtype=dt) + out_res = (h, w) # sharpfin uses (H, W) + src_h, src_w = tensor.shape[-2], tensor.shape[-1] + both_down = (h <= src_h and w <= src_w) + both_up = (h >= src_h and w >= src_w) + result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up) + return to_pil(result) def resize_tensor(tensor, target_size, *, kernel=None, linearize=False): @@ -185,8 +170,8 @@ def resize_tensor(tensor, target_size, *, kernel=None, linearize=False): linearize: sRGB linearization (default False for latent/mask data) """ _check() - dev, dt = _get_device_dtype() - # Non-CUDA: use F.interpolate (has optimized kernels for CPU/MPS/etc) + from modules import devices + dev = devices.device if not _should_use_sharpfin(dev): mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area' inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0) @@ -198,38 +183,29 @@ def resize_tensor(tensor, target_size, *, kernel=None, linearize=False): inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0) result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True) return result.squeeze(0) if tensor.dim() == 3 else result - try: - from modules.sharpfin.functional import scale - dev, dt = _get_device_dtype() - squeezed = False - if tensor.dim() == 3: - tensor = tensor.unsqueeze(0) - squeezed = True - src_h, src_w = tensor.shape[-2], tensor.shape[-1] - th, tw = target_size - both_down = (th <= src_h and tw <= src_w) - both_up = (th >= src_h and tw >= src_w) - if both_down or both_up: - use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down - result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse) + from modules.sharpfin.functional import scale + dt = torch.float16 + squeezed = False + if tensor.dim() == 3: + tensor = tensor.unsqueeze(0) + squeezed = True + src_h, src_w = tensor.shape[-2], tensor.shape[-1] + th, tw = target_size + both_down = (th <= src_h and tw <= src_w) + both_up = (th >= src_h and tw >= src_w) + if both_down or both_up: + use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down + result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse) + else: + if th > src_h: + intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) + result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) else: - if th > src_h: - intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) - result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) - else: - intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) - result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) - if squeezed: - result = result.squeeze(0) - return result - # except Exception as e: # DEBUG: F.interpolate fallback disabled for testing - # _get_log().warning(f"Sharpfin resize_tensor failed, falling back to F.interpolate: {e}") - # mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area' - # inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0) - # result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True) - # return result.squeeze(0) if tensor.dim() == 3 else result - finally: - pass + intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) + result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False) + if squeezed: + result = result.squeeze(0) + return result def to_tensor(image): diff --git a/modules/postprocess/pixelart.py b/modules/postprocess/pixelart.py index 8b1ef1254..f2b27e02b 100644 --- a/modules/postprocess/pixelart.py +++ b/modules/postprocess/pixelart.py @@ -3,14 +3,13 @@ from typing import List import math import torch import numpy as np -from modules import images_sharpfin - from PIL import Image from diffusers.utils import CONFIG_NAME from diffusers.image_processor import PipelineImageInput from diffusers.configuration_utils import ConfigMixin, register_to_config from transformers import ImageProcessingMixin +from modules import images_sharpfin from modules import devices diff --git a/pyproject.toml b/pyproject.toml index 372c6fad9..2ec740e67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ exclude = [ "modules/schedulers", "modules/teacache", "modules/seedvr", + "modules/sharpfin", "modules/control/proc", "modules/control/units", @@ -150,6 +151,7 @@ main.ignore-paths=[ "modules/prompt_parser_xhinker.py", "modules/ras", "modules/seedvr", + "modules/sharpfin", "modules/rife", "modules/schedulers", "modules/taesd",