diff --git a/scripts/iib/db/datamodel.py b/scripts/iib/db/datamodel.py index 45c84ff..3f0b5a5 100644 --- a/scripts/iib/db/datamodel.py +++ b/scripts/iib/db/datamodel.py @@ -279,40 +279,65 @@ class Image: @classmethod def get_random_images(cls, conn: Connection, size: int) -> List["Image"]: + from scripts.iib.logger import logger + logger.info(f"Starting to get random images, requested size: {size}") images = [] max_cyc = 10 curr_cyc = 0 with closing(conn.cursor()) as cur: while len(images) < size and curr_cyc < max_cyc: curr_cyc += 1 + logger.info(f"Starting attempt {curr_cyc} to get random images") cur.execute("SELECT COUNT(*) FROM image") total_count = cur.fetchone()[0] + logger.info(f"Total images in database: {total_count}") if total_count == 0 or size <= 0: + logger.warning(f"Cannot get random images: total_count={total_count}, requested_size={size}") return [] step = max(1, total_count // size) + logger.info(f"Calculated step size: {step}") start_indices = [] for i in range(size): min_val = i * step max_val = min((i + 1) * step - 1, total_count - 1) - # 确保 max_val 不小于 min_val + # Ensure max_val is not less than min_val if max_val < min_val: max_val = min_val - # 确保索引在有效范围内 (1 到 total_count) + # Ensure indices are within valid range (1 to total_count) min_val = max(1, min(min_val, total_count)) max_val = max(1, min(max_val, total_count)) if min_val <= max_val: - start_indices.append(random.randint(min_val, max_val)) + idx = random.randint(min_val, max_val) + start_indices.append(idx) + logger.debug(f"Generated random index [{i}]: range {min_val}-{max_val}, selected {idx}") + + logger.info(f"Generated random index list: {start_indices}") if start_indices: placeholders = ",".join("?" * len(start_indices)) - cur.execute(f"SELECT * FROM image WHERE id IN ({placeholders})", start_indices) + query = f"SELECT * FROM image WHERE id IN ({placeholders})" + logger.debug(f"Executing SQL query: {query}, parameters: {start_indices}") + cur.execute(query, start_indices) rows = cur.fetchall() - curr_images = [cls.from_row(row) for row in rows if os.path.exists(row[1])] + logger.info(f"Query returned {len(rows)} records") + + curr_images = [] + for row in rows: + path = row[1] + if os.path.exists(path): + curr_images.append(cls.from_row(row)) + else: + logger.warning(f"Image file does not exist: {path}") + + logger.info(f"Valid images found in this cycle: {len(curr_images)}") images.extend(curr_images) images = unique_by(images, lambda x: x.path) + logger.info(f"Total unique images after deduplication: {len(images)}") + + logger.info(f"Random image retrieval completed, final image count: {len(images)}") return images