diff --git a/modules/upscaler.py b/modules/upscaler.py index 44dfad48d..6bafa3db0 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import os from abc import abstractmethod +from typing import TYPE_CHECKING from PIL import Image from modules import modelloader, shared, paths from modules.logger import log +if TYPE_CHECKING: + from torch import Tensor + models = None @@ -92,10 +98,10 @@ class Upscaler: return scalers @abstractmethod - def do_upscale(self, img: Image, selected_model: str): + def do_upscale(self, img: Image.Image | Tensor, selected_model: str): return img - def upscale(self, img: Image, scale, selected_model: str = None): + def upscale(self, img: Image.Image | Tensor, scale, selected_model: str | None = None): jobid = shared.state.begin('Upscale') self.scale = scale if isinstance(img, Image.Image): @@ -153,10 +159,10 @@ class UpscalerData: name = None data_path = None scale: int = 4 - scaler: Upscaler = None + scaler: Upscaler | None = None model: None - def __init__(self, name: str, path: str = None, upscaler: Upscaler = None, scale: int = 4, model=None): + def __init__(self, name: str, path: str | None = None, upscaler: Upscaler | None = None, scale: int = 4, model=None): self.name = name self.data_path = path self.local_data_path = path