import base64 import datetime import json import logging import os from dataclasses import dataclass from typing import Any, Optional import boto3 import sagemaker import time import tomli import tomli_w from common import const from common.const import LoraTrainType, PERMISSION_TRAIN_ALL from common.ddb_service.client import DynamoDbUtilsService from common.response import ( ok, not_found, ) from libs.common_tools import DecimalEncoder from libs.data_types import ( CheckPoint, CheckPointStatus, TrainJob, TrainJobStatus, ) from libs.utils import get_user_roles, permissions_check, response_error bucket_name = os.environ.get("S3_BUCKET") train_table = os.environ.get("TRAIN_TABLE") checkpoint_table = os.environ.get("CHECKPOINT_TABLE") user_table = os.environ.get("MULTI_USER_TABLE") region = os.environ.get("AWS_REGION") instance_type = os.environ.get("INSTANCE_TYPE") sagemaker_role_arn = os.environ.get("TRAIN_JOB_ROLE") image_uri = os.environ.get("TRAIN_ECR_URL") logger = logging.getLogger(__name__) logger.setLevel(os.environ.get("LOG_LEVEL") or logging.ERROR) ddb_service = DynamoDbUtilsService(logger=logger) s3 = boto3.client("s3", region_name=region) @dataclass class Event: params: dict[str, Any] lora_train_type: Optional[str] = LoraTrainType.KOHYA.value def _update_toml_file_in_s3( bucket_name: str, file_key: str, new_file_key: str, updated_params ): """Update and save a TOML file in an S3 bucket Args: bucket_name (str): S3 bucket name to save the TOML file file_key (str): TOML template file key new_file_key (str): TOML file with merged parameters updated_params (_type_): parameters to be merged """ try: response = s3.get_object(Bucket=bucket_name, Key=file_key) toml_content = response["Body"].read().decode("utf-8") toml_data = tomli.loads(toml_content) # Update parameters in the TOML data for section, params in updated_params.items(): if section in toml_data: for key, value in params.items(): toml_data[section][key] = value else: toml_data[section] = params updated_toml_content = tomli_w.dumps(toml_data) s3.put_object(Bucket=bucket_name, Key=new_file_key, Body=updated_toml_content) logger.info(f"Updated '{file_key}' in '{bucket_name}' successfully.") except Exception as e: logger.error(f"An error occurred when updating Kohya toml: {e}") def _json_encode_hyperparameters(hyperparameters): """Encode hyperparameters Args: hyperparameters : hyperparameters to be encoded Returns: Encoded hyperparameters """ new_params = {} for k, v in hyperparameters.items(): if region.startswith("cn"): new_params[k] = json.dumps(v, cls=DecimalEncoder) else: json_v = json.dumps(v, cls=DecimalEncoder) v_bytes = json_v.encode("ascii") base64_bytes = base64.b64encode(v_bytes) base64_v = base64_bytes.decode("ascii") new_params[k] = base64_v return new_params def _trigger_sagemaker_training_job( train_job: TrainJob, ckpt_output_path: str, train_job_name: str ): """Trigger a SageMaker training job Args: train_job (TrainJob): training job metadata ckpt_output_path (str): S3 path to store the trained model file train_job_name (str): training job name """ hyperparameters = _json_encode_hyperparameters( { "sagemaker_program": "extensions/sd-webui-sagemaker/sagemaker_entrypoint_json.py", "params": train_job.params, "s3-input-path": train_job.input_s3_location, "s3-output-path": ckpt_output_path, "training-type": train_job.params[ "training_type" ], # Available value: "kohya" } ) final_instance_type = instance_type if ( "training_params" in train_job.params and "training_instance_type" in train_job.params["training_params"] and train_job.params["training_params"]["training_instance_type"] ): final_instance_type = train_job.params["training_params"][ "training_instance_type" ] est = sagemaker.estimator.Estimator( image_uri, sagemaker_role_arn, instance_count=1, instance_type=final_instance_type, volume_size=125, base_job_name=f"{train_job_name}", hyperparameters=hyperparameters, job_id=train_job.id, ) est.fit(wait=False) while not est._current_job_name: time.sleep(1) train_job.sagemaker_train_name = est._current_job_name search_key = {"id": train_job.id} ddb_service.update_item( table=train_table, key=search_key, field_name="sagemaker_train_name", value=est._current_job_name, ) train_job.job_status = TrainJobStatus.Training ddb_service.update_item( table=train_table, key=search_key, field_name="job_status", value=TrainJobStatus.Training.value, ) def _start_training_job(train_job_id: str): raw_train_job = ddb_service.get_item( table=train_table, key_values={"id": train_job_id} ) if raw_train_job is None or len(raw_train_job) == 0: return not_found(message=f"no such train job with id({train_job_id})") train_job = TrainJob(**raw_train_job) train_job_name = train_job.model_id raw_checkpoint = ddb_service.get_item( table=checkpoint_table, key_values={"id": train_job.checkpoint_id} ) if raw_checkpoint is None: return not_found( message=f"checkpoint with id {train_job.checkpoint_id} is not found" ) checkpoint = CheckPoint(**raw_checkpoint) _trigger_sagemaker_training_job(train_job, checkpoint.s3_location, train_job_name) return { "id": train_job.id, "status": train_job.job_status.value, "created": str(train_job.timestamp), "params": train_job.params, "input_location": train_job.input_s3_location, "output_location": checkpoint.s3_location } def _create_training_job(raw_event, context): """Create a training job Returns: Training job in JSON format """ request_id = context.aws_request_id event = Event(**json.loads(raw_event["body"])) logger.info(json.dumps(json.loads(raw_event["body"]))) _lora_train_type = event.lora_train_type username = permissions_check(raw_event, [PERMISSION_TRAIN_ALL]) if _lora_train_type.lower() == LoraTrainType.KOHYA.value: # Kohya training base_key = f"{_lora_train_type.lower()}/train/{request_id}" input_location = f"{base_key}/input" if ( "training_params" not in event.params or "s3_model_path" not in event.params["training_params"] or "s3_data_path" not in event.params["training_params"] or "fm_type" not in event.params["training_params"] ): raise ValueError( "Missing train parameters, fm_type, s3_model_path and s3_data_path should be in training_params" ) fm_type = event.params["training_params"]["fm_type"] if fm_type.lower() == const.TrainFMType.SD_1_5.value: toml_dest_path = f"{input_location}/{const.KOHYA_TOML_FILE_NAME}" toml_template_path = "template/" + const.KOHYA_TOML_FILE_NAME elif fm_type.lower() == const.TrainFMType.SD_XL.value: toml_dest_path = f"{input_location}/{const.KOHYA_XL_TOML_FILE_NAME}" toml_template_path = "template/" + const.KOHYA_XL_TOML_FILE_NAME else: raise ValueError( f"Invalid fm_type {fm_type}, the valid values are {const.TrainFMType.SD_1_5.value} and {const.TrainFMType.SD_XL.value}" ) # Merge user parameter, if no config_params is defined, use the default value in S3 bucket if "config_params" in event.params: updated_parameters = event.params["config_params"] _update_toml_file_in_s3( bucket_name, toml_template_path, toml_dest_path, updated_parameters ) else: # Copy template file and make no changes as no config parameters are defined s3.copy_object( CopySource={"Bucket": bucket_name, "Key": toml_template_path}, Bucket=bucket_name, Key=toml_dest_path, ) event.params["training_params"][ "s3_toml_path" ] = f"s3://{bucket_name}/{toml_dest_path}" else: raise ValueError( f"Invalid lora train type: {_lora_train_type}, the valid value is {LoraTrainType.KOHYA.value}." ) event.params["training_type"] = _lora_train_type.lower() user_roles = get_user_roles(ddb_service, user_table, username) ckpt_type = const.CheckPointType.LORA if "config_params" in event.params and \ "additional_network" in event.params["config_params"] and \ "network_module" in event.params["config_params"]["additional_network"]: network_module = event.params["config_params"]["additional_network"]["network_module"] if network_module.lower() != const.NetworkModule.LORA: ckpt_type = const.CheckPointType.SD checkpoint = CheckPoint( id=request_id, checkpoint_type=ckpt_type, s3_location=f"s3://{bucket_name}/{base_key}/output", checkpoint_status=CheckPointStatus.Initial, timestamp=datetime.datetime.now().timestamp(), allowed_roles_or_users=user_roles, ) ddb_service.put_items(table=checkpoint_table, entries=checkpoint.__dict__) train_input_s3_location = f"s3://{bucket_name}/{input_location}" train_job = TrainJob( id=request_id, model_id=const.KOHYA_MODEL_ID, job_status=TrainJobStatus.Initial, params=event.params, train_type=const.TRAIN_TYPE, input_s3_location=train_input_s3_location, checkpoint_id=checkpoint.id, timestamp=datetime.datetime.now().timestamp(), allowed_roles_or_users=[username], ) ddb_service.put_items(table=train_table, entries=train_job.__dict__) return train_job.id def handler(raw_event, context): job_id = None try: logger.info(json.dumps(raw_event)) job_id = _create_training_job(raw_event, context) job_info = _start_training_job(job_id) return ok(data=job_info, decimal=True) except Exception as e: if job_id: ddb_service.update_item( table=train_table, key={"id": job_id}, field_name="job_status", value=TrainJobStatus.Failed.value, ) return response_error(e)