refactor: address remaining PR #4640 review comments

- Remove _get_device_dtype() indirection, inline device/dtype at call sites
- Remove commented-out fallback blocks and try/finally wrappers
- Add modules/sharpfin to ruff and pylint excludes in pyproject.toml
- Fix import ordering in joytag.py and pixelart.py
pull/4668/head
CalamitousFelicitousness 2026-02-10 19:59:42 +00:00 committed by vladmandic
parent 162651cbdb
commit dc8ecb0a64
4 changed files with 44 additions and 67 deletions

View File

@ -14,10 +14,10 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import QuickGELUActivation
import torchvision
from modules import images_sharpfin
import einops
from einops.layers.torch import Rearrange
import huggingface_hub
from modules import images_sharpfin
from modules import shared, devices, sd_models

View File

@ -4,7 +4,7 @@ Provides drop-in replacements for torchvision.transforms.functional operations
with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization,
and Triton GPU acceleration when available.
All public functions include try/except fallback to PIL/torchvision.
Non-CUDA devices fall back to PIL/torch.nn.functional automatically.
"""
import torch
@ -62,16 +62,6 @@ def _resolve_linearize(linearize=None, is_mask=False):
return shared.opts.resize_linearize_srgb
def _get_device_dtype(device=None, dtype=None):
"""Get device/dtype for sharpfin operations."""
from modules import devices
dev = device if device is not None else devices.device
if dtype is not None:
return dev, dtype
# float16 for CUDA (efficient), float32 for CPU/other (accurate)
return dev, torch.float16 if dev.type == 'cuda' else torch.float32
def _should_use_sharpfin(device=None):
"""Determine if sharpfin should be used based on device."""
if device is None:
@ -147,32 +137,27 @@ def _resize_pil(image, target_size, *, kernel=None, linearize=None, device=None,
w, h = target_size
if image.width == w and image.height == h:
return image
dev, dt = _get_device_dtype(device, dtype)
# Non-CUDA: use PIL (torchvision has optimized kernels for these devices)
from modules import devices
dev = device if device is not None else devices.device
if not _should_use_sharpfin(dev):
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
is_mask = image.mode == 'L'
rk = _resolve_kernel(kernel)
if rk is None:
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
try:
from modules.sharpfin.functional import scale
do_linear = _resolve_linearize(linearize, is_mask=is_mask)
tensor = to_tensor(image)
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
tensor = tensor.to(device=dev, dtype=dt)
out_res = (h, w) # sharpfin uses (H, W)
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
both_down = (h <= src_h and w <= src_w)
both_up = (h >= src_h and w >= src_w)
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
return to_pil(result)
# except Exception as e: # DEBUG: PIL fallback disabled for testing
# _get_log().warning(f"Sharpfin resize failed, falling back to PIL: {e}")
# return image.resize((w, h), resample=Image.Resampling.LANCZOS)
finally:
pass
from modules.sharpfin.functional import scale
dt = dtype if dtype is not None else torch.float16
do_linear = _resolve_linearize(linearize, is_mask=is_mask)
tensor = to_tensor(image)
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
tensor = tensor.to(device=dev, dtype=dt)
out_res = (h, w) # sharpfin uses (H, W)
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
both_down = (h <= src_h and w <= src_w)
both_up = (h >= src_h and w >= src_w)
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
return to_pil(result)
def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
@ -185,8 +170,8 @@ def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
linearize: sRGB linearization (default False for latent/mask data)
"""
_check()
dev, dt = _get_device_dtype()
# Non-CUDA: use F.interpolate (has optimized kernels for CPU/MPS/etc)
from modules import devices
dev = devices.device
if not _should_use_sharpfin(dev):
mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
@ -198,38 +183,29 @@ def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
return result.squeeze(0) if tensor.dim() == 3 else result
try:
from modules.sharpfin.functional import scale
dev, dt = _get_device_dtype()
squeezed = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
squeezed = True
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
th, tw = target_size
both_down = (th <= src_h and tw <= src_w)
both_up = (th >= src_h and tw >= src_w)
if both_down or both_up:
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
from modules.sharpfin.functional import scale
dt = torch.float16
squeezed = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
squeezed = True
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
th, tw = target_size
both_down = (th <= src_h and tw <= src_w)
both_up = (th >= src_h and tw >= src_w)
if both_down or both_up:
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
else:
if th > src_h:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
else:
if th > src_h:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
else:
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
if squeezed:
result = result.squeeze(0)
return result
# except Exception as e: # DEBUG: F.interpolate fallback disabled for testing
# _get_log().warning(f"Sharpfin resize_tensor failed, falling back to F.interpolate: {e}")
# mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
# inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
# result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
# return result.squeeze(0) if tensor.dim() == 3 else result
finally:
pass
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
if squeezed:
result = result.squeeze(0)
return result
def to_tensor(image):

View File

@ -3,14 +3,13 @@ from typing import List
import math
import torch
import numpy as np
from modules import images_sharpfin
from PIL import Image
from diffusers.utils import CONFIG_NAME
from diffusers.image_processor import PipelineImageInput
from diffusers.configuration_utils import ConfigMixin, register_to_config
from transformers import ImageProcessingMixin
from modules import images_sharpfin
from modules import devices

View File

@ -24,6 +24,7 @@ exclude = [
"modules/schedulers",
"modules/teacache",
"modules/seedvr",
"modules/sharpfin",
"modules/control/proc",
"modules/control/units",
@ -150,6 +151,7 @@ main.ignore-paths=[
"modules/prompt_parser_xhinker.py",
"modules/ras",
"modules/seedvr",
"modules/sharpfin",
"modules/rife",
"modules/schedulers",
"modules/taesd",