diff --git a/scripts/iib/api.py b/scripts/iib/api.py index 575185e..30f8091 100644 --- a/scripts/iib/api.py +++ b/scripts/iib/api.py @@ -642,6 +642,8 @@ def infinite_image_browsing_api(app: FastAPI, **kwargs): "Cache-Control": "no-store", }, ) + if not is_media_file(path): + raise HTTPException(status_code=400, detail=f"{path} is not a video file") # 如果缓存文件不存在,则生成缩略图并保存 import imageio.v3 as iio diff --git a/scripts/iib/db/datamodel.py b/scripts/iib/db/datamodel.py index ea3ed75..e76e9e4 100644 --- a/scripts/iib/db/datamodel.py +++ b/scripts/iib/db/datamodel.py @@ -266,22 +266,27 @@ class Image: @classmethod def get_random_images(cls, conn: Connection, size: int) -> List["Image"]: + images = [] + max_cyc = 10 + curr_cyc = 0 with closing(conn.cursor()) as cur: - cur.execute("SELECT COUNT(*) FROM image") - total_count = cur.fetchone()[0] + while len(images) < size and curr_cyc < max_cyc: + curr_cyc += 1 + cur.execute("SELECT COUNT(*) FROM image") + total_count = cur.fetchone()[0] - if total_count == 0 or size <= 0: - return [] + if total_count == 0 or size <= 0: + return [] - step = max(1, total_count // size) + step = max(1, total_count // size) - start_indices = [random.randint(i * step, min((i + 1) * step - 1, total_count - 1)) for i in range(size)] - - placeholders = ",".join("?" * len(start_indices)) - cur.execute(f"SELECT * FROM image WHERE id IN ({placeholders})", start_indices) - rows = cur.fetchall() - - images = [cls.from_row(row) for row in rows if os.path.exists(row[1])] + start_indices = [random.randint(i * step, min((i + 1) * step - 1, total_count - 1)) for i in range(size)] + placeholders = ",".join("?" * len(start_indices)) + cur.execute(f"SELECT * FROM image WHERE id IN ({placeholders})", start_indices) + rows = cur.fetchall() + curr_images = [cls.from_row(row) for row in rows if os.path.exists(row[1])] + images.extend(curr_images) + images = unique_by(images, lambda x: x.path) return images diff --git a/scripts/iib/tool.py b/scripts/iib/tool.py index 618bb9c..8c5cf80 100755 --- a/scripts/iib/tool.py +++ b/scripts/iib/tool.py @@ -13,7 +13,8 @@ import json import zipfile from PIL import Image import shutil - +# import magic +import filetype sd_img_dirs = [ "outdir_txt2img_samples", @@ -200,6 +201,11 @@ def convert_to_bytes(file_size_str): else: raise ValueError(f"Invalid file size string '{file_size_str}'") + +def is_video_simple(filepath): + kind = filetype.guess(filepath) + return kind and kind.mime.startswith('video/') + def get_video_type(file_path): video_extensions = ['.mp4', '.m4v', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.ts'] file_extension = file_path[file_path.rfind('.'):].lower() @@ -218,7 +224,7 @@ def is_image_file(filename: str) -> bool: return f".{extension}" in extensions def is_video_file(filename: str) -> bool: - return isinstance(get_video_type(filename), str) + return isinstance(get_video_type(filename), str) and is_video_simple(filename) def is_valid_media_path(path): """