mirror of https://github.com/vladmandic/automatic
83 lines
3.4 KiB
Python
83 lines
3.4 KiB
Python
import sys
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image
|
|
from installer import log
|
|
|
|
|
|
def to_tensor(image: Image.Image | np.ndarray):
|
|
"""PIL Image -> float32 CHW tensor [0,1]. Pure torch, no torchvision."""
|
|
if isinstance(image, Image.Image):
|
|
pic = np.array(image, copy=True)
|
|
elif isinstance(image, np.ndarray):
|
|
pic = image.copy()
|
|
else:
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
raise TypeError(f"convert: target=tensor type={type(image)} fn={fn} unsupported")
|
|
if pic.ndim == 2:
|
|
pic = pic[:, :, np.newaxis]
|
|
tensor = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
|
|
# log.debug(f'Convert: source={type(image)} target={tensor.shape} fn={fn}')
|
|
if tensor.dtype == torch.uint8:
|
|
return tensor.to(torch.float32).div_(255.0)
|
|
return tensor.to(torch.float32)
|
|
|
|
|
|
def to_pil(tensor: torch.Tensor | np.ndarray):
|
|
"""Float CHW/HWC or BCHW/BHWC tensor [0,1] -> PIL Image. Pure torch, no torchvision."""
|
|
if isinstance(tensor, torch.Tensor):
|
|
tensor = tensor.detach().cpu()
|
|
elif isinstance(tensor, np.ndarray):
|
|
tensor = torch.from_numpy(tensor)
|
|
else:
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
raise TypeError(f"convert: target=image type={type(tensor)} fn={fn} unsupported")
|
|
try:
|
|
if tensor.dim() == 4:
|
|
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2]: # BHWC
|
|
tensor = tensor.permute(0, 3, 1, 2)
|
|
tensor = tensor[0]
|
|
elif tensor.dim() == 3:
|
|
if tensor.shape[-1] in (1, 3, 4) and tensor.shape[-1] < tensor.shape[-2] and tensor.shape[-1] < tensor.shape[-3]: # HWC
|
|
tensor = tensor.permute(2, 0, 1)
|
|
if tensor.dtype != torch.uint8:
|
|
tensor = (tensor.clamp(0, 1) * 255).round().to(torch.uint8)
|
|
ndarr = tensor.permute(1, 2, 0).numpy()
|
|
if ndarr.shape[2] == 1:
|
|
ndarr = ndarr[:, :, 0]
|
|
mode = 'L'
|
|
elif ndarr.shape[2] == 3:
|
|
mode = 'RGB'
|
|
else:
|
|
mode = 'RGBA'
|
|
image = Image.fromarray(ndarr, mode=mode)
|
|
except Exception as e:
|
|
image = Image.new('RGB', (tensor.shape[-1], tensor.shape[-2]), color=(152, 32, 48))
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
log.error(f'Convert: source={type(tensor)} target={image} fn={fn} {e}')
|
|
return image
|
|
|
|
|
|
def pil_to_tensor(image):
|
|
"""PIL Image -> uint8 CHW tensor (no float scaling). Replaces TF.pil_to_tensor."""
|
|
if not isinstance(image, Image.Image):
|
|
raise TypeError(f"Expected PIL Image, got {type(image)}")
|
|
pic = np.array(image, copy=True)
|
|
if pic.ndim == 2:
|
|
pic = pic[:, :, np.newaxis]
|
|
return torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
|
|
|
|
|
|
def normalize(tensor, mean, std, inplace=False):
|
|
"""Tensor normalization. Replaces TF.normalize."""
|
|
if not inplace:
|
|
tensor = tensor.clone()
|
|
mean_t = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
|
std_t = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
if mean_t.ndim == 1:
|
|
mean_t = mean_t[:, None, None]
|
|
if std_t.ndim == 1:
|
|
std_t = std_t[:, None, None]
|
|
tensor.sub_(mean_t).div_(std_t)
|
|
return tensor
|