132 lines
6.2 KiB
Python
132 lines
6.2 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import sys
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(level=logging.INFO) # Set logging level and STDOUT handler
|
|
|
|
sys.path.insert(0, os.path.join(os.getcwd(), "extensions/stable-diffusion-aws-extension/"))
|
|
from utils_cn import download_folder_from_s3_by_tar, download_folder_from_s3, upload_file_to_s3
|
|
from utils_cn import get_bucket_name_from_s3_path, get_path_from_s3_path
|
|
|
|
os.environ['IGNORE_CMD_ARGS_ERRORS'] = ""
|
|
|
|
from utils_cn import tar, mv
|
|
|
|
|
|
|
|
def upload_model_to_s3_v2(model_name, s3_output_path, model_type, region):
|
|
output_bucket_name = get_bucket_name_from_s3_path(s3_output_path)
|
|
s3_output_path = get_path_from_s3_path(s3_output_path).rstrip("/")
|
|
logger.info("Upload the model file to s3.")
|
|
if model_type == "Stable-diffusion":
|
|
local_path = os.path.join(f"models/{model_type}", model_name)
|
|
elif model_type == "Lora":
|
|
local_path = f"models/{model_type}"
|
|
logger.info(f"Search model file in {local_path}.")
|
|
for root, dirs, files in os.walk(local_path):
|
|
logger.info(files)
|
|
for file in files:
|
|
if file.endswith('.safetensors'):
|
|
ckpt_name = re.sub('\.safetensors$', '', file)
|
|
safetensors = os.path.join(root, file)
|
|
print(f'model type: {model_type}')
|
|
if model_type == "Stable-diffusion":
|
|
yaml = os.path.join(root, f"{ckpt_name}.yaml")
|
|
output_tar = file
|
|
tar_command = f"tar cvf {output_tar} {safetensors} {yaml}"
|
|
print(tar_command)
|
|
# os.system(tar_command)
|
|
tar(mode='c', archive=output_tar, sfiles=[safetensors, yaml], verbose=True)
|
|
print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
|
|
upload_file_to_s3(output_tar, output_bucket_name, os.path.join(s3_output_path, model_name), region)
|
|
elif model_type == "Lora":
|
|
output_tar = file
|
|
tar_command = f"tar cvf {output_tar} {safetensors}"
|
|
print(tar_command)
|
|
# os.system(tar_command)
|
|
tar(mode='c', archive=output_tar, sfiles=[safetensors], verbose=True)
|
|
print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
|
|
upload_file_to_s3(output_tar, output_bucket_name, s3_output_path, region)
|
|
|
|
|
|
def download_data(data_list, s3_data_path_list, s3_input_path, region):
|
|
for data, data_tar in zip(data_list, s3_data_path_list):
|
|
if len(data) == 0:
|
|
continue
|
|
target_dir = data
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
if data_tar.startswith("s3://"):
|
|
input_bucket_name = get_bucket_name_from_s3_path(data_tar)
|
|
input_path = get_path_from_s3_path(data_tar)
|
|
local_tar_path = data_tar.replace("s3://", "").replace("/", "-")
|
|
logger.info(f"Download data from s3 {input_bucket_name} {input_path} to {target_dir} {local_tar_path}")
|
|
download_folder_from_s3(input_bucket_name, input_path, target_dir, region)
|
|
else:
|
|
input_bucket_name = get_bucket_name_from_s3_path(s3_input_path)
|
|
input_path = os.path.join(get_path_from_s3_path(s3_input_path), data_tar)
|
|
local_tar_path = data_tar
|
|
logger.info(f"Download data from s3 {input_bucket_name} {input_path} to {target_dir} {local_tar_path}")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, local_tar_path, target_dir, region)
|
|
|
|
|
|
def main(s3_input_path, s3_output_path, params, region):
|
|
os.system("df -h")
|
|
# import launch
|
|
# launch.prepare_environment()
|
|
model_name = params["model_name"]
|
|
model_type = params["model_type"]
|
|
s3_model_path = params["s3_model_path"]
|
|
s3_data_path_list = params["data_tar_list"]
|
|
s3_class_data_path_list = params["class_data_tar_list"]
|
|
# s3_data_path_list = params["s3_data_path_list"]
|
|
# s3_class_data_path_list = params["s3_class_data_path_list"]
|
|
print(f"s3_model_path {s3_model_path} model_name:{model_name} s3_input_path: {s3_input_path} s3_data_path_list:{s3_data_path_list} s3_class_data_path_list:{s3_class_data_path_list}")
|
|
os.system("df -h")
|
|
# sync_status(job_id, bucket_name, model_dir)
|
|
os.system("df -h")
|
|
os.system("ls -R models")
|
|
upload_model_to_s3_v2(model_name, s3_output_path, model_type, region)
|
|
os.system("df -h")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
os.environ['AWS_DEFAULT_REGION'] = 'cn-northwest-1'
|
|
print(sys.argv)
|
|
command_line_args = ' '.join(sys.argv[1:])
|
|
params = {}
|
|
s3_input_path = ''
|
|
s3_output_path = ''
|
|
args_list = command_line_args.split("--")
|
|
for arg in args_list:
|
|
if arg.strip().startswith("params"):
|
|
start_idx = arg.find("{")
|
|
end_idx = arg.rfind("}")
|
|
if start_idx != -1 and end_idx != -1:
|
|
params_str = arg[start_idx:end_idx+1]
|
|
try:
|
|
params_str = params_str.replace(" ", "").replace("\n", "").replace("'", "")
|
|
print(params_str)
|
|
params_str = params_str.replace(",,", ",")
|
|
print(params_str)
|
|
params_str = params_str.replace(",,", ",")
|
|
print(params_str)
|
|
fixed_string = params_str.replace('{', '{"').replace(':', '":"').replace(',', '","')\
|
|
.replace('}', '"}').replace("\"[", "[\"").replace("]\"", "\"]").replace('s3":"//', "s3://")
|
|
print(fixed_string)
|
|
params = json.loads(fixed_string)
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error decoding JSON: {e}")
|
|
if arg.strip().startswith("s3-input-path"):
|
|
start_idx = arg.find(":")
|
|
s3_input_path = f"s3{arg[start_idx:]}"
|
|
if arg.strip().startswith("s3-output-path"):
|
|
start_idx = arg.find(":")
|
|
s3_output_path = f"s3{arg[start_idx:]}"
|
|
training_params = params
|
|
print(training_params)
|
|
print(s3_input_path)
|
|
print(s3_output_path)
|
|
region = 'cn-northwest-1'
|
|
main(s3_input_path, s3_output_path, training_params, region) |