268 lines
14 KiB
Python
268 lines
14 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import threading
|
|
import sys
|
|
import time
|
|
import pickle
|
|
import logging
|
|
import base64
|
|
logger = logging.getLogger(__name__)
|
|
# logger.setLevel(logging.INFO)
|
|
logging.basicConfig(level=logging.INFO) # Set logging level and STDOUT handler
|
|
|
|
import boto3
|
|
|
|
sys.path.insert(0, os.path.join(os.getcwd(), "extensions/stable-diffusion-aws-extension/"))
|
|
sys.path.append(os.path.join(os.getcwd(), "extensions/sd_dreambooth_extension"))
|
|
from utils import download_file_from_s3, download_folder_from_s3_by_tar, download_folder_from_s3, upload_file_to_s3, upload_folder_to_s3_by_tar
|
|
from utils import get_bucket_name_from_s3_path, get_path_from_s3_path
|
|
|
|
os.environ['IGNORE_CMD_ARGS_ERRORS'] = ""
|
|
from dreambooth.ui_functions import start_training
|
|
from dreambooth.shared import status
|
|
|
|
from utls import tar, mv
|
|
|
|
def sync_status_from_s3_json(bucket_name, webui_status_file_path, sagemaker_status_file_path):
|
|
while True:
|
|
time.sleep(1)
|
|
print(f'sagemaker status: {status.__dict__}')
|
|
try:
|
|
download_file_from_s3(bucket_name, webui_status_file_path, 'webui_status.json')
|
|
with open('webui_status.json') as webui_status_file:
|
|
webui_status = json.load(webui_status_file)
|
|
status.do_save_model = webui_status['do_save_model']
|
|
status.do_save_samples = webui_status['do_save_samples']
|
|
status.interrupted = webui_status['interrupted']
|
|
status.interrupted_after_save = webui_status['interrupted_after_save']
|
|
status.interrupted_after_epoch = webui_status['interrupted_after_epoch']
|
|
except Exception as e:
|
|
print('The webui status file is not exists')
|
|
print(e)
|
|
with open('sagemaker_status.json', 'w') as sagemaker_status_file:
|
|
json.dump(status.__dict__, sagemaker_status_file)
|
|
upload_file_to_s3('sagemaker_status.json', bucket_name, sagemaker_status_file_path)
|
|
|
|
def sync_status_from_s3_in_sagemaker(bucket_name, webui_status_file_path, sagemaker_status_file_path):
|
|
while True:
|
|
time.sleep(1)
|
|
print(status.__dict__)
|
|
try:
|
|
download_file_from_s3(bucket_name, webui_status_file_path, 'webui_status.pickle')
|
|
with open('webui_status.pickle', 'rb') as webui_status_file:
|
|
webui_status = pickle.load(webui_status_file)
|
|
status.do_save_model = webui_status['do_save_model']
|
|
status.do_save_samples = webui_status['do_save_samples']
|
|
status.interrupted = webui_status['interrupted']
|
|
status.interrupted_after_save = webui_status['interrupted_after_save']
|
|
status.interrupted_after_epoch = webui_status['interrupted_after_epoch']
|
|
except Exception as e:
|
|
print('The webui status file is not exists')
|
|
print(f'error: {e}')
|
|
with open('sagemaker_status.pickle', 'wb') as sagemaker_status_file:
|
|
pickle.dump(status, sagemaker_status_file)
|
|
upload_file_to_s3('sagemaker_status.pickle', bucket_name, sagemaker_status_file_path)
|
|
|
|
def train(model_dir):
|
|
start_training(model_dir)
|
|
|
|
def check_and_upload(local_path, bucket, s3_path):
|
|
while True:
|
|
time.sleep(1)
|
|
if os.path.exists(local_path):
|
|
print(f'upload {s3_path} to {local_path}')
|
|
upload_folder_to_s3_by_tar(local_path, bucket, s3_path)
|
|
else:
|
|
print(f'{local_path} is not exist')
|
|
|
|
def upload_model_to_s3(model_name, s3_output_path):
|
|
output_bucket_name = get_bucket_name_from_s3_path(s3_output_path)
|
|
local_path = os.path.join("models/Stable-diffusion", model_name)
|
|
s3_output_path = get_path_from_s3_path(s3_output_path)
|
|
logger.info(f"Upload check point to s3 {local_path} {output_bucket_name} {s3_output_path}")
|
|
print(f"Upload check point to s3 {local_path} {output_bucket_name} {s3_output_path}")
|
|
upload_folder_to_s3_by_tar(local_path, output_bucket_name, s3_output_path)
|
|
|
|
def upload_model_to_s3_v2(model_name, s3_output_path, model_type):
|
|
output_bucket_name = get_bucket_name_from_s3_path(s3_output_path)
|
|
s3_output_path = get_path_from_s3_path(s3_output_path).rstrip("/")
|
|
if model_type == "Stable-diffusion":
|
|
local_path = os.path.join(f"models/{model_type}", model_name)
|
|
elif model_type == "Lora":
|
|
local_path = f"models/{model_type}"
|
|
for root, dirs, files in os.walk(local_path):
|
|
for file in files:
|
|
if file.endswith('.safetensors'):
|
|
ckpt_name = re.sub('\.safetensors$', '', file)
|
|
safetensors = os.path.join(root, file)
|
|
if model_type == "Stable-diffusion":
|
|
yaml = os.path.join(root, f"{ckpt_name}.yaml")
|
|
output_tar = file
|
|
tar_command = f"tar cvf {output_tar} {safetensors} {yaml}"
|
|
elif model_type == "Lora":
|
|
output_tar = file
|
|
tar_command = f"tar cvf {output_tar} {safetensors}"
|
|
|
|
print(tar_command)
|
|
# os.system(tar_command)
|
|
tar(mode='c', archive=output_tar, sfiles=[safetensors, yaml], verbose=True)
|
|
logger.info(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
|
|
print(f"Upload check point to s3 {output_tar} {output_bucket_name} {s3_output_path}")
|
|
upload_file_to_s3(output_tar, output_bucket_name, s3_output_path)
|
|
|
|
def download_data(data_list, s3_data_path_list, s3_input_path):
|
|
for data, data_tar in zip(data_list, s3_data_path_list):
|
|
if len(data) == 0:
|
|
continue
|
|
target_dir = data
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
if data_tar.startswith("s3://"):
|
|
input_bucket_name = get_bucket_name_from_s3_path(data_tar)
|
|
input_path = get_path_from_s3_path(data_tar)
|
|
local_tar_path = data_tar.replace("s3://", "").replace("/", "-")
|
|
logger.info(f"Download data from s3 {input_bucket_name} {input_path} to {target_dir} {local_tar_path}")
|
|
download_folder_from_s3(input_bucket_name, input_path, target_dir)
|
|
else:
|
|
input_bucket_name = get_bucket_name_from_s3_path(s3_input_path)
|
|
input_path = os.path.join(get_path_from_s3_path(s3_input_path), data_tar)
|
|
local_tar_path = data_tar
|
|
logger.info(f"Download data from s3 {input_bucket_name} {input_path} to {target_dir} {local_tar_path}")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, local_tar_path, target_dir)
|
|
|
|
def prepare_for_training(s3_model_path, model_name, s3_input_path, data_tar_list, class_data_tar_list):
|
|
model_bucket_name = get_bucket_name_from_s3_path(s3_model_path)
|
|
s3_model_path = os.path.join(get_path_from_s3_path(s3_model_path), f'{model_name}.tar')
|
|
logger.info(f"Download src model from s3 {model_bucket_name} {s3_model_path} {model_name}.tar")
|
|
print(f"Download src model from s3 {model_bucket_name} {s3_model_path} {model_name}.tar")
|
|
download_folder_from_s3_by_tar(model_bucket_name, s3_model_path, f'{model_name}.tar')
|
|
|
|
input_bucket_name = get_bucket_name_from_s3_path(s3_input_path)
|
|
input_path = os.path.join(get_path_from_s3_path(s3_input_path), "db_config.tar")
|
|
logger.info(f"Download db_config from s3 {input_bucket_name} {input_path} db_config.tar")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, "db_config.tar")
|
|
download_db_config_path = f"models/sagemaker_dreambooth/{model_name}/db_config_cloud.json"
|
|
target_db_config_path = f"models/dreambooth/{model_name}/db_config.json"
|
|
logger.info(f"Move db_config to correct position {download_db_config_path} {target_db_config_path}")
|
|
# os.system(f"mv {download_db_config_path} {target_db_config_path}")
|
|
mv(download_db_config_path, target_db_config_path)
|
|
with open(target_db_config_path) as db_config_file:
|
|
db_config = json.load(db_config_file)
|
|
data_list = []
|
|
class_data_list = []
|
|
for concept in db_config["concepts_list"]:
|
|
data_list.append(concept["instance_data_dir"])
|
|
class_data_list.append(concept["class_data_dir"])
|
|
# hack_db_config(db_config, db_config_path, model_name, data_tar_list, class_data_tar_list)
|
|
download_data(data_list, data_tar_list, s3_input_path)
|
|
download_data(class_data_list, class_data_tar_list, s3_input_path)
|
|
|
|
def prepare_for_training_v2(s3_model_path, model_name, s3_input_path, s3_data_path_list, s3_class_data_path_list):
|
|
model_bucket_name = get_bucket_name_from_s3_path(s3_model_path)
|
|
s3_model_path = os.path.join(get_path_from_s3_path(s3_model_path), f'{model_name}.tar')
|
|
logger.info(f"Download src model from s3 {model_bucket_name} {s3_model_path} {model_name}.tar")
|
|
print(f"Download src model from s3 {model_bucket_name} {s3_model_path} {model_name}.tar")
|
|
download_folder_from_s3_by_tar(model_bucket_name, s3_model_path, f'{model_name}.tar')
|
|
|
|
# input_bucket_name = get_bucket_name_from_s3_path(s3_input_path)
|
|
input_path = os.path.join(get_path_from_s3_path(s3_input_path), "db_config.tar")
|
|
logger.info(f"Download db_config from s3 {input_bucket_name} {input_path} db_config.tar")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, "db_config.tar")
|
|
download_db_config_path = f"models/sagemaker_dreambooth/{model_name}/db_config_cloud.json"
|
|
target_db_config_path = f"models/dreambooth/{model_name}/db_config.json"
|
|
logger.info(f"Move db_config to correct position {download_db_config_path} {target_db_config_path}")
|
|
# os.system(f"mv {download_db_config_path} {target_db_config_path}")
|
|
mv(download_db_config_path, target_db_config_path)
|
|
with open(target_db_config_path) as db_config_file:
|
|
db_config = json.load(db_config_file)
|
|
data_list = []
|
|
class_data_list = []
|
|
for concept in db_config["concepts_list"]:
|
|
data_list.append(concept["instance_data_dir"])
|
|
class_data_list.append(concept["class_data_dir"])
|
|
# hack_db_config(db_config, db_config_path, model_name, data_tar_list, class_data_tar_list)
|
|
|
|
for s3_data_path, local_data_path in zip(data_list, s3_data_path_list):
|
|
if len(local_data_path) == 0:
|
|
continue
|
|
target_dir = local_data_path
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
input_bucket_name = get_bucket_name_from_s3_path(s3_data_path)
|
|
input_path = get_path_from_s3_path(s3_data_path)
|
|
logger.info(f"Download data from s3 {input_bucket_name} {input_path} to {target_dir}")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, target_dir)
|
|
for s3_class_data_path, local_class_data_path in zip(class_data_list, s3_class_data_path_list):
|
|
if len(local_class_data_path) == 0:
|
|
continue
|
|
target_dir = local_class_data_path
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
input_bucket_name = get_bucket_name_from_s3_path(s3_class_data_path)
|
|
input_path = get_path_from_s3_path(s3_class_data_path)
|
|
logger.info(f"Download data from s3 {input_bucket_name} {input_path} to {target_dir}")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, target_dir)
|
|
|
|
def sync_status(job_id, bucket_name, model_dir):
|
|
local_samples_dir = f'models/dreambooth/{model_dir}/samples'
|
|
upload_thread = threading.Thread(target=check_and_upload, args=(local_samples_dir, bucket_name, f'aigc-webui-test-samples/{job_id}'))
|
|
upload_thread.start()
|
|
sync_status_thread = threading.Thread(target=sync_status_from_s3_in_sagemaker,
|
|
args=(bucket_name, f'aigc-webui-test-status/{job_id}/webui_status.pickle',
|
|
f'aigc-webui-test-status/{job_id}/sagemaker_status.pickle'))
|
|
sync_status_thread.start()
|
|
|
|
def main(s3_input_path, s3_output_path, params):
|
|
import launch
|
|
launch.prepare_environment()
|
|
params = params["training_params"]
|
|
model_name = params["model_name"]
|
|
model_type = params["model_type"]
|
|
s3_model_path = params["s3_model_path"]
|
|
s3_data_path_list = params["data_tar_list"]
|
|
s3_class_data_path_list = params["class_data_tar_list"]
|
|
# s3_data_path_list = params["s3_data_path_list"]
|
|
# s3_class_data_path_list = params["s3_class_data_path_list"]
|
|
prepare_for_training(s3_model_path, model_name, s3_input_path, s3_data_path_list, s3_class_data_path_list)
|
|
# sync_status(job_id, bucket_name, model_dir)
|
|
train(model_name)
|
|
upload_model_to_s3_v2(model_name, s3_output_path, model_type)
|
|
|
|
def test():
|
|
model_name = "qiaohu-1-1"
|
|
s3_model_path = "s3://stable-diffusion-aws-5c2b588b-2023-05-19-02/Stable-diffusion/model/qiaohu-1-1/066dc5e6-1ed7-470a-ab9e-aa9e01cfe5e4/output"
|
|
s3_input_path = "s3://stable-diffusion-aws-5c2b588b-2023-05-19-02/Stable-diffusion/train/qiaohu-1-1/546290f8-9b0c-460b-be28-74a4ccad9a63/input"
|
|
s3_output_path = "s3://stable-diffusion-aws-5c2b588b-2023-05-19-02/Stable-diffusion/model/qiaohu-1-1/066dc5e6-1ed7-470a-ab9e-aa9e01cfe5e4/output"
|
|
training_params = {
|
|
"training_params": {
|
|
"model_name": model_name,
|
|
"s3_model_path": s3_model_path,
|
|
"data_tar_list": ["data_images-75.tar"],
|
|
"class_data_tar_list": [""],
|
|
}
|
|
}
|
|
main(s3_input_path, s3_output_path, training_params)
|
|
|
|
def parse_params(args):
|
|
def decode_base64(base64_message):
|
|
base64_bytes = base64_message.encode('ascii')
|
|
message_bytes = base64.b64decode(base64_bytes)
|
|
message = message_bytes.decode('ascii')
|
|
return message
|
|
s3_input_path = json.loads(decode_base64(args.s3_input_path))
|
|
s3_output_path = json.loads(decode_base64(args.s3_output_path))
|
|
params = json.loads(decode_base64(args.params))
|
|
return s3_input_path, s3_output_path, params
|
|
|
|
if __name__ == "__main__":
|
|
# test()
|
|
# sys.exit()
|
|
print(sys.argv)
|
|
import argparse
|
|
parser = argparse.ArgumentParser(add_help=False)
|
|
parser.add_argument("--params", type=str)
|
|
parser.add_argument("--s3-input-path", type=str)
|
|
parser.add_argument("--s3-output-path", type=str)
|
|
args, _ = parser.parse_known_args()
|
|
s3_input_path, s3_output_path, training_params = parse_params(args)
|
|
main(s3_input_path, s3_output_path, training_params)
|
|
|
|
# upload_model_to_s3_v2("test-1-1", "s3://stable-diffusion-aws-extension-991301791329-us-east-1/dreambooth/checkpoint/test-sd-type/test/") |