stable-diffusion-aws-extension/scripts/main.py

1398 lines
73 KiB
Python

import datetime
import sys
from queue import Queue
from typing import Dict, List
import requests
import logging
import gradio as gr
import os
import modules.scripts as scripts
from aws_extension.cloud_models_manager.sd_manager import CloudSDModelsManager, postfix
from modules import script_callbacks
from modules.ui import create_refresh_button
from modules.ui_components import FormRow
from utils import get_variable_from_json
from utils import save_variable_to_json
# sys.path.append("extensions/stable-diffusion-aws-extension/scripts")
# import sagemaker_ui
from aws_extension import sagemaker_ui
dreambooth_available = True
def dummy_function(*args, **kwargs):
return []
try:
from dreambooth_on_cloud.train import (
async_cloud_train,
get_cloud_db_model_name_list,
wrap_load_model_params,
get_train_job_list,
get_sorted_cloud_dataset
)
from dreambooth_on_cloud.create_model import (
get_sd_cloud_models,
get_create_model_job_list,
cloud_create_model,
)
except Exception as e:
logging.warning(
"[main]dreambooth_on_cloud is not installed or can not be imported, using dummy function to proceed.")
dreambooth_available = False
cloud_train = dummy_function
get_cloud_db_model_name_list = dummy_function
wrap_load_model_params = dummy_function
get_train_job_list = dummy_function
get_sorted_cloud_dataset = dummy_function
get_sd_cloud_models = dummy_function
get_create_model_job_list = dummy_function
cloud_create_model = dummy_function
cloud_datasets = []
training_job_dashboard = None
db_model_name = None
cloud_db_model_name = None
cloud_train_instance_type = None
db_use_txt2img = None
db_sagemaker_train = None
db_save_config = None
txt2img_show_hook = None
txt2img_gallery = None
txt2img_generation_info = None
txt2img_html_info = None
img2img_show_hook = None
img2img_gallery = None
img2img_generation_info = None
img2img_html_info = None
modelmerger_merge_hook = None
modelmerger_merge_component = None
async_inference_choices=["ml.g4dn.2xlarge","ml.g4dn.4xlarge","ml.g4dn.8xlarge","ml.g4dn.12xlarge","ml.g5.2xlarge","ml.g5.4xlarge","ml.g5.8xlarge","ml.g5.12xlarge","ml.g5.12xlarge"]
class SageMakerUI(scripts.Script):
latest_result = None
current_inference_id = None
inference_queue = Queue(maxsize=30)
default_images_inner = None
txt2img_generate_btn = None
img2img_generate_btn = None
sd_model_manager = CloudSDModelsManager()
txt2img_generation_info = None
txt2img_gallery = None
txt2img_html_info = None
img2img_generation_info = None
img2img_gallery = None
img2img_html_info = None
ph = None
def title(self):
return "SageMaker embeddings"
def show(self, is_img2img):
return scripts.AlwaysVisible
def after_component(self, component, **kwargs):
if type(component) is gr.Button:
if self.is_txt2img and getattr(component, 'elem_id', None) == f'txt2img_generate':
self.txt2img_generate_btn = component
elif self.is_img2img and getattr(component, 'elem_id', None) == f'img2img_generate':
self.img2img_generate_btn = component
if type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'generation_info_txt2img' and self.is_txt2img:
self.txt2img_generation_info = component
if type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'txt2img_gallery' and self.is_txt2img:
self.txt2img_gallery = component
if type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_txt2img' and self.is_txt2img:
self.txt2img_html_info = component
async def _update_result():
if self.inference_queue and not self.inference_queue.empty():
inference_id = self.inference_queue.get()
self.latest_result = sagemaker_ui.process_result_by_inference_id(inference_id)
return self.latest_result
return gr.skip(), gr.skip(), gr.skip()
if self.txt2img_html_info and self.txt2img_gallery and self.txt2img_generation_info:
self.txt2img_generation_info.change(
fn=lambda: sagemaker_ui.async_loop_wrapper(_update_result),
inputs=None,
outputs=[self.txt2img_gallery, self.txt2img_generation_info, self.txt2img_html_info]
)
if type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'generation_info_img2img' and self.is_img2img:
self.img2img_generation_info = component
if type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'img2img_gallery' and self.is_img2img:
self.img2img_gallery = component
if type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_img2img' and self.is_img2img:
self.img2img_html_info = component
if self.img2img_html_info and self.img2img_gallery and self.img2img_generation_info:
self.img2img_generation_info.change(
fn=lambda: sagemaker_ui.async_loop_wrapper(_update_result),
inputs=None,
outputs=[self.img2img_gallery, self.img2img_generation_info, self.img2img_html_info]
)
pass
def ui(self, is_img2img):
def _check_generate(endpoint):
if endpoint:
self.sd_model_manager.update_models()
else:
self.sd_model_manager.clear()
return f'Generate{" on Cloud" if endpoint else ""}'
if not is_img2img:
sagemaker_endpoint, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud = sagemaker_ui.create_ui(
is_img2img)
sagemaker_endpoint.change(_check_generate, inputs=sagemaker_endpoint,
outputs=[self.txt2img_generate_btn])
return [sagemaker_endpoint, inference_job_dropdown, txt2img_inference_job_ids_refresh_button,
primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud]
else:
sagemaker_endpoint, inference_job_dropdown, txt2img_inference_job_ids_refresh_button, primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud = sagemaker_ui.create_ui(
is_img2img)
sagemaker_endpoint.change(_check_generate, inputs=sagemaker_endpoint,
outputs=[self.img2img_generate_btn])
return [sagemaker_endpoint, inference_job_dropdown, txt2img_inference_job_ids_refresh_button,
primary_model_name, secondary_model_name, tertiary_model_name, modelmerger_merge_on_cloud]
def before_process(self, p, *args):
on_docker = os.environ.get('ON_DOCKER', "false")
if on_docker == "true":
return
# check if endpoint is inService
sagemaker_endpoint = ''
if args[0]:
sagemaker_endpoint = args[0].split('+')[0] if args[0].split('+')[1] == 'InService' else ''
if not sagemaker_endpoint:
return
if not args[0]:
return
import json
from PIL import Image, PngImagePlugin
from io import BytesIO
import base64
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
import numpy
from modules import sd_models
from modules import extra_networks
current_model = sd_models.select_checkpoint()
print(current_model.name)
models = {'Stable-diffusion': [current_model.name.replace(f'.{postfix}', '')]}
api_param_cls = None
if self.is_img2img:
api_param_cls = StableDiffusionImg2ImgProcessingAPI
if self.is_txt2img:
api_param_cls = StableDiffusionTxt2ImgProcessingAPI
if not api_param_cls:
raise NotImplementedError
api_param = api_param_cls(**p.__dict__)
if self.is_img2img:
api_param.mask = p.image_mask
def get_pil_metadata(pil_image):
# Copy any text-only metadata
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
return metadata
def encode_pil_to_base64(pil_image):
with BytesIO() as output_bytes:
pil_image.save(output_bytes, "PNG", pnginfo=get_pil_metadata(pil_image))
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return "data:image/png;base64," + base64_str
def encode_no_json(obj):
import enum
if isinstance(obj, numpy.ndarray):
return encode_pil_to_base64(Image.fromarray(obj))
# return obj.tolist()
# return "base64 str"
elif isinstance(obj, Image.Image):
return encode_pil_to_base64(obj)
elif isinstance(obj, enum.Enum):
return obj.value
elif hasattr(obj, '__dict__'):
return obj.__dict__
else:
print(f'may not able to json dumps {type(obj)}: {str(obj)}')
return str(obj)
selected_script_index = p.script_args[0] - 1
api_param.script_args = []
for sid, script in enumerate(p.scripts.scripts):
# escape sagemaker plugin
if script.title() == self.title():
continue
all_used_models = []
script_args = p.script_args[script.args_from:script.args_to]
if script.alwayson:
print(f'{script.name} {script.args_from} {script.args_to}')
api_param.alwayson_scripts[script.name] = {}
api_param.alwayson_scripts[script.name]['args'] = []
for _id, arg in enumerate(script_args):
parsed_args, used_models = self._process_args_by_plugin(script.name, arg, _id, script_args)
all_used_models.append(used_models)
api_param.alwayson_scripts[script.name]['args'].append(parsed_args)
elif selected_script_index == sid:
api_param.script_name = script.name
for _id, arg in enumerate(script_args):
parsed_args, used_models = self._process_args_by_plugin(script.name, arg, _id, script_args)
all_used_models.append(used_models)
api_param.script_args.append(parsed_args)
if all_used_models:
for used_models in all_used_models:
for key, vals in used_models.items():
if key not in models:
models[key] = []
for val in vals:
if val not in models[key]:
models[key].append(val)
api_param.sampler_index = p.sampler_name
# finished construct api payload
js = json.dumps(api_param, default=encode_no_json)
# fixme: not handle batches yet
from modules import shared
# we not support automatic for simplicity because the default is Automatic
# if user need, has to select a vae model manually in the setting page
if shared.opts.sd_vae and shared.opts.sd_vae not in ['None', 'Automatic']:
models['VAE'] = [shared.opts.sd_vae]
from modules.processing import get_fixed_seed
seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
p.setup_prompts()
if type(seed) == list:
p.all_seeds = seed
else:
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
p.all_subseeds = subseed
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
p.prompts = p.all_prompts
p.negative_prompts = p.all_negative_prompts
p.seeds = p.all_seeds
p.subseeds = p.all_subseeds
_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
import importlib
from modules.sd_hijack import model_hijack
from modules import shared
from modules.shared import cmd_opts
lora_extensions_builtin = importlib.import_module("extensions-builtin.Lora.networks")
lora_lookup = lora_extensions_builtin.available_network_aliases
# load lora
for key, vals in extra_network_data.items():
if key == 'lora':
for val in vals:
if 'Lora' not in models:
models['Lora'] = []
lora_filename = lora_lookup[val.positional[0]].filename.split(os.path.sep)[-1]
if lora_filename not in models['Lora']:
models['Lora'].append(lora_filename)
if key == 'hypernet':
print(key, vals)
for val in vals:
if 'hypernetworks' not in models:
models['hypernetworks'] = []
hypernet_filename = shared.hypernetworks[val.positional[0]].split(os.path.sep)[-1]
if hypernet_filename not in models['hypernetworks']:
models['hypernetworks'].append(hypernet_filename)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
p.setup_conds()
# load textual inversion
for key, val in model_hijack.extra_generation_params.items():
if val.split(': ')[0] not in model_hijack.embedding_db.word_embeddings:
continue
textual_inv_name = \
model_hijack.embedding_db.word_embeddings[val.split(': ')[0]].filename.split(os.path.sep)[-1]
if 'embeddings' not in models:
models['embeddings'] = []
if textual_inv_name not in models['embeddings']:
models['embeddings'].append(textual_inv_name)
# create an inference and upload to s3
# Start creating model on cloud.
url = get_variable_from_json('api_gateway_url')
api_key = get_variable_from_json('api_token')
if not url or not api_key:
logging.debug("Url or API-Key is not setting.")
return
payload = {
'sagemaker_endpoint_name': sagemaker_endpoint,
'task_type': "txt2img" if self.is_txt2img else "img2img",
'models': models,
'filters': {
'creator': datetime.datetime.now().timestamp()
}
}
print(payload)
err = None
inference_id = None
try:
response = requests.post(f'{url}inference/v2', json=payload, headers={'x-api-key': api_key})
response.raise_for_status()
upload_param_response = response.json()
if 'inference' in upload_param_response and 'api_params_s3_upload_url' in upload_param_response['inference']:
upload_s3_resp = requests.put(upload_param_response['inference']['api_params_s3_upload_url'], data=js)
upload_s3_resp.raise_for_status()
inference_id = upload_param_response['inference']['id']
# start run infer
response = requests.put(f'{url}inference/v2/{inference_id}/run', json=payload,
headers={'x-api-key': api_key})
response.raise_for_status()
self.current_inference_id = inference_id
self.inference_queue.put(inference_id)
elif upload_param_response['status'] != 200:
err = upload_param_response['error']
except Exception as e:
err = str(e)
from modules import processing
from modules.processing import Processed
def process_image_inner_hijack(processing_param):
if not self.default_images_inner:
default_processing = importlib.import_module("modules.processing")
self.default_images_inner = default_processing.process_images_inner
if self.default_images_inner:
processing.process_images_inner = self.default_images_inner
if err:
return Processed(
p,
images_list=[],
seed=0,
info=f"Inference job is failed: {', '.join(err)}",
subseed=0,
index_of_first_image=0,
infotexts=[],
)
processed = Processed(
p,
images_list=[],
seed=0,
info=f'Inference job with id {inference_id} has created and running on cloud now. Use Inference job in the SageMaker part to see the result.',
subseed=0,
index_of_first_image=0,
infotexts=[],
)
return processed
default_processing = importlib.import_module("modules.processing")
self.default_images_inner = default_processing.process_images_inner
processing.process_images_inner = process_image_inner_hijack
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
# debug only, may delete later
with open(f'api_{"txt2img" if self.is_txt2img else "img2img"}_param.json', 'w') as f:
f.write(js)
pass
def process(self, p, *args):
pass
def _process_args_by_plugin(self, script_name, arg, current_index, args):
processors = {
'controlnet': self._controlnet_args,
'x/y/z plot': self._xyz_args,
}
models = {}
if script_name not in processors:
return arg, models
f = processors[script_name]
mdls = f(script_name, arg, current_index, args)
for key, val in mdls.items():
if not val:
continue
if key not in models:
models[key] = []
models[key].extend(val)
return arg, models
def _xyz_args(self, script_name, arg, current_index, args) -> Dict[str, List[str]]:
if script_name != 'x/y/z plot':
return {}
if not arg or type(arg) is not list:
return {}
# 10 represent the checkpoint_name option for both img2img and txt2img
# ref: xyz_grid.py#L204
if current_index - 2 < 0 or args[current_index - 2] != 10:
return {}
models = [' '.join(md.split()[:-1]) for md in arg]
for _id, val in enumerate(models):
args[current_index][_id] = val
return {'Stable-diffusion': models}
def _controlnet_args(self, script_name, arg, *_) -> Dict[str, List[str]]:
if script_name != 'controlnet' or not arg.enabled:
return {}
model_name_parts = arg.model.split()
models = []
# make sure there is a hash, otherwise remain not changed
if len(model_name_parts) > 1:
arg.model = ' '.join(model_name_parts[:-1])
if arg.model == 'None':
return {}
models.append(f'{arg.model}.pth')
return {'ControlNet': models}
def on_after_component_callback(component, **_kwargs):
global db_model_name, db_use_txt2img, db_sagemaker_train, db_save_config, cloud_db_model_name, cloud_train_instance_type, training_job_dashboard
is_dreambooth_train = type(component) is gr.Button and getattr(component, 'elem_id', None) == 'db_train'
is_dreambooth_model_name = type(component) is gr.Dropdown and \
(getattr(component, 'elem_id', None) == 'model_name' or \
(getattr(component, 'label', None) == 'Model' and getattr(
component.parent.parent.parent.parent, 'elem_id', None) == 'ModelPanel'))
is_cloud_dreambooth_model_name = type(component) is gr.Dropdown and \
getattr(component, 'elem_id', None) == 'cloud_db_model_name'
is_machine_type_for_train = type(component) is gr.Dropdown and \
getattr(component, 'elem_id', None) == 'cloud_train_instance_type'
is_dreambooth_use_txt2img = type(component) is gr.Checkbox and getattr(component, 'label', None) == 'Use txt2img'
is_training_job_dashboard = type(component) is gr.Dataframe and getattr(component, 'elem_id',
None) == 'training_job_dashboard'
is_db_save_config = getattr(component, 'elem_id', None) == 'db_save_config'
if is_dreambooth_train:
db_sagemaker_train = gr.Button(value="SageMaker Train", elem_id="db_sagemaker_train", variant='primary')
if is_dreambooth_model_name:
db_model_name = component
if is_cloud_dreambooth_model_name:
cloud_db_model_name = component
if is_training_job_dashboard:
training_job_dashboard = component
if is_machine_type_for_train:
cloud_train_instance_type = component
if is_dreambooth_use_txt2img:
db_use_txt2img = component
if is_db_save_config:
db_save_config = component
# After all requiment comment is loaded, add the SageMaker training button click callback function.
if training_job_dashboard is not None and cloud_train_instance_type is not None and \
cloud_db_model_name is not None and db_model_name is not None and \
db_use_txt2img is not None and db_sagemaker_train is not None and \
(
is_dreambooth_train or is_dreambooth_model_name or is_dreambooth_use_txt2img or is_cloud_dreambooth_model_name or is_machine_type_for_train or is_training_job_dashboard):
db_model_name.value = "dummy_local_model"
db_sagemaker_train.click(
fn=async_cloud_train,
_js="db_start_sagemaker_train",
inputs=[
db_model_name,
cloud_db_model_name,
db_use_txt2img,
cloud_train_instance_type
],
outputs=[training_job_dashboard]
)
# Hook image display logic
global txt2img_gallery, txt2img_generation_info, txt2img_html_info, txt2img_show_hook, txt2img_prompt
is_txt2img_gallery = type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'txt2img_gallery'
is_txt2img_generation_info = type(component) is gr.Textbox and getattr(component, 'elem_id',
None) == 'generation_info_txt2img'
is_txt2img_html_info = type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_txt2img'
is_txt2img_prompt = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'txt2img_prompt'
if is_txt2img_prompt:
txt2img_prompt = component
if is_txt2img_gallery:
txt2img_gallery = component
if is_txt2img_generation_info:
txt2img_generation_info = component
if is_txt2img_html_info:
txt2img_html_info = component
# return test
# sagemaker_ui.textual_inversion_dropdown is not None and \
# sagemaker_ui.hyperNetwork_dropdown is not None and \
# sagemaker_ui.lora_dropdown is not None and \
if sagemaker_ui.inference_job_dropdown is not None and \
txt2img_gallery is not None and \
txt2img_generation_info is not None and \
txt2img_html_info is not None and \
txt2img_show_hook is None and \
txt2img_prompt is not None:
txt2img_show_hook = "finish"
sagemaker_ui.inference_job_dropdown.change(
# fn=lambda selected_value: sagemaker_ui.fake_gan(selected_value, txt2img_prompt['value']),
fn=sagemaker_ui.fake_gan,
inputs=[sagemaker_ui.inference_job_dropdown, txt2img_prompt],
outputs=[txt2img_gallery, txt2img_generation_info, txt2img_html_info, txt2img_prompt]
)
# sagemaker_ui.textual_inversion_dropdown.change(
# fn=sagemaker_ui.update_txt2imgPrompt_from_TextualInversion,
# inputs=[sagemaker_ui.textual_inversion_dropdown, txt2img_prompt],
# outputs=[txt2img_prompt]
# )
#
# sagemaker_ui.hyperNetwork_dropdown.change(
# fn=sagemaker_ui.update_txt2imgPrompt_from_Hypernetworks,
# inputs=[sagemaker_ui.hyperNetwork_dropdown, txt2img_prompt],
# outputs=[txt2img_prompt]
# )
#
# sagemaker_ui.lora_dropdown.change(
# fn=sagemaker_ui.update_txt2imgPrompt_from_Lora,
# inputs=[sagemaker_ui.lora_dropdown, txt2img_prompt],
# outputs=[txt2img_prompt]
# )
sagemaker_ui.sagemaker_endpoint.change(
fn=lambda selected_value: sagemaker_ui.displayEndpointInfo(selected_value),
inputs=[sagemaker_ui.sagemaker_endpoint],
outputs=[txt2img_html_info]
)
# elem_id = getattr(component, "elem_id", None)
# if elem_id == "generate_on_cloud_with_cloud_config_button":
# sagemaker_ui.generate_on_cloud_button_with_js.click(
# fn=sagemaker_ui.call_txt2img_inference,
# _js="txt2img_config_save",
# inputs=[sagemaker_ui.sagemaker_endpoint],
# outputs=[txt2img_gallery, txt2img_generation_info, txt2img_html_info]
# )
sagemaker_ui.modelmerger_merge_on_cloud.click(
fn=sagemaker_ui.modelmerger_on_cloud_func,
# fn=None,
_js="txt2img_config_save",
inputs=[sagemaker_ui.sagemaker_endpoint],
# inputs=[
# sagemaker_ui.primary_model_name,
# sagemaker_ui.secondary_model_name,
# sagemaker_ui.tertiary_model_name,
# ],
outputs=[
])
# Hook image display logic
global img2img_gallery, img2img_generation_info, img2img_html_info, img2img_show_hook, \
img2img_prompt, \
init_img, \
sketch, \
init_img_with_mask, \
inpaint_color_sketch, \
init_img_inpaint, \
init_mask_inpaint
is_img2img_gallery = type(component) is gr.Gallery and getattr(component, 'elem_id', None) == 'img2img_gallery'
is_img2img_generation_info = type(component) is gr.Textbox and getattr(component, 'elem_id',
None) == 'generation_info_img2img'
is_img2img_html_info = type(component) is gr.HTML and getattr(component, 'elem_id', None) == 'html_info_img2img'
is_img2img_prompt = type(component) is gr.Textbox and getattr(component, 'elem_id', None) == 'img2img_prompt'
is_init_img = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2img_image'
is_sketch = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2img_sketch'
is_init_img_with_mask = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img2maskimg'
is_inpaint_color_sketch = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'inpaint_sketch'
is_init_img_inpaint = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img_inpaint_base'
is_init_mask_inpaint = type(component) is gr.Image and getattr(component, 'elem_id', None) == 'img_inpaint_mask'
if is_img2img_gallery:
img2img_gallery = component
if is_img2img_generation_info:
img2img_generation_info = component
if is_img2img_html_info:
img2img_html_info = component
if is_img2img_prompt:
img2img_prompt = component
if is_init_img:
init_img = component
if is_sketch:
sketch = component
if is_init_img_with_mask:
init_img_with_mask = component
if is_inpaint_color_sketch:
inpaint_color_sketch = component
if is_init_img_inpaint:
init_img_inpaint = component
if is_init_mask_inpaint:
init_mask_inpaint = component
# sagemaker_ui.textual_inversion_dropdown is not None and \
# sagemaker_ui.hyperNetwork_dropdown is not None and \
# sagemaker_ui.lora_dropdown is not None and \
if sagemaker_ui.inference_job_dropdown is not None and \
img2img_gallery is not None and \
img2img_generation_info is not None and \
img2img_html_info is not None and \
img2img_show_hook is None and \
sagemaker_ui.interrogate_clip_on_cloud_button is not None and \
sagemaker_ui.interrogate_deep_booru_on_cloud_button is not None and \
img2img_prompt is not None and \
init_img is not None and \
sketch is not None and \
init_img_with_mask is not None and \
inpaint_color_sketch is not None and \
init_img_inpaint is not None and \
init_mask_inpaint is not None:
img2img_show_hook = "finish"
sagemaker_ui.inference_job_dropdown.change(
fn=sagemaker_ui.fake_gan,
inputs=[sagemaker_ui.inference_job_dropdown, img2img_prompt],
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info, img2img_prompt]
# outputs=[img2img_gallery, img2img_generation_info, img2img_html_info]
)
# sagemaker_ui.textual_inversion_dropdown.change(
# fn=sagemaker_ui.update_txt2imgPrompt_from_TextualInversion,
# inputs=[sagemaker_ui.textual_inversion_dropdown, img2img_prompt],
# outputs=[img2img_prompt]
# )
#
# sagemaker_ui.hyperNetwork_dropdown.change(
# fn=sagemaker_ui.update_txt2imgPrompt_from_Hypernetworks,
# inputs=[sagemaker_ui.hyperNetwork_dropdown, img2img_prompt],
# outputs=[img2img_prompt]
# )
#
# sagemaker_ui.lora_dropdown.change(
# fn=sagemaker_ui.update_txt2imgPrompt_from_Lora,
# inputs=[sagemaker_ui.lora_dropdown, img2img_prompt],
# outputs=[img2img_prompt]
# )
sagemaker_ui.interrogate_clip_on_cloud_button.click(
fn=sagemaker_ui.call_interrogate_clip,
_js="img2img_config_save",
inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch,
init_img_inpaint, init_mask_inpaint],
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info, img2img_prompt]
)
sagemaker_ui.interrogate_deep_booru_on_cloud_button.click(
fn=sagemaker_ui.call_interrogate_deepbooru,
_js="img2img_config_save",
inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch,
init_img_inpaint, init_mask_inpaint],
outputs=[img2img_gallery, img2img_generation_info, img2img_html_info, img2img_prompt]
)
# sagemaker_ui.generate_on_cloud_button_with_js_img2img.click(
# fn=sagemaker_ui.call_img2img_inference,
# _js="img2img_config_save",
# inputs=[sagemaker_ui.sagemaker_endpoint, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint],
# outputs=[img2img_gallery, img2img_generation_info, img2img_html_info]
# )
def update_connect_config(api_url, api_token):
# Check if api_url ends with '/', if not append it
if not api_url.endswith('/'):
api_url += '/'
save_variable_to_json('api_gateway_url', api_url)
save_variable_to_json('api_token', api_token)
global api_gateway_url
api_gateway_url = get_variable_from_json('api_gateway_url')
global api_key
api_key = get_variable_from_json('api_token')
sagemaker_ui.init_refresh_resource_list_from_cloud()
return "Setting updated"
def test_aws_connect_config(api_url, api_token):
update_connect_config(api_url, api_token)
api_url = get_variable_from_json('api_gateway_url')
api_token = get_variable_from_json('api_token')
if not api_url.endswith('/'):
api_url += '/'
target_url = f'{api_url}inference/test-connection'
headers = {
"x-api-key": api_token,
"Content-Type": "application/json"
}
try:
response = requests.get(target_url,
headers=headers) # Assuming sagemaker_ui.server_request is a wrapper around requests
response.raise_for_status() # Raise an exception if the HTTP request resulted in an error
r = response.json()
return "Successfully Connected"
except requests.exceptions.RequestException as e:
print(f"Error: Failed to get server request. Details: {e}")
return "failed to connect to backend server, please check the url and token"
def on_ui_tabs():
import modules.ui
buildin_model_list = ['AWS JumpStart Model', 'AWS BedRock Model', 'Hugging Face Model']
with gr.Blocks() as sagemaker_interface:
with gr.Row(equal_height=True, elem_id="aws_sagemaker_ui_row", visible=False):
sm_load_params = gr.Button(value="Load Settings", elem_id="aws_load_params", visible=False)
sm_save_params = gr.Button(value="Save Settings", elem_id="aws_save_params", visible=False)
sm_train_model = gr.Button(value="Train", variant="primary", elem_id="aws_train_model", visible=False)
sm_generate_checkpoint = gr.Button(value="Generate Ckpt", elem_id="aws_gen_ckpt", visible=False)
with gr.Row():
gr.HTML(value="Enter your API URL & Token to start the connection.", elem_id="hint_row")
with gr.Row():
with gr.Column(variant="panel", scale=1):
gr.HTML(value="<u><b>AWS Connection Setting</b></u>")
global api_gateway_url
api_gateway_url = get_variable_from_json('api_gateway_url')
global api_key
api_key = get_variable_from_json('api_token')
with gr.Row():
api_url_textbox = gr.Textbox(value=api_gateway_url, lines=1,
placeholder="Please enter API Url of Middle", label="API Url",
elem_id="aws_middleware_api")
def update_api_gateway_url():
global api_gateway_url
api_gateway_url = get_variable_from_json('api_gateway_url')
return api_gateway_url
# modules.ui.create_refresh_button(api_url_textbox, get_variable_from_json('api_gateway_url'), lambda: {"value": get_variable_from_json('api_gateway_url')}, "refresh_api_gate_way")
modules.ui.create_refresh_button(api_url_textbox, update_api_gateway_url,
lambda: {"value": api_gateway_url}, "refresh_api_gateway_url")
with gr.Row():
def update_api_key():
global api_key
api_key = get_variable_from_json('api_token')
return api_key
api_token_textbox = gr.Textbox(value=api_key, lines=1, placeholder="Please enter API Token",
label="API Token", elem_id="aws_middleware_token")
modules.ui.create_refresh_button(api_token_textbox, update_api_key, lambda: {"value": api_key},
"refresh_api_token")
global test_connection_result
test_connection_result = gr.Label(title="Output");
aws_connect_button = gr.Button(value="Update Setting", variant='primary', elem_id="aws_config_save")
aws_connect_button.click(_js="update_auth_settings",
fn=update_connect_config,
inputs=[api_url_textbox, api_token_textbox],
outputs=[test_connection_result])
aws_test_button = gr.Button(value="Test Connection", variant='primary', elem_id="aws_config_test")
aws_test_button.click(test_aws_connect_config, inputs=[api_url_textbox, api_token_textbox],
outputs=[test_connection_result])
with gr.Row():
with gr.Accordion("Disclaimer", open=False):
gr.HTML(
value="You should perform your own independent assessment, and take measures to ensure that you comply with your own specific quality control practices and standards, and the local rules, laws, regulations, licenses and terms of use that apply to you, your content, and the third-party generative AI service in this web UI. Amazon Web Services has no control or authority over the third-party generative AI service in this web UI, and does not make any representations or warranties that the third-party generative AI service is secure, virus-free, operational, or compatible with your production environment and standards.");
with gr.Column(variant="panel", scale=1.5):
gr.HTML(value="<u><b>Cloud Assets Management</b></u>")
sagemaker_html_log = gr.HTML(elem_id=f'html_log_sagemaker')
with gr.Accordion("Upload Model to S3 from WebUI", open=False):
gr.HTML(value="Refresh to select the model to upload to S3")
exts = (".bin", ".pt", ".pth", ".safetensors", ".ckpt")
root_path = os.getcwd()
model_folders = {
"ckpt": os.path.join(root_path, "models", "Stable-diffusion"),
"text": os.path.join(root_path, "embeddings"),
"lora": os.path.join(root_path, "models", "Lora"),
"control": os.path.join(root_path, "models", "ControlNet"),
"hyper": os.path.join(root_path, "models", "hypernetworks"),
"vae": os.path.join(root_path, "models", "VAE"),
}
def scan_sd_ckpt():
model_files = os.listdir(model_folders["ckpt"])
# filter non-model files not in exts
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
model_files = [os.path.join(model_folders["ckpt"], f) for f in model_files]
return model_files
def scan_textual_inversion_model():
model_files = os.listdir(model_folders["text"])
# filter non-model files not in exts
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
model_files = [os.path.join(model_folders["text"], f) for f in model_files]
return model_files
def scan_lora_model():
model_files = os.listdir(model_folders["lora"])
# filter non-model files not in exts
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
model_files = [os.path.join(model_folders["lora"], f) for f in model_files]
return model_files
def scan_control_model():
model_files = os.listdir(model_folders["control"])
# filter non-model files not in exts
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
model_files = [os.path.join(model_folders["control"], f) for f in model_files]
return model_files
def scan_hypernetwork_model():
model_files = os.listdir(model_folders["hyper"])
# filter non-model files not in exts
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
model_files = [os.path.join(model_folders["hyper"], f) for f in model_files]
return model_files
def scan_vae_model():
model_files = os.listdir(model_folders["vae"])
# filter non-model files not in exts
model_files = [f for f in model_files if os.path.splitext(f)[1] in exts]
model_files = [os.path.join(model_folders["vae"], f) for f in model_files]
return model_files
with FormRow(elem_id="model_upload_form_row_01"):
sd_checkpoints_path = gr.Dropdown(label="SD Checkpoints", choices=sorted(scan_sd_ckpt()), elem_id="sd_ckpt_dropdown")
create_refresh_button(sd_checkpoints_path, scan_sd_ckpt, lambda: {"choices": sorted(scan_sd_ckpt())}, "refresh_sd_ckpt")
textual_inversion_path = gr.Dropdown(label="Textual Inversion", choices=sorted(scan_textual_inversion_model()),elem_id="textual_inversion_model_dropdown")
create_refresh_button(textual_inversion_path, scan_textual_inversion_model, lambda: {"choices": sorted(scan_textual_inversion_model())}, "refresh_textual_inversion_model")
with FormRow(elem_id="model_upload_form_row_02"):
lora_path = gr.Dropdown(label="LoRA model", choices=sorted(scan_lora_model()), elem_id="lora_model_dropdown")
create_refresh_button(lora_path, scan_lora_model, lambda: {"choices": sorted(scan_lora_model())}, "refresh_lora_model",)
controlnet_model_path = gr.Dropdown(label="ControlNet model", choices=sorted(scan_control_model()), elem_id="controlnet_model_dropdown")
create_refresh_button(controlnet_model_path, scan_control_model, lambda: {"choices": sorted(scan_control_model())}, "refresh_controlnet_models")
with FormRow(elem_id="model_upload_form_row_03"):
hypernetwork_path = gr.Dropdown(label="Hypernetwork", choices=sorted(scan_hypernetwork_model()),elem_id="hyper_model_dropdown")
create_refresh_button(hypernetwork_path, scan_hypernetwork_model, lambda: {"choices": sorted(scan_hypernetwork_model())}, "refresh_hyper_models")
vae_path = gr.Dropdown(label="VAE", choices=sorted(scan_vae_model()), elem_id="vae_model_dropdown")
create_refresh_button(vae_path, scan_vae_model, lambda: {"choices": sorted(scan_vae_model())}, "refresh_vae_models")
with gr.Row():
model_update_button = gr.Button(value="Upload Models to Cloud", variant="primary",elem_id="sagemaker_model_update_button", size=(200, 50))
model_update_button.click(_js="model_update",
fn=sagemaker_ui.sagemaker_upload_model_s3,
inputs=[sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path, vae_path],
outputs=[test_connection_result, sd_checkpoints_path, textual_inversion_path, lora_path, hypernetwork_path, controlnet_model_path, vae_path])
with gr.Accordion("Upload Model to S3 from My Computer", open=False):
gr.HTML(value="Refresh to select the model to upload to S3")
with FormRow(elem_id="model_upload_local_form_row_01"):
model_type_drop_down = gr.Dropdown(label="Model Type", choices=["SD Checkpoints", "Textual Inversion", "LoRA model", "ControlNet model", "Hypernetwork", "VAE"], elem_id="model_type_ele_id")
model_type_hiden_text = gr.Textbox(elem_id="model_type_value_ele_id", visible=False)
def change_model_type_value(model_type: str):
model_type_hiden_text.value = model_type
return model_type
model_type_drop_down.change(fn=change_model_type_value, _js="getModelTypeValue",
inputs=[model_type_drop_down], outputs=model_type_hiden_text)
file_upload_html_component = gr.HTML('<div class="lg svelte-1ipelgc"><div class="lg svelte-1ipelgc"><input type="file" class="lg secondary gradio-button svelte-1ipelgc" id="file-uploader" multiple onchange="showFileName(event)" style="width:100%"></div></div>')
with FormRow(elem_id="model_upload_local_form_row_02"):
hidden_bind_html = gr.HTML(elem_id="hidden_bind_upload_files", value="<div id='hidden_bind_upload_files_html'></div>")
with FormRow(elem_id="model_upload_local_form_row_03"):
upload_label = gr.HTML(label="upload process", elem_id="progress-bar")
upload_percent_label = gr.HTML(label="upload percent process", elem_id="progress-percent")
with gr.Row():
model_update_button_local = gr.Button(value="Upload Models to Cloud", variant="primary", elem_id="sagemaker_model_update_button_local", size=(200, 50))
model_update_button_local.click(_js="uploadFiles",
fn=sagemaker_ui.sagemaker_upload_model_s3_local,
# inputs=[sagemaker_ui.checkpoint_info],
outputs=[upload_label]
)
with gr.Blocks(title="Deploy New SageMaker Endpoint", variant='panel'):
gr.HTML(value="<b>Deploy New SageMaker Endpoint</b>")
default_table = """
<table style="width:100%; border: 1px solid black; border-collapse: collapse;">
<tr>
<th style="border: 1px solid grey; padding: 15px; text-align: left; background-color: #f2f2f2;" colspan="2">Default SageMaker Endpoint Config</th>
</tr>
<tr>
<td style="border: 1px solid grey; padding: 15px; text-align: left;"><b>Instance Type: </b></td>
<td style="border: 1px solid grey; padding: 15px; text-align: left;">ml.g5.2xlarge</td>
</tr>
<tr>
<td style="border: 1px solid grey; padding: 15px; text-align: left;"><b>Instance Count</b></td>
<td style="border: 1px solid grey; padding: 15px; text-align: left;">1</td>
</tr>
<tr>
<td style="border: 1px solid grey; padding: 15px; text-align: left;"><b>Automatic Scaling</b></td>
<td style="border: 1px solid grey; padding: 15px; text-align: left;">yes(range:0-1)</td>
</tr>
</table>
"""
gr.HTML(value=default_table)
with gr.Row():
# instance_type_dropdown = gr.Dropdown(label="SageMaker Instance Type", choices=async_inference_choices, elem_id="sagemaker_inference_instance_type_textbox", value="ml.g4dn.xlarge")
# instance_count_dropdown = gr.Dropdown(label="Please select Instance count", choices=["1","2","3","4"], elem_id="sagemaker_inference_instance_count_textbox", value="1")
endpoint_advance_config_enabled = gr.Checkbox(
label="Advanced Endpoint Configuration", value=False, visible=True
)
# with gr.Row(variant='panel', visible=False) as filter_row:
with gr.Row(variant='panel', visible=False) as filter_row:
endpoint_name_textbox = gr.Textbox(value="", lines=1, placeholder="custome endpoint name ", label="Specify Endpoint Name", visible=True)
instance_type_dropdown = gr.Dropdown(label="Instance Type", choices=async_inference_choices, elem_id="sagemaker_inference_instance_type_textbox", value="ml.g5.2xlarge")
instance_count_dropdown = gr.Dropdown(label="Max Instance count", choices=["1","2","3","4","5","6"], elem_id="sagemaker_inference_instance_count_textbox", value="1")
autoscaling_enabled = gr.Checkbox(
label="Enable Autoscaling (0 to Max Instance count)", value=True, visible=True
)
def toggle_new_rows(checkbox_state):
if checkbox_state:
return gr.update(visible=True)
else:
return gr.update(visible=False)
endpoint_advance_config_enabled.change(
fn=toggle_new_rows,
inputs=endpoint_advance_config_enabled,
outputs=filter_row
)
with gr.Row():
sagemaker_deploy_button = gr.Button(value="Deploy", variant='primary',
elem_id="sagemaker_deploy_endpoint_buttion")
sagemaker_deploy_button.click(sagemaker_ui.sagemaker_deploy,
_js="deploy_endpoint", \
inputs = [endpoint_name_textbox, instance_type_dropdown, instance_count_dropdown, autoscaling_enabled],
outputs=[test_connection_result])
with gr.Blocks(title="Delete SageMaker Endpoint", variant='panel'):
gr.HTML(value="<u><b>Delete SageMaker Endpoint</b></u>")
with gr.Row():
sagemaker_endpoint_delete_dropdown = gr.Dropdown(choices=sagemaker_ui.sagemaker_endpoints,
multiselect=True,
label="Select Cloud SageMaker Endpoint")
modules.ui.create_refresh_button(sagemaker_endpoint_delete_dropdown,
sagemaker_ui.update_sagemaker_endpoints,
lambda: {"choices": sagemaker_ui.sagemaker_endpoints},
"refresh_sagemaker_endpoints_delete")
sagemaker_endpoint_delete_button = gr.Button(value="Delete", variant='primary',
elem_id="sagemaker_endpoint_delete_button")
sagemaker_endpoint_delete_button.click(sagemaker_ui.sagemaker_endpoint_delete,
_js="delete_sagemaker_endpoint", \
inputs=[sagemaker_endpoint_delete_dropdown],
outputs=[test_connection_result])
with gr.Column(variant="panel", scale=1):
# TODO: uncomment if implemented, comment since the tab component do not has visible parameter
# with gr.Blocks(title="Deploy New SageMaker Endpoint", variant='panel', visible=False):
# gr.HTML(value="<u><b>AWS Model Setting</b></u>", visible=False)
# with gr.Tab("Select"):
# gr.HTML(value="AWS Built-in Model", visible=False)
# model_select_dropdown = gr.Dropdown(buildin_model_list, label="Select Built-In Model", elem_id="aws_select_model", visible=False)
# with gr.Tab("Create"):
# gr.HTML(value="AWS Custom Model", visible=False)
# model_name_textbox = gr.Textbox(value="", lines=1, placeholder="Please enter model name", label="Model Name", visible=False)
# model_create_button = gr.Button(value="Create Model", variant='primary',elem_id="aws_create_model", visible=False)
with gr.Blocks(title="Create AWS dataset", variant='panel'):
gr.HTML(value="<u><b>AWS Dataset Management</b></u>")
with gr.Tab("Create"):
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
file_output = gr.File()
upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"],
file_count="multiple")
upload_button.upload(fn=upload_file, inputs=[upload_button], outputs=[file_output])
def create_dataset(files, dataset_name, dataset_desc):
print(dataset_name)
dataset_content = []
file_path_lookup = {}
for file in files:
orig_name = file.name.split(os.sep)[-1]
file_path_lookup[orig_name] = file.name
dataset_content.append(
{
"filename": orig_name,
"name": orig_name,
"type": "image",
"params": {}
}
)
payload = {
"dataset_name": dataset_name,
"content": dataset_content,
"params": {
"description": dataset_desc
}
}
url = get_variable_from_json('api_gateway_url') + '/dataset'
api_key = get_variable_from_json('api_token')
raw_response = requests.post(url=url, json=payload, headers={'x-api-key': api_key})
raw_response.raise_for_status()
response = raw_response.json()
print(f"Start upload sample files response:\n{response}")
for filename, presign_url in response['s3PresignUrl'].items():
file_path = file_path_lookup[filename]
with open(file_path, 'rb') as f:
response = requests.put(presign_url, f)
print(response)
response.raise_for_status()
payload = {
"dataset_name": dataset_name,
"status": "Enabled"
}
raw_response = requests.put(url=url, json=payload, headers={'x-api-key': api_key})
raw_response.raise_for_status()
print(raw_response.json())
return f'Complete Dataset {dataset_name} creation', None, None, None, None
dataset_name_upload = gr.Textbox(value="", lines=1, placeholder="Please input dataset name",
label="Dataset Name", elem_id="sd_dataset_name_textbox")
dataset_description_upload = gr.Textbox(value="", lines=1,
placeholder="Please input dataset description",
label="Dataset Description",
elem_id="sd_dataset_description_textbox")
create_dataset_button = gr.Button("Create Dataset", variant="primary",
elem_id="sagemaker_dataset_create_button") # size=(200, 50)
dataset_create_result = gr.Textbox(value="", label="Create Result", interactive=False)
create_dataset_button.click(
fn=create_dataset,
inputs=[upload_button, dataset_name_upload, dataset_description_upload],
outputs=[
dataset_create_result,
dataset_name_upload,
dataset_description_upload,
file_output,
upload_button
],
show_progress=True
)
with gr.Tab('Browse'):
with gr.Row():
global cloud_datasets
cloud_datasets = get_sorted_cloud_dataset()
cloud_dataset_name = gr.Dropdown(
label="Dataset From Cloud",
choices=[d['datasetName'] for d in cloud_datasets],
elem_id="cloud_dataset_dropdown",
type="index",
info='select datasets from cloud'
)
def refresh_datasets():
global cloud_datasets
cloud_datasets = get_sorted_cloud_dataset()
return cloud_datasets
def refresh_datasets_dropdown():
global cloud_datasets
cloud_datasets = get_sorted_cloud_dataset()
return {"choices": [d['datasetName'] for d in cloud_datasets]}
create_refresh_button(
cloud_dataset_name,
refresh_datasets,
refresh_datasets_dropdown,
"refresh_cloud_dataset",
)
with gr.Row():
dataset_s3_output = gr.Textbox(label='dataset s3 location', show_label=True,
type='text').style(show_copy_button=True)
with gr.Row():
dataset_des_output = gr.Textbox(label='dataset description', show_label=True, type='text')
with gr.Row():
dataset_gallery = gr.Gallery(
label="Dataset images", show_label=False, elem_id="gallery",
).style(columns=[2], rows=[2], object_fit="contain", height="auto")
def get_results_from_datasets(dataset_idx):
ds = cloud_datasets[dataset_idx]
url = f"{get_variable_from_json('api_gateway_url')}/dataset/{ds['datasetName']}/data"
api_key = get_variable_from_json('api_token')
raw_response = requests.get(url=url, headers={'x-api-key': api_key})
raw_response.raise_for_status()
dataset_items = [(item['preview_url'], item['key']) for item in
raw_response.json()['data']]
return ds['s3'], ds['description'], dataset_items
cloud_dataset_name.select(fn=get_results_from_datasets, inputs=[cloud_dataset_name],
outputs=[dataset_s3_output, dataset_des_output, dataset_gallery])
return (sagemaker_interface, "Amazon SageMaker", "sagemaker_interface"),
script_callbacks.on_after_component(on_after_component_callback)
script_callbacks.on_ui_tabs(on_ui_tabs)
# create new tabs for create Model
origin_callback = script_callbacks.ui_tabs_callback
def avoid_duplicate_from_restart_ui(res):
for extension_ui in res:
if extension_ui[1] == 'Dreambooth':
for key in list(extension_ui[0].blocks):
val = extension_ui[0].blocks[key]
if type(val) is gr.Tab:
if val.label == 'Select From Cloud':
return True
return False
def ui_tabs_callback():
res = origin_callback()
if avoid_duplicate_from_restart_ui(res):
return res
for extension_ui in res:
if extension_ui[1] == 'Dreambooth':
for key in list(extension_ui[0].blocks):
val = extension_ui[0].blocks[key]
if type(val) is gr.Tab:
if val.label == 'Select':
with extension_ui[0]:
with val.parent:
with gr.Tab('Select From Cloud'):
with gr.Row():
cloud_db_model_name = gr.Dropdown(
label="Model", choices=sorted(get_cloud_db_model_name_list()),
elem_id="cloud_db_model_name"
)
create_refresh_button(
cloud_db_model_name,
get_cloud_db_model_name_list,
lambda: {"choices": sorted(get_cloud_db_model_name_list())},
"refresh_db_models",
)
with gr.Row():
cloud_db_snapshot = gr.Dropdown(
label="Cloud Snapshot to Resume",
choices=sorted(get_cloud_model_snapshots()),
elem_id="cloud_snapshot_to_resume_dropdown"
)
create_refresh_button(
cloud_db_snapshot,
get_cloud_model_snapshots,
lambda: {"choices": sorted(get_cloud_model_snapshots())},
"refresh_db_snapshots",
)
with gr.Row():
cloud_train_instance_type = gr.Dropdown(
label="SageMaker Train Instance Type",
choices=['ml.g4dn.2xlarge', 'ml.g5.2xlarge'],
elem_id="cloud_train_instance_type",
info='select SageMaker Train Instance Type'
)
with gr.Row(visible=False) as lora_model_row:
cloud_db_lora_model_name = gr.Dropdown(
label="Lora Model", choices=get_sorted_lora_cloud_models(),
elem_id="cloud_lora_model_dropdown"
)
create_refresh_button(
cloud_db_lora_model_name,
get_sorted_lora_cloud_models,
lambda: {"choices": get_sorted_lora_cloud_models()},
"refresh_lora_models",
)
with gr.Row():
gr.HTML(value="Loaded Model from Cloud:")
cloud_db_model_path = gr.HTML()
with gr.Row():
gr.HTML(value="Cloud Model Revision:")
cloud_db_revision = gr.HTML(elem_id="cloud_db_revision")
with gr.Row():
gr.HTML(value="Cloud Model Epoch:")
cloud_db_epochs = gr.HTML(elem_id="cloud_db_epochs")
with gr.Row():
gr.HTML(value="V2 Model From Cloud:")
cloud_db_v2 = gr.HTML(elem_id="cloud_db_v2")
with gr.Row():
gr.HTML(value="Has EMA:")
cloud_db_has_ema = gr.HTML(elem_id="cloud_db_has_ema")
with gr.Row():
gr.HTML(value="Source Checkpoint From Cloud:")
cloud_db_src = gr.HTML()
with gr.Row():
gr.HTML(value="Cloud DB Status:")
cloud_db_status = gr.HTML(elem_id="db_status", value="")
with gr.Row():
gr.HTML(value="Experimental Shared Source:")
cloud_db_shared_diffusers_path = gr.HTML()
with gr.Row():
gr.HTML(value="<b>Training Jobs Details:<b/>")
with gr.Row():
training_job_dashboard = gr.Dataframe(
headers=["id", "model name", "status", "SageMaker train name"],
datatype=["str", "str", "str", "str"],
col_count=(4, "fixed"),
value=get_train_job_list,
interactive=False,
every=3,
elem_id='training_job_dashboard'
# show_progress=True
)
with gr.Tab('Create From Cloud'):
with gr.Column():
cloud_db_create_model = gr.Button(
value="Create Model From Cloud", variant="primary"
)
cloud_db_new_model_name = gr.Textbox(label="Name",
placeholder="Model names can only contain alphanumeric and -")
with gr.Row():
cloud_db_create_from_hub = gr.Checkbox(
label="Create From Hub", value=False, visible=False
)
cloud_db_512_model = gr.Checkbox(label="512x Model", value=True)
with gr.Column(visible=False) as hub_row:
cloud_db_new_model_url = gr.Textbox(
label="Model Path",
placeholder="runwayml/stable-diffusion-v1-5",
elem_id="cloud_db_model_path_text_box"
)
cloud_db_new_model_token = gr.Textbox(
label="HuggingFace Token", value=""
)
with gr.Column(visible=True) as local_row:
with gr.Row():
cloud_db_new_model_src = gr.Dropdown(
label="Source Checkpoint",
choices=sorted(get_sd_cloud_models()),
elem_id="cloud_db_source_checkpoint_dropdown"
)
create_refresh_button(
cloud_db_new_model_src,
get_sd_cloud_models,
lambda: {"choices": sorted(get_sd_cloud_models())},
"refresh_sd_models",
)
with gr.Column(visible=False) as shared_row:
with gr.Row():
cloud_db_new_model_shared_src = gr.Dropdown(
label="EXPERIMENTAL: LoRA Shared Diffusers Source",
choices=[],
value=""
)
cloud_db_new_model_extract_ema = gr.Checkbox(
label="Extract EMA Weights", value=False
)
cloud_db_train_unfrozen = gr.Checkbox(label="Unfreeze Model", value=False,
elem_id="cloud_db_unfreeze_model_checkbox")
with gr.Row():
gr.HTML(value="<b>Model Creation Jobs Details:<b/>")
with gr.Row():
createmodel_dashboard = gr.Dataframe(
headers=["id", "model name", "status"],
datatype=["str", "str", "str"],
col_count=(3, "fixed"),
value=get_create_model_job_list,
interactive=False,
every=3
# show_progress=True
)
def toggle_new_rows(create_from):
return gr.update(visible=create_from), gr.update(visible=not create_from)
cloud_db_create_from_hub.change(
fn=toggle_new_rows,
inputs=[cloud_db_create_from_hub],
outputs=[hub_row, local_row],
)
cloud_db_model_name.change(
_js="clear_loaded",
fn=wrap_load_model_params,
inputs=[cloud_db_model_name],
outputs=[
cloud_db_model_path,
cloud_db_revision,
cloud_db_epochs,
cloud_db_v2,
cloud_db_has_ema,
cloud_db_src,
cloud_db_shared_diffusers_path,
cloud_db_snapshot,
cloud_db_lora_model_name,
cloud_db_status,
],
)
cloud_db_create_model.click(
fn=cloud_create_model,
_js="check_create_model_params",
inputs=[
cloud_db_new_model_name,
cloud_db_new_model_src,
cloud_db_new_model_shared_src,
cloud_db_create_from_hub,
cloud_db_new_model_url,
cloud_db_new_model_token,
cloud_db_new_model_extract_ema,
cloud_db_train_unfrozen,
cloud_db_512_model,
],
outputs=[
createmodel_dashboard
# cloud_db_new_model_name
# cloud_db_create_from_hub
# cloud_db_512_model
# cloud_db_new_model_url
# cloud_db_new_model_token
# cloud_db_new_model_src
]
)
break
return res
script_callbacks.ui_tabs_callback = ui_tabs_callback
def get_sorted_lora_cloud_models():
return []
def get_cloud_model_snapshots():
return []