stable-diffusion-aws-extension/aws_extension/mme_utils.py

169 lines
7.1 KiB
Python

import os
import io
import base64
from PIL import Image
from utils import ModelsRef
import subprocess
import logging
try:
import modules.shared as shared
from modules import sd_hijack, sd_models, sd_vae
except Exception:
print('default modules load fails')
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors"]
models_type_list = ['Stable-diffusion', 'hypernetworks', 'Lora', 'ControlNet', 'embeddings', 'VAE']
models_used_count = {key: ModelsRef() for key in models_type_list}
models_path = {key: None for key in models_type_list}
models_path['Stable-diffusion'] = 'models/Stable-diffusion'
models_path['ControlNet'] = 'models/ControlNet'
models_path['hypernetworks'] = 'models/hypernetworks'
models_path['Lora'] = 'models/Lora'
models_path['embeddings'] = 'embeddings'
models_path['VAE'] = 'models/VAE'
disk_path = '/tmp'
#disk_path = '/'
TAR_TYPE_FILE = 'application/x-tar'
def checkspace_and_update_models(selected_models):
print(selected_models)
models_num = len(models_type_list)
space_free_size = selected_models['space_free_size']
# os.system("df -h")
for type_id in range(models_num):
model_type = models_type_list[type_id]
if model_type not in selected_models:
continue
selected_models_name = selected_models[model_type]
local_models = []
for path, subdirs, files in os.walk(models_path[model_type]):
for name in files:
full_path_name = os.path.join(path, name)
name_local = os.path.relpath(full_path_name, models_path[model_type])
local_models.append(name_local)
for model in selected_models_name:
models_used_count[model_type].add_models_ref(model['model_name'])
if model['model_name'] in local_models:
continue
if model_type == 'VAE' and model['model_name'] in ['Automatic', 'None']:
print(f'skip vae download for {model["model_name"]}')
continue
else:
st = os.statvfs(disk_path)
free = (st.f_bavail * st.f_frsize)
print('!!!!!!!!!!!!current free space is', free)
if free < space_free_size:
#### delete least used model to get more space ########
space_check_success = False
for i in range(models_num):
type_id_check = (type_id + i)%models_num
type_check = models_type_list[type_id_check]
if type_check not in selected_models:
continue
selected_models_name_check = selected_models[type_check]
print(f'check current model folder: {os.listdir(models_path[type_check])}')
local_models_check = [f for f in os.listdir(models_path[type_check]) if os.path.splitext(f)[1] in CN_MODEL_EXTS]
if len(local_models_check) == 0:
continue
sorted_local_modles = models_used_count[type_check].get_sorted_models(local_models_check)
for local_model in sorted_local_modles:
if local_model in selected_models_name_check:
continue
else:
os.remove(os.path.join(models_path[type_check], local_model))
print('remove models', os.path.join(models_path[type_check], local_model))
models_used_count[type_check].remove_model_ref(local_model)
st = os.statvfs(disk_path)
free = (st.f_bavail * st.f_frsize)
print('!!!!!!!!!!!!current free space is', free)
if free > space_free_size:
space_check_success = True
break
if space_check_success:
break
if not space_check_success:
print('can not get enough space to download models!!!!!!')
return
####down load models######
download_and_update(model_type, f'{model["s3"]}/{model["model_name"]}')
shared.opts.sd_model_checkpoint = selected_models['Stable-diffusion'][0]["model_name"]
sd_models.reload_model_weights()
if 'VAE' in selected_models:
shared.opts.sd_vae = selected_models['VAE'][0]['model_name']
sd_vae.reload_vae_weights()
def download_model(model_name, model_s3_pos):
#download from s3
os.system(f'./tools/s5cmd cp {model_s3_pos} ./')
os.system(f"tar xvf {model_name}")
def upload_model(model_type, model_name, model_s3_pos):
#upload model to s3
os.system(f"tar cvf {model_name} {models_path[model_type]}/{model_name}")
os.system(f'./tools/s5cmd cp {model_name} {model_s3_pos}')
def download_and_update(model_type, model_s3_pos):
#download from s3
logging.info(f'./tools/s5cmd cp "{model_s3_pos}" ./')
os.system(f'./tools/s5cmd cp "{model_s3_pos}" ./')
tar_name = model_s3_pos.split('/')[-1]
logging.info(tar_name)
command = f'file --mime-type -b ./"{tar_name}"'
file_type = subprocess.check_output(command, shell=True).decode('utf-8').strip()
logging.info(f"file_type is {file_type}")
if file_type == TAR_TYPE_FILE:
logging.info("model type is tar")
os.system(f"tar xvf '{tar_name}'")
os.system(f"rm '{tar_name}'")
os.system("df -h")
else:
os.system(f"rm '{tar_name}'") # 使用引号括起文件名
logging.info(f"model type is origin file type: {file_type}")
prefix_name = model_s3_pos.split('.')[0]
if model_type == 'embeddings':
os.system(f'./tools/s5cmd cp "{prefix_name}"* ./{model_type}/')
else:
os.system(f'./tools/s5cmd cp "{prefix_name}"* ./models/{model_type}/')
logging.info("download finished")
if model_type == 'Stable-diffusion':
sd_models.list_models()
if model_type == 'hypernetworks':
shared.reload_hypernetworks()
if model_type == 'embeddings':
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
if model_type == 'ControlNet':
#sys.path.append("extensions/sd-webui-controlnet/scripts/")
from scripts import global_state
global_state.update_cn_models()
#sys.path.remove("extensions/sd-webui-controlnet/scripts/")
if model_type == 'VAE':
sd_vae.refresh_vae_list()
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(io.BytesIO(base64.b64decode(encoding)))
def file_to_base64(file_path) -> str:
with open(file_path, "rb") as f:
im_b64 = base64.b64encode(f.read())
return str(im_b64, 'utf-8')
def get_bucket_and_key(s3uri):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]
return bucket, key