Merge pull request #804 from zanllp/fix/insufficient-random-media-count

fix(media): ensure minimum quantity in random image/video API
pull/807/head
zanllp 2025-05-25 20:35:55 +08:00 committed by GitHub
commit 387ef7875b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 14 deletions

View File

@ -642,6 +642,8 @@ def infinite_image_browsing_api(app: FastAPI, **kwargs):
"Cache-Control": "no-store",
},
)
if not is_media_file(path):
raise HTTPException(status_code=400, detail=f"{path} is not a video file")
# 如果缓存文件不存在,则生成缩略图并保存
import imageio.v3 as iio

View File

@ -266,22 +266,27 @@ class Image:
@classmethod
def get_random_images(cls, conn: Connection, size: int) -> List["Image"]:
images = []
max_cyc = 10
curr_cyc = 0
with closing(conn.cursor()) as cur:
cur.execute("SELECT COUNT(*) FROM image")
total_count = cur.fetchone()[0]
while len(images) < size and curr_cyc < max_cyc:
curr_cyc += 1
cur.execute("SELECT COUNT(*) FROM image")
total_count = cur.fetchone()[0]
if total_count == 0 or size <= 0:
return []
if total_count == 0 or size <= 0:
return []
step = max(1, total_count // size)
step = max(1, total_count // size)
start_indices = [random.randint(i * step, min((i + 1) * step - 1, total_count - 1)) for i in range(size)]
placeholders = ",".join("?" * len(start_indices))
cur.execute(f"SELECT * FROM image WHERE id IN ({placeholders})", start_indices)
rows = cur.fetchall()
images = [cls.from_row(row) for row in rows if os.path.exists(row[1])]
start_indices = [random.randint(i * step, min((i + 1) * step - 1, total_count - 1)) for i in range(size)]
placeholders = ",".join("?" * len(start_indices))
cur.execute(f"SELECT * FROM image WHERE id IN ({placeholders})", start_indices)
rows = cur.fetchall()
curr_images = [cls.from_row(row) for row in rows if os.path.exists(row[1])]
images.extend(curr_images)
images = unique_by(images, lambda x: x.path)
return images

View File

@ -13,7 +13,8 @@ import json
import zipfile
from PIL import Image
import shutil
# import magic
import filetype
sd_img_dirs = [
"outdir_txt2img_samples",
@ -200,6 +201,11 @@ def convert_to_bytes(file_size_str):
else:
raise ValueError(f"Invalid file size string '{file_size_str}'")
def is_video_simple(filepath):
kind = filetype.guess(filepath)
return kind and kind.mime.startswith('video/')
def get_video_type(file_path):
video_extensions = ['.mp4', '.m4v', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.ts']
file_extension = file_path[file_path.rfind('.'):].lower()
@ -218,7 +224,7 @@ def is_image_file(filename: str) -> bool:
return f".{extension}" in extensions
def is_video_file(filename: str) -> bool:
return isinstance(get_video_type(filename), str)
return isinstance(get_video_type(filename), str) and is_video_simple(filename)
def is_valid_media_path(path):
"""