mirror of https://github.com/vladmandic/automatic
refactor: address remaining PR #4640 review comments
- Remove _get_device_dtype() indirection, inline device/dtype at call sites - Remove commented-out fallback blocks and try/finally wrappers - Add modules/sharpfin to ruff and pylint excludes in pyproject.toml - Fix import ordering in joytag.py and pixelart.pypull/4668/head
parent
162651cbdb
commit
dc8ecb0a64
|
|
@ -14,10 +14,10 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from transformers.activations import QuickGELUActivation
|
||||
import torchvision
|
||||
from modules import images_sharpfin
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
import huggingface_hub
|
||||
from modules import images_sharpfin
|
||||
from modules import shared, devices, sd_models
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Provides drop-in replacements for torchvision.transforms.functional operations
|
|||
with higher quality resampling (Magic Kernel Sharp 2021), sRGB linearization,
|
||||
and Triton GPU acceleration when available.
|
||||
|
||||
All public functions include try/except fallback to PIL/torchvision.
|
||||
Non-CUDA devices fall back to PIL/torch.nn.functional automatically.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
|
@ -62,16 +62,6 @@ def _resolve_linearize(linearize=None, is_mask=False):
|
|||
return shared.opts.resize_linearize_srgb
|
||||
|
||||
|
||||
def _get_device_dtype(device=None, dtype=None):
|
||||
"""Get device/dtype for sharpfin operations."""
|
||||
from modules import devices
|
||||
dev = device if device is not None else devices.device
|
||||
if dtype is not None:
|
||||
return dev, dtype
|
||||
# float16 for CUDA (efficient), float32 for CPU/other (accurate)
|
||||
return dev, torch.float16 if dev.type == 'cuda' else torch.float32
|
||||
|
||||
|
||||
def _should_use_sharpfin(device=None):
|
||||
"""Determine if sharpfin should be used based on device."""
|
||||
if device is None:
|
||||
|
|
@ -147,32 +137,27 @@ def _resize_pil(image, target_size, *, kernel=None, linearize=None, device=None,
|
|||
w, h = target_size
|
||||
if image.width == w and image.height == h:
|
||||
return image
|
||||
dev, dt = _get_device_dtype(device, dtype)
|
||||
# Non-CUDA: use PIL (torchvision has optimized kernels for these devices)
|
||||
from modules import devices
|
||||
dev = device if device is not None else devices.device
|
||||
if not _should_use_sharpfin(dev):
|
||||
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
is_mask = image.mode == 'L'
|
||||
rk = _resolve_kernel(kernel)
|
||||
if rk is None:
|
||||
return image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
try:
|
||||
from modules.sharpfin.functional import scale
|
||||
do_linear = _resolve_linearize(linearize, is_mask=is_mask)
|
||||
tensor = to_tensor(image)
|
||||
if tensor.dim() == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
tensor = tensor.to(device=dev, dtype=dt)
|
||||
out_res = (h, w) # sharpfin uses (H, W)
|
||||
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
|
||||
both_down = (h <= src_h and w <= src_w)
|
||||
both_up = (h >= src_h and w >= src_w)
|
||||
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
|
||||
return to_pil(result)
|
||||
# except Exception as e: # DEBUG: PIL fallback disabled for testing
|
||||
# _get_log().warning(f"Sharpfin resize failed, falling back to PIL: {e}")
|
||||
# return image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
||||
finally:
|
||||
pass
|
||||
from modules.sharpfin.functional import scale
|
||||
dt = dtype if dtype is not None else torch.float16
|
||||
do_linear = _resolve_linearize(linearize, is_mask=is_mask)
|
||||
tensor = to_tensor(image)
|
||||
if tensor.dim() == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
tensor = tensor.to(device=dev, dtype=dt)
|
||||
out_res = (h, w) # sharpfin uses (H, W)
|
||||
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
|
||||
both_down = (h <= src_h and w <= src_w)
|
||||
both_up = (h >= src_h and w >= src_w)
|
||||
result = _scale_pil(scale, tensor, out_res, rk, dev, dt, do_linear, src_h, src_w, h, w, both_down, both_up)
|
||||
return to_pil(result)
|
||||
|
||||
|
||||
def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
|
||||
|
|
@ -185,8 +170,8 @@ def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
|
|||
linearize: sRGB linearization (default False for latent/mask data)
|
||||
"""
|
||||
_check()
|
||||
dev, dt = _get_device_dtype()
|
||||
# Non-CUDA: use F.interpolate (has optimized kernels for CPU/MPS/etc)
|
||||
from modules import devices
|
||||
dev = devices.device
|
||||
if not _should_use_sharpfin(dev):
|
||||
mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
|
||||
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
|
||||
|
|
@ -198,38 +183,29 @@ def resize_tensor(tensor, target_size, *, kernel=None, linearize=False):
|
|||
inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
|
||||
result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
|
||||
return result.squeeze(0) if tensor.dim() == 3 else result
|
||||
try:
|
||||
from modules.sharpfin.functional import scale
|
||||
dev, dt = _get_device_dtype()
|
||||
squeezed = False
|
||||
if tensor.dim() == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
squeezed = True
|
||||
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
|
||||
th, tw = target_size
|
||||
both_down = (th <= src_h and tw <= src_w)
|
||||
both_up = (th >= src_h and tw >= src_w)
|
||||
if both_down or both_up:
|
||||
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
|
||||
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
|
||||
from modules.sharpfin.functional import scale
|
||||
dt = torch.float16
|
||||
squeezed = False
|
||||
if tensor.dim() == 3:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
squeezed = True
|
||||
src_h, src_w = tensor.shape[-2], tensor.shape[-1]
|
||||
th, tw = target_size
|
||||
both_down = (th <= src_h and tw <= src_w)
|
||||
both_up = (th >= src_h and tw >= src_w)
|
||||
if both_down or both_up:
|
||||
use_sparse = _triton_ok and dev.type == 'cuda' and rk.value == 'magic_kernel_sharp_2021' and both_down
|
||||
result = scale(tensor, target_size, resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=use_sparse)
|
||||
else:
|
||||
if th > src_h:
|
||||
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
else:
|
||||
if th > src_h:
|
||||
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
else:
|
||||
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
return result
|
||||
# except Exception as e: # DEBUG: F.interpolate fallback disabled for testing
|
||||
# _get_log().warning(f"Sharpfin resize_tensor failed, falling back to F.interpolate: {e}")
|
||||
# mode = 'bilinear' if target_size[0] * target_size[1] > tensor.shape[-2] * tensor.shape[-1] else 'area'
|
||||
# inp = tensor if tensor.dim() == 4 else tensor.unsqueeze(0)
|
||||
# result = torch.nn.functional.interpolate(inp, size=target_size, mode=mode, antialias=True)
|
||||
# return result.squeeze(0) if tensor.dim() == 3 else result
|
||||
finally:
|
||||
pass
|
||||
intermediate = scale(tensor, (th, src_w), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
result = scale(intermediate, (th, tw), resize_kernel=rk, device=dev, dtype=dt, do_srgb_conversion=linearize, use_sparse=False)
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
return result
|
||||
|
||||
|
||||
def to_tensor(image):
|
||||
|
|
|
|||
|
|
@ -3,14 +3,13 @@ from typing import List
|
|||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from modules import images_sharpfin
|
||||
|
||||
from PIL import Image
|
||||
from diffusers.utils import CONFIG_NAME
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from transformers import ImageProcessingMixin
|
||||
|
||||
from modules import images_sharpfin
|
||||
from modules import devices
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ exclude = [
|
|||
"modules/schedulers",
|
||||
"modules/teacache",
|
||||
"modules/seedvr",
|
||||
"modules/sharpfin",
|
||||
|
||||
"modules/control/proc",
|
||||
"modules/control/units",
|
||||
|
|
@ -150,6 +151,7 @@ main.ignore-paths=[
|
|||
"modules/prompt_parser_xhinker.py",
|
||||
"modules/ras",
|
||||
"modules/seedvr",
|
||||
"modules/sharpfin",
|
||||
"modules/rife",
|
||||
"modules/schedulers",
|
||||
"modules/taesd",
|
||||
|
|
|
|||
Loading…
Reference in New Issue