update image crop function

training_guidance
yuxiao 2024-04-22 06:39:15 +00:00
parent d9aa4cb4a5
commit 53dcf2f001
1 changed files with 41 additions and 34 deletions

View File

@ -1,16 +1,15 @@
import glob
import os
import cv2
import argparse
import math
from PIL import Image
import numpy as np
from utils import split_s3_path, upload_file_to_s3, download_file_from_s3
def resize_image(src_img_s3_path, max_resolution="512x512", divisible_by=2, interpolation=None):
def resize_image(src_img_s3_path, max_resolution="512x512", interpolation='lanczos'):
#split s3 path
bucket_name, key_src = split_s3_path(src_img_s3_path)
key_dst = os.path.dirname(key_src) + '_crop/' + os.path.basename(key_src)
key_dst = os.path.dirname(key_src) + '_' + max_resolution + os.path.basename(key_src)
#dowload file for src_img_s3_path
local_src_folder = os.path.join('/tmp', os.path.dirname(key_src))
local_dst_folder = os.path.join('/tmp', os.path.dirname(key_src) + '_crop')
@ -23,12 +22,12 @@ def resize_image(src_img_s3_path, max_resolution="512x512", divisible_by=2, inte
download_file_from_s3(bucket_name, key_src, local_src_file_path)
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
if interpolation == 'lanczos':
interpolation_type = Image.LANCZOS
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
interpolation_type = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
interpolation_type = Image.NEAREST
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
@ -42,43 +41,54 @@ def resize_image(src_img_s3_path, max_resolution="512x512", divisible_by=2, inte
image = Image.open(local_src_file_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
width = int(max_resolution.split("x")[0])
height = int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
current_height = image.size[1]
current_width = image.size[0]
# Check if the image needs resizing
if current_pixels > max_pixels:
if current_width > width and current_height > height:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
scale_factor_width = width / current_width
scale_factor_height = height / current_height
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
new_height, new_width = img.shape[0:2]
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
if scale_factor_height > scale_factor_width:
new_width = math.ceil(current_width * scale_factor_height)
image = image.resize((new_width, height), interpolation_type)
elif scale_factor_height < scale_factor_width:
new_height = math.ceil(current_height * scale_factor_width)
image = image.resize((width, new_height), interpolation_type)
else:
image = image.resize((width, height), interpolation_type)
resized_img = np.array(image)
new_img = np.zeros((height, width, 3), dtype=np.uint8)
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
new_y = 0
new_x = 0
height_dst = height
width_dst = width
y = int((resized_img.shape[0] - height) / 2)
if y < 0:
new_y = -y
height_dst = resized_img.shape[0]
y = 0
x = int((resized_img.shape[1] - width) / 2)
if x < 0:
new_x = -x
width_dst = resized_img.shape[1]
x = 0
new_img[new_y:new_y+height_dst, new_x:new_x+width_dst] = resized_img[y:y + height_dst, x:x + width_dst]
# Save resized image in dst_img_folder
image = Image.fromarray(img)
image = Image.fromarray(new_img)
image.save(local_dst_file_path, quality=100)
#proc = "Resized" if current_pixels > max_pixels else "Saved"
#print(f"{proc} image: {os.path.basename(key_src)} with size {img.shape[0]}x{img.shape[1]}")
upload_file_to_s3(local_dst_file_path, bucket_name, key_dst)
@ -88,8 +98,6 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument('src_img_s3_path', type=str, help='Source folder containing the images / 元画像のフォルダ')
parser.add_argument('--max_resolution', type=str,
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
@ -100,8 +108,7 @@ def main():
parser = setup_parser()
args = parser.parse_args()
resize_image(args.src_img_s3_path, args.dst_img_s3_path, args.max_resolution,
args.divisible_by, args.interpolation)
resize_image(args.src_img_s3_path, args.max_resolution, args.interpolation)
if __name__ == '__main__':