import os import io import base64 from PIL import Image import uuid import boto3 import modules.shared as shared from utils import ModelsRef from modules import sd_hijack, sd_models, sd_vae CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors"] models_type_list = ['Stable-diffusion', 'hypernetworks', 'Lora', 'ControlNet', 'embeddings'] models_used_count = {key: ModelsRef() for key in models_type_list} models_path = {key: None for key in models_type_list} models_path['Stable-diffusion'] = 'models/Stable-diffusion' models_path['ControlNet'] = 'models/ControlNet' models_path['hypernetworks'] = 'models/hypernetworks' models_path['Lora'] = 'models/Lora' models_path['embeddings'] = 'embeddings' disk_path = '/tmp' #disk_path = '/' def checkspace_and_update_models(selected_models, checkpoint_info): models_num = len(models_type_list) space_free_size = selected_models['space_free_size'] os.system("df -h") for type_id in range(models_num): model_type = models_type_list[type_id] selected_models_name = selected_models[model_type] local_models = [] for path, subdirs, files in os.walk(models_path[model_type]): for name in files: full_path_name = os.path.join(path, name) name_local = os.path.relpath(full_path_name, models_path[model_type]) local_models.append(name_local) for selected_model_name in selected_models_name: models_used_count[model_type].add_models_ref(selected_model_name) if selected_model_name in local_models: continue else: st = os.statvfs(disk_path) free = (st.f_bavail * st.f_frsize) print('!!!!!!!!!!!!current free space is', free) if free < space_free_size: #### delete least used model to get more space ######## space_check_succese = False for i in range(models_num): type_id_check = (type_id + i)%models_num type_check = models_type_list[type_id_check] selected_models_name_check = selected_models[type_check] print(os.listdir(models_path[type_check])) local_models_check = [f for f in os.listdir(models_path[type_check]) if os.path.splitext(f)[1] in CN_MODEL_EXTS] if len(local_models_check) == 0: continue sorted_local_modles = models_used_count[type_check].get_sorted_models(local_models_check) for local_model in sorted_local_modles: if local_model in selected_models_name_check: continue else: os.remove(os.path.join(models_path[type_check], local_model)) print('remove models', os.path.join(models_path[type_check], local_model)) models_used_count[type_check].remove_model_ref(local_model) st = os.statvfs(disk_path) free = (st.f_bavail * st.f_frsize) print('!!!!!!!!!!!!current free space is', free) if free > space_free_size: space_check_succese = True break if space_check_succese: break if not space_check_succese: print('can not get enough space to download models!!!!!!') return ####down load models###### selected_model_s3_pos = checkpoint_info[model_type][selected_model_name] download_and_update(model_type, selected_model_name, selected_model_s3_pos) shared.opts.sd_model_checkpoint = selected_models['Stable-diffusion'][0] sd_models.reload_model_weights() sd_vae.reload_vae_weights() def download_model(model_name, model_s3_pos): #download from s3 os.system(f'./tools/s5cmd cp {model_s3_pos} ./') os.system(f"tar xvf {model_name}") def upload_model(model_type, model_name, model_s3_pos): #upload model to s3 os.system(f"tar cvf {model_name} {models_path[model_type]}/{model_name}") os.system(f'./tools/s5cmd cp {model_name} {model_s3_pos}') def download_and_update(model_type, model_name, model_s3_pos): #download from s3 os.system(f'./tools/s5cmd cp {model_s3_pos} ./') tar_name = model_s3_pos.split('/')[-1] os.system(f"tar xvf {tar_name}") os.system(f"rm {tar_name}") os.system("df -h") if model_type == 'Stable-diffusion': sd_models.list_models() if model_type == 'hypernetworks': shared.reload_hypernetworks() if model_type == 'embeddings': sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) if model_type == 'ControlNet': #sys.path.append("extensions/sd-webui-controlnet/scripts/") from scripts import global_state global_state.update_cn_models() #sys.path.remove("extensions/sd-webui-controlnet/scripts/") def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] return Image.open(io.BytesIO(base64.b64decode(encoding))) def file_to_base64(file_path) -> str: with open(file_path, "rb") as f: im_b64 = base64.b64encode(f.read()) return str(im_b64, 'utf-8') def get_bucket_and_key(s3uri): pos = s3uri.find('/', 5) bucket = s3uri[5 : pos] key = s3uri[pos + 1 : ] return bucket, key