diff --git a/middleware_api/trainings/create_training_job.py b/middleware_api/trainings/create_training_job.py index 9e92adae..d5ed75af 100644 --- a/middleware_api/trainings/create_training_job.py +++ b/middleware_api/trainings/create_training_job.py @@ -100,10 +100,14 @@ def _trigger_sagemaker_training_job( train_job_name (str): training job name """ + site_packages_s3_path = (f"aws-gcr-solutions-{region}/" + f"stable-diffusion-aws-extension-github-mainline/{esd_version}/train.tar") + data = { "id": train_job.id, "training_id": train_job.id, "sagemaker_program": "extensions/sd-webui-sagemaker/sagemaker_entrypoint_json.py", + "site_packages_s3_path": site_packages_s3_path, "params": train_job.params, "s3-input-path": train_job.input_s3_location, "s3-output-path": ckpt_output_path, @@ -138,8 +142,7 @@ def _trigger_sagemaker_training_job( }, job_id=train_job.id, environment={ - "SITE_PACKAGES_S3_PATH": f"aws-gcr-solutions-{region}/" - f"stable-diffusion-aws-extension-github-mainline/{esd_version}/train.tar" + "SITE_PACKAGES_S3_PATH": site_packages_s3_path } ) est.fit(wait=False)