539 lines
24 KiB
Python
539 lines
24 KiB
Python
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import traceback
|
|
import copy
|
|
import datetime
|
|
import boto3
|
|
from fastapi import FastAPI
|
|
|
|
from modules import sd_models
|
|
import modules.extras
|
|
import sys
|
|
from aws_extension.models import InvocationsRequest
|
|
from aws_extension.mme_utils import checkspace_and_update_models, download_model, models_path
|
|
import requests
|
|
from utils import get_bucket_name_from_s3_path, get_path_from_s3_path, download_folder_from_s3_by_tar, \
|
|
upload_folder_to_s3_by_tar, read_from_s3
|
|
|
|
THREAD_CHECK_COUNT = 1
|
|
CONDITION_POOL_MAX_COUNT = 10
|
|
CONDITION_WAIT_TIME_OUT = 100000
|
|
|
|
|
|
def dummy_function(*args, **kwargs):
|
|
return None
|
|
|
|
|
|
logger = logging.getLogger("sd_proxy")
|
|
logger.setLevel(os.environ.get('LOG_LEVEL') or logging.ERROR)
|
|
|
|
cloudwatch = boto3.client('cloudwatch')
|
|
|
|
endpoint_name = os.getenv('ENDPOINT_NAME')
|
|
endpoint_instance_id = os.getenv('ENDPOINT_INSTANCE_ID', 'default')
|
|
|
|
ddb_client = boto3.resource('dynamodb')
|
|
inference_table = ddb_client.Table('SDInferenceJobTable')
|
|
|
|
|
|
def update_execute_job_table(prompt_id, key, value):
|
|
logger.info(f"Update job with prompt_id: {prompt_id}, key: {key}, value: {value}")
|
|
try:
|
|
inference_table.update_item(
|
|
Key={
|
|
"InferenceJobId": prompt_id,
|
|
},
|
|
UpdateExpression=f"set #k = :r",
|
|
ExpressionAttributeNames={'#k': key},
|
|
ExpressionAttributeValues={':r': value},
|
|
ConditionExpression="attribute_exists(InferenceJobId)",
|
|
ReturnValues="UPDATED_NEW"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Update execute job table error: {e}")
|
|
raise e
|
|
|
|
|
|
def record_metric(req: InvocationsRequest):
|
|
data = [
|
|
{
|
|
'MetricName': 'InferenceTotal',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': endpoint_name
|
|
},
|
|
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': 'InferenceTotal',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': endpoint_name
|
|
},
|
|
{
|
|
'Name': 'Instance',
|
|
'Value': endpoint_instance_id
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': 'InferenceEndpointReceived',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Service',
|
|
'Value': 'Stable-Diffusion'
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
{
|
|
'MetricName': 'InferenceEndpointReceived',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Endpoint',
|
|
'Value': endpoint_name
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
},
|
|
]
|
|
|
|
if req.workflow:
|
|
data.append({
|
|
'MetricName': 'InferenceEndpointReceived',
|
|
'Dimensions': [
|
|
{
|
|
'Name': 'Workflow',
|
|
'Value': req.workflow
|
|
},
|
|
],
|
|
'Timestamp': datetime.datetime.utcnow(),
|
|
'Value': 1,
|
|
'Unit': 'Count'
|
|
})
|
|
|
|
response = cloudwatch.put_metric_data(
|
|
Namespace='ESD',
|
|
MetricData=data
|
|
)
|
|
|
|
logger.info(f"record_metric response: {response}")
|
|
|
|
|
|
def merge_model_on_cloud(req):
|
|
def modelmerger(*args):
|
|
try:
|
|
results = modules.extras.run_modelmerger(*args)
|
|
except Exception as e:
|
|
logger.info(f"Error loading/saving model file: {e}")
|
|
logger.info(traceback.format_exc(), file=sys.stderr)
|
|
# modules.sd_models.list_models() # to remove the potentially missing models from the list
|
|
return [None, None, None, None, f"Error merging checkpoints: {e}"]
|
|
return results
|
|
|
|
merge_checkpoint_payload = req.merge_checkpoint_payload
|
|
primary_model_name = merge_checkpoint_payload["primary_model_name"]
|
|
secondary_model_name = merge_checkpoint_payload["secondary_model_name"]
|
|
tertiary_model_name = merge_checkpoint_payload["teritary_model_name"]
|
|
interp_method = merge_checkpoint_payload["interp_method"]
|
|
interp_amount = merge_checkpoint_payload["interp_amount"]
|
|
save_as_half = merge_checkpoint_payload["save_as_half"]
|
|
custom_name = merge_checkpoint_payload["custom_name"]
|
|
checkpoint_format = merge_checkpoint_payload["checkpoint_format"]
|
|
config_source = merge_checkpoint_payload["config_source"]
|
|
bake_in_vae = merge_checkpoint_payload["bake_in_vae"]
|
|
discard_weights = merge_checkpoint_payload["discard_weights"]
|
|
save_metadata = merge_checkpoint_payload["save_metadata"]
|
|
merge_model_s3_pos = merge_checkpoint_payload["merge_model_s3"]
|
|
|
|
# upload checkpoints from cloud to local variable
|
|
model_type = 'Stable-diffusion'
|
|
checkpoint_info = req.checkpoint_info
|
|
selected_model_s3_pos = checkpoint_info[model_type][primary_model_name]
|
|
download_model(primary_model_name, selected_model_s3_pos)
|
|
selected_model_s3_pos = checkpoint_info[model_type][secondary_model_name]
|
|
download_model(secondary_model_name, selected_model_s3_pos)
|
|
if tertiary_model_name:
|
|
selected_model_s3_pos = checkpoint_info[model_type][tertiary_model_name]
|
|
download_model(tertiary_model_name, selected_model_s3_pos)
|
|
|
|
sd_models.list_models()
|
|
|
|
for model_name in sd_models.checkpoints_list.keys():
|
|
raw_name = model_name[:-13]
|
|
if raw_name == primary_model_name:
|
|
primary_model_name = model_name
|
|
if raw_name == secondary_model_name:
|
|
secondary_model_name = model_name
|
|
if raw_name == tertiary_model_name:
|
|
tertiary_model_name = model_name
|
|
|
|
logger.info(f"sd model checkpoint list is {sd_models.checkpoints_list}")
|
|
|
|
[primary_model_name, secondary_model_name, tertiary_model_name, component_dict_sd_model_checkpoints, modelmerger_result] = \
|
|
modelmerger("fake_id_task", primary_model_name, secondary_model_name, tertiary_model_name, \
|
|
interp_method, interp_amount, save_as_half, custom_name, checkpoint_format, config_source, \
|
|
bake_in_vae, discard_weights, save_metadata)
|
|
|
|
output_model_position = modelmerger_result[20:]
|
|
|
|
# check whether yaml exists
|
|
merge_model_name = output_model_position.split('/')[-1].replace(' ','\ ')
|
|
|
|
yaml_position = output_model_position[:-len(output_model_position.split('.')[-1])]+'yaml'
|
|
yaml_states = os.path.isfile(yaml_position)
|
|
|
|
new_merge_model_name = merge_model_name.replace('(','_').replace(')','_')
|
|
|
|
base_path = models_path[model_type]
|
|
|
|
merge_model_name_complete_path = base_path + '/' + merge_model_name
|
|
new_merge_model_name_complete_path = base_path + '/' + new_merge_model_name
|
|
merge_model_name_complete_path = merge_model_name_complete_path.replace('(','\(').replace(')','\)')
|
|
os.system(f"mv {merge_model_name_complete_path} {new_merge_model_name_complete_path}")
|
|
|
|
model_yaml = (merge_model_name[:-len(merge_model_name.split('.')[-1])]+'yaml').replace('(','\(').replace(')','\)')
|
|
model_yaml_complete_path = base_path + '/' + model_yaml
|
|
|
|
logger.info(f"m {merge_model_name_complete_path}, n_m {new_merge_model_name_complete_path}, yaml {model_yaml_complete_path}")
|
|
|
|
if yaml_states:
|
|
new_model_yaml = model_yaml.replace('(','_').replace(')','_')
|
|
new_model_yaml_complete_path = base_path + '/' + new_model_yaml
|
|
os.system(f"mv {model_yaml_complete_path} {new_model_yaml_complete_path}")
|
|
os.system(f"tar cvf {new_merge_model_name} {new_merge_model_name_complete_path} {new_model_yaml_complete_path}")
|
|
else:
|
|
os.system(f"tar cvf {new_merge_model_name} {new_merge_model_name_complete_path} ")
|
|
|
|
os.system(f's5cmd sync {new_merge_model_name} {merge_model_s3_pos}{new_merge_model_name}')
|
|
os.system(f'rm {new_merge_model_name_complete_path}')
|
|
os.system(f'rm {new_model_yaml_complete_path}')
|
|
|
|
logger.info(f"output model path is {output_model_position}")
|
|
|
|
return output_model_position
|
|
|
|
|
|
def get_output_img_type(payload: dict):
|
|
try:
|
|
if not payload or 'alwayson_scripts' not in payload.keys() or not payload['alwayson_scripts']:
|
|
logger.debug("not using alwayson_scripts ,image type not set")
|
|
return None
|
|
# about animatediff out
|
|
if 'animatediff' in payload['alwayson_scripts'].keys() and payload['alwayson_scripts']['animatediff']:
|
|
if ('args' not in payload['alwayson_scripts']['animatediff']
|
|
or not payload['alwayson_scripts']['animatediff']['args']
|
|
or not payload['alwayson_scripts']['animatediff']['args'][0]):
|
|
logger.debug("not using alwayson_scripts or args null,image type not set")
|
|
return None
|
|
|
|
if ('enable' not in payload['alwayson_scripts']['animatediff']['args'][0]
|
|
or not payload['alwayson_scripts']['animatediff']['args'][0]['enable']):
|
|
logger.debug("not using alwayson_scripts or not enable ,image type not set")
|
|
return None
|
|
|
|
if ('format' not in payload['alwayson_scripts']['animatediff']['args'][0]
|
|
or not payload['alwayson_scripts']['animatediff']['args'][0]['format']):
|
|
logger.debug("not using alwayson_scripts or not set format ,image type not set")
|
|
return None
|
|
images_types = payload['alwayson_scripts']['animatediff']['args'][0]['format']
|
|
logger.debug(f"using alwayson_scripts ,image type set:{images_types}")
|
|
return images_types
|
|
except Exception as e:
|
|
logger.debug(f"get_output_img_type error:{e}")
|
|
return None
|
|
|
|
|
|
def parse_constant(c: str) -> float:
|
|
if c == "NaN":
|
|
raise ValueError("NaN is not valid JSON")
|
|
|
|
if c == 'Infinity':
|
|
return sys.float_info.max
|
|
|
|
return float(c)
|
|
|
|
|
|
def sagemaker_api(_, app: FastAPI):
|
|
logger.info(app.__dict__)
|
|
logger.info(app)
|
|
logger.debug("Loading Sagemaker API Endpoints.")
|
|
import threading
|
|
from collections import deque
|
|
global condition
|
|
condition = threading.Condition()
|
|
global thread_deque
|
|
thread_deque = deque()
|
|
|
|
def wrap_response(start_time, data):
|
|
data['start_time'] = start_time
|
|
data['endpoint_name'] = os.getenv('ENDPOINT_NAME')
|
|
data['endpoint_instance_id'] = os.getenv('ENDPOINT_INSTANCE_ID')
|
|
return data
|
|
|
|
@app.post("/invocations")
|
|
def invocations(req: InvocationsRequest):
|
|
|
|
logger.info(f'-------invocation on port {req.port}------')
|
|
logger.info(json.dumps(req.__dict__, default=str))
|
|
|
|
start_time = datetime.datetime.now().isoformat()
|
|
|
|
update_execute_job_table(req.id, 'startTime', start_time)
|
|
|
|
record_metric(req)
|
|
|
|
with condition:
|
|
try:
|
|
thread_deque.append(req)
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name} {len(thread_deque)}")
|
|
if len(thread_deque) > THREAD_CHECK_COUNT and len(thread_deque) <= CONDITION_POOL_MAX_COUNT:
|
|
logger.info(f"wait {threading.current_thread().ident}_{threading.current_thread().name} {len(thread_deque)}")
|
|
condition.wait(timeout=CONDITION_WAIT_TIME_OUT)
|
|
elif len(thread_deque) > CONDITION_POOL_MAX_COUNT:
|
|
logger.info(f"waiting thread too much in condition pool {len(thread_deque)}, max: {CONDITION_POOL_MAX_COUNT}")
|
|
raise MemoryError
|
|
|
|
logger.info(f"task is {req.task}")
|
|
logger.info(f"models is {req.models}")
|
|
|
|
payload_string = None
|
|
# if it has payload_string, use it
|
|
if req.payload_string:
|
|
payload_string = req.payload_string
|
|
elif req.param_s3:
|
|
payload_string = read_from_s3(req.param_s3)
|
|
|
|
payload = json.loads(payload_string, parse_constant=parse_constant)
|
|
|
|
if req.task == 'txt2img':
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name}_______ txt2img start !!!!!!!!")
|
|
checkspace_and_update_models(req.models)
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name}_______ txt2img models update !!!!!!!!")
|
|
image_type = get_output_img_type(payload)
|
|
logger.debug(f"image_type:{image_type}")
|
|
resp = {}
|
|
if image_type:
|
|
logger.debug(f"set output_img_type:{image_type}")
|
|
resp["output_img_type"] = image_type
|
|
response = requests.post(url=f'http://0.0.0.0:{req.port}/sdapi/v1/txt2img',
|
|
json=payload)
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name}_______ txt2img end !!!!!!!! {len(response.json())}")
|
|
resp.update(response.json())
|
|
return wrap_response(start_time, resp)
|
|
elif req.task == 'img2img':
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name}_______ img2img start!!!!!!!!")
|
|
checkspace_and_update_models(req.models)
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name}_______ txt2img models update !!!!!!!!")
|
|
image_type = get_output_img_type(payload)
|
|
logger.debug(f"image_type:{image_type}")
|
|
resp = {}
|
|
if image_type:
|
|
logger.debug(f"set output_img_type:{image_type}")
|
|
resp["output_img_type"] = image_type
|
|
response = requests.post(url=f'http://0.0.0.0:{req.port}/sdapi/v1/img2img',
|
|
json=payload)
|
|
logger.info(f"{threading.current_thread().ident}_{threading.current_thread().name}_______ img2img end !!!!!!!!{len(response.json())}")
|
|
resp.update(response.json())
|
|
return wrap_response(start_time, resp)
|
|
elif req.task == 'interrogate_clip' or req.task == 'interrogate_deepbooru':
|
|
response = requests.post(url=f'http://0.0.0.0:{req.port}/sdapi/v1/interrogate',
|
|
json=json.loads(req.interrogate_payload.json()))
|
|
return wrap_response(start_time, response.json())
|
|
elif req.task == 'extra-single-image':
|
|
response = requests.post(url=f'http://0.0.0.0:{req.port}/sdapi/v1/extra-single-image',
|
|
json=payload)
|
|
return wrap_response(start_time, response.json())
|
|
elif req.task == 'extra-batch-images':
|
|
response = requests.post(url=f'http://0.0.0.0:{req.port}/sdapi/v1/extra-batch-images',
|
|
json=payload)
|
|
return wrap_response(start_time, response.json())
|
|
elif req.task == 'rembg':
|
|
response = requests.post(url=f'http://0.0.0.0:{req.port}/rembg', json=payload)
|
|
return wrap_response(start_time, response.json())
|
|
elif req.task == 'db-create-model':
|
|
r"""
|
|
task: db-create-model
|
|
db_create_model_payload:
|
|
:s3_input_path: S3 path for download src model.
|
|
:s3_output_path: S3 path for upload generated model.
|
|
:ckpt_from_cloud: Whether to get ckpt from cloud or local.
|
|
:job_id: job id.
|
|
:param
|
|
:new_model_name: generated model name.
|
|
:ckpt_path: S3 path for download src model.
|
|
:db_new_model_shared_src="",
|
|
:from_hub=False,
|
|
:new_model_url="",
|
|
:new_model_token="",
|
|
:extract_ema=False,
|
|
:train_unfrozen=False,
|
|
:is_512=True,
|
|
"""
|
|
try:
|
|
db_create_model_payload = json.loads(req.db_create_model_payload)
|
|
job_id = db_create_model_payload["job_id"]
|
|
s3_output_path = db_create_model_payload["s3_output_path"]
|
|
output_bucket_name = get_bucket_name_from_s3_path(s3_output_path)
|
|
output_path = get_path_from_s3_path(s3_output_path)
|
|
db_create_model_params = db_create_model_payload["param"]["create_model_params"]
|
|
if "ckpt_from_cloud" in db_create_model_payload["param"]:
|
|
ckpt_from_s3 = db_create_model_payload["param"]["ckpt_from_cloud"]
|
|
else:
|
|
ckpt_from_s3 = False
|
|
if not db_create_model_params['from_hub']:
|
|
if ckpt_from_s3:
|
|
s3_input_path = db_create_model_payload["param"]["s3_ckpt_path"]
|
|
local_model_path = db_create_model_params["ckpt_path"]
|
|
input_path = get_path_from_s3_path(s3_input_path)
|
|
logger.info(f"ckpt from s3 {input_path} {local_model_path}")
|
|
else:
|
|
s3_input_path = db_create_model_payload["s3_input_path"]
|
|
local_model_path = db_create_model_params["ckpt_path"]
|
|
input_path = os.path.join(get_path_from_s3_path(s3_input_path), local_model_path)
|
|
logger.info(f"ckpt from local {input_path} {local_model_path}")
|
|
input_bucket_name = get_bucket_name_from_s3_path(s3_input_path)
|
|
logger.info("Check disk usage before download.")
|
|
os.system("df -h")
|
|
logger.info(f"Download src model from s3 {input_bucket_name} {input_path} {local_model_path}")
|
|
download_folder_from_s3_by_tar(input_bucket_name, input_path, local_model_path)
|
|
# Refresh the ckpt list.
|
|
sd_models.list_models()
|
|
logger.info("Check disk usage after download.")
|
|
os.system("df -h")
|
|
logger.info("Start creating model.")
|
|
create_model_func_args = copy.deepcopy(db_create_model_params)
|
|
local_response = create_model(**create_model_func_args)
|
|
target_local_model_dir = f'models/dreambooth/{db_create_model_params["new_model_name"]}'
|
|
logger.info(f"Upload tgt model to s3 {target_local_model_dir} {output_bucket_name} {output_path}")
|
|
upload_folder_to_s3_by_tar(target_local_model_dir, output_bucket_name, output_path)
|
|
config_file = os.path.join(target_local_model_dir, "db_config.json")
|
|
with open(config_file, 'r') as openfile:
|
|
config_dict = json.load(openfile)
|
|
message = {
|
|
"response": local_response,
|
|
"config_dict": config_dict
|
|
}
|
|
response = {
|
|
"id": job_id,
|
|
"statusCode": 200,
|
|
"message": message,
|
|
"outputLocation": [f'{s3_output_path}/db_create_model_params["new_model_name"]']
|
|
}
|
|
return response
|
|
except Exception as e:
|
|
response = {
|
|
"id": job_id,
|
|
"statusCode": 500,
|
|
"message": traceback.format_exc(),
|
|
}
|
|
logger.error(traceback.format_exc())
|
|
return response
|
|
finally:
|
|
# Clean up
|
|
logger.info("Delete src model.")
|
|
delete_src_command = f"rm -rf models/Stable-diffusion/{db_create_model_params['ckpt_path']}"
|
|
logger.info(delete_src_command)
|
|
os.system(delete_src_command)
|
|
logger.info("Delete tgt model.")
|
|
delete_tgt_command = f"rm -rf models/dreambooth/{db_create_model_params['new_model_name']}"
|
|
logger.info(delete_tgt_command)
|
|
os.system(delete_tgt_command)
|
|
logger.info("Check disk usage after request.")
|
|
os.system("df -h")
|
|
elif req.task == 'merge-checkpoint':
|
|
try:
|
|
output_model_position = merge_model_on_cloud(req)
|
|
response = {
|
|
"statusCode": 200,
|
|
"message": output_model_position,
|
|
}
|
|
return response
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
else:
|
|
raise NotImplementedError
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
finally:
|
|
thread_deque.popleft()
|
|
condition.notify()
|
|
|
|
@app.get("/ping")
|
|
def ping():
|
|
return {'status': 'Healthy'}
|
|
|
|
|
|
def md5(fname):
|
|
hash_md5 = hashlib.md5()
|
|
with open(fname, "rb") as f:
|
|
for chunk in iter(lambda: f.read(4096), b""):
|
|
hash_md5.update(chunk)
|
|
return hash_md5.hexdigest()
|
|
|
|
|
|
def get_file_md5_dict(path):
|
|
file_dict = {}
|
|
for root, dirs, files in os.walk(path):
|
|
for file in files:
|
|
file_dict[file] = md5(os.path.join(root, file))
|
|
return file_dict
|
|
|
|
|
|
def move_model_to_tmp(_, app: FastAPI):
|
|
logger.info("Copy model dir to tmp")
|
|
model_tmp_dir = f"models"
|
|
# for mutil gpus
|
|
if os.path.exists(f"/tmp/{model_tmp_dir}"):
|
|
return
|
|
os.system(f"cp -rL models /tmp/{model_tmp_dir}")
|
|
src_file_dict = get_file_md5_dict("models")
|
|
tgt_file_dict = get_file_md5_dict(f"/tmp/{model_tmp_dir}")
|
|
is_complete = True
|
|
for file in src_file_dict:
|
|
logger.info(f"Src file {file} md5 {src_file_dict[file]}")
|
|
if file not in tgt_file_dict:
|
|
is_complete = False
|
|
break
|
|
if src_file_dict[file] != tgt_file_dict[file]:
|
|
is_complete = False
|
|
break
|
|
if is_complete:
|
|
os.system(f"rm -rf models")
|
|
logger.info("Link model dir")
|
|
os.system(f"ln -s /tmp/{model_tmp_dir} models")
|
|
else:
|
|
logger.info("Failed to copy model dir, use the original dir")
|
|
logger.info("Check disk usage on app started")
|
|
os.system("df -h")
|
|
|
|
try:
|
|
import modules.script_callbacks as script_callbacks
|
|
|
|
script_callbacks.on_app_started(sagemaker_api)
|
|
on_sagemaker = os.environ.get('ON_SAGEMAKER', "false")
|
|
if on_sagemaker == "true":
|
|
from modules import shared
|
|
shared.opts.data.update(control_net_max_models_num=10)
|
|
script_callbacks.on_app_started(move_model_to_tmp)
|
|
logger.debug("SD-Webui API layer loaded")
|
|
except Exception as e:
|
|
logger.error(e)
|
|
logger.debug("Unable to import script callbacks.")
|
|
pass
|