diff --git a/modules/image/grid.py b/modules/image/grid.py index 0a411b500..9ba8801fe 100644 --- a/modules/image/grid.py +++ b/modules/image/grid.py @@ -1,5 +1,6 @@ import math from collections import namedtuple +from typing import TYPE_CHECKING import numpy as np from PIL import Image, ImageFont, ImageDraw from modules import shared, script_callbacks @@ -26,7 +27,7 @@ def check_grid_size(imgs): return ok -def get_grid_size(imgs, batch_size=1, rows=None, cols=None): +def get_grid_size(imgs, batch_size=1, rows: int | None = None, cols: int | None = None): if rows and rows > len(imgs): rows = len(imgs) if cols and cols > len(imgs): @@ -34,12 +35,16 @@ def get_grid_size(imgs, batch_size=1, rows=None, cols=None): if rows is None and cols is None: if shared.opts.n_rows > 0: rows = shared.opts.n_rows + if TYPE_CHECKING: + assert isinstance(rows, int) cols = math.ceil(len(imgs) / rows) elif shared.opts.n_rows == 0: rows = batch_size cols = math.ceil(len(imgs) / rows) elif shared.opts.n_cols > 0: cols = shared.opts.n_cols + if TYPE_CHECKING: + assert isinstance(cols, int) rows = math.ceil(len(imgs) / cols) elif shared.opts.n_cols == 0: cols = batch_size @@ -49,16 +54,20 @@ def get_grid_size(imgs, batch_size=1, rows=None, cols=None): while len(imgs) % rows != 0: rows -= 1 cols = math.ceil(len(imgs) / rows) - elif cols is None: + return rows, cols + elif rows is not None and cols is None: cols = math.ceil(len(imgs) / rows) - elif rows is None: + elif rows is None and cols is not None: rows = math.ceil(len(imgs) / cols) else: + if TYPE_CHECKING: + assert isinstance(rows, int) + assert isinstance(cols, int) pass return rows, cols -def image_grid(imgs, batch_size:int=1, rows:int=None, cols:int=None): +def image_grid(imgs, batch_size=1, rows=1, cols=1): rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols) params = script_callbacks.ImageGridLoopParams(imgs, cols, rows) script_callbacks.image_grid_callback(params) diff --git a/modules/image/resize.py b/modules/image/resize.py index 1ff04b610..22841e6e5 100644 --- a/modules/image/resize.py +++ b/modules/image/resize.py @@ -8,7 +8,7 @@ from modules.logger import log from modules.image import sharpfin -def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None): +def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, height: int, upscaler_name: str | None = None, output_type: str = 'image', context: str | None = None): upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img def verify_image(image): @@ -95,7 +95,7 @@ def resize_image(resize_mode: int, im: Image.Image | torch.Tensor, width: int, h res.paste(im, box=((width - im.width)//2, (height - im.height)//2)) return res - def context_aware(im: Image.Image, width, height, context): + def context_aware(im: Image.Image, width: int, height: int, context: str): from installer import install install('seam-carving') width, height = int(width), int(height)