diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f781c71
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,17 @@
+# for vim
+.~
+.*.swp
+*~
+
+# for MacOS
+.DS_Store
+
+__pycache__/
+
+.idea/
+
+logs/
+flagged/
+
+configs/civitai_models.json
+configs/liandange_models.json
diff --git a/install.py b/install.py
new file mode 100644
index 0000000..6920515
--- /dev/null
+++ b/install.py
@@ -0,0 +1,30 @@
+import launch
+import os
+import gzip
+import io
+
+
+def install_preset_models_if_needed():
+ assets_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets")
+ configs_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
+
+ for model_filename in ["civitai_models.json", "liandange_models.json"]:
+ gzip_file = os.path.join(assets_folder, f"{model_filename}.gz")
+ target_file = os.path.join(configs_folder, f"{model_filename}")
+ if not os.path.exists(target_file):
+ with gzip.open(gzip_file, "rb") as compressed_file:
+ with io.TextIOWrapper(compressed_file, encoding="utf-8") as decoder:
+ content = decoder.read()
+ with open(target_file, "w") as model_file:
+ model_file.write(content)
+
+
+req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
+
+with open(req_file) as file:
+ for lib in file:
+ lib = lib.strip()
+ if not launch.is_installed(lib):
+ launch.run_pip(f"install {lib}", f"Miaoshou assistant requirement: {lib}")
+
+install_preset_models_if_needed()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e8bf606
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+psutil
+rehash
+tqdm
diff --git a/scripts/assistant/__init__.py b/scripts/assistant/__init__.py
new file mode 100644
index 0000000..bf393ef
--- /dev/null
+++ b/scripts/assistant/__init__.py
@@ -0,0 +1 @@
+__all__ = ["miaoshou"]
diff --git a/scripts/assistant/miaoshou.py b/scripts/assistant/miaoshou.py
new file mode 100644
index 0000000..69edd52
--- /dev/null
+++ b/scripts/assistant/miaoshou.py
@@ -0,0 +1,256 @@
+import os
+import sys
+import typing as t
+
+import gradio as gr
+
+import launch
+import modules
+from scripts.logging.msai_logger import Logger
+from scripts.runtime.msai_prelude import MiaoshouPrelude
+from scripts.runtime.msai_runtime import MiaoshouRuntime
+
+
+class MiaoShouAssistant(object):
+ # default css definition
+ default_css = '#my_model_cover{width: 100px;} #my_model_trigger_words{width: 200px;}'
+
+ def __init__(self) -> None:
+ self.logger = Logger()
+ self.prelude = MiaoshouPrelude()
+ self.runtime = MiaoshouRuntime()
+ self.refresh_symbol = '\U0001f504'
+
+ def on_event_ui_tabs_opened(self) -> t.List[t.Optional[t.Tuple[t.Any, str, str]]]:
+ with gr.Blocks(analytics_enabled=False, css=MiaoShouAssistant.default_css) as miaoshou_assistant:
+ self.create_subtab_boot_assistant()
+ self.create_subtab_model_management()
+ self.create_subtab_model_download()
+
+ return [(miaoshou_assistant.queue(), "Miaoshou Assistant", "miaoshou_assistant")]
+
+ def create_subtab_boot_assistant(self) -> None:
+ with gr.TabItem('Boot Assistant', elem_id="boot_assistant_tab") as boot_assistant:
+ with gr.Row():
+ with gr.Column(elem_id="col_model_list"):
+ gpu, theme, port, chk_args, txt_args, webui_ver = self.runtime.get_default_args()
+ gr.Markdown(value="Argument settings")
+ with gr.Row():
+ drp_gpu = gr.Dropdown(label="", elem_id="drp_args_vram",
+ choices=list(self.prelude.gpu_setting.keys()),
+ value=gpu, interactive=True)
+ drp_theme = gr.Dropdown(label="UI Theme", choices=list(self.prelude.theme_setting.keys()),
+ value=theme,
+ elem_id="drp_args_theme", interactive=True)
+ txt_listen_port = gr.Text(label='Listen Port', value=port, elem_id="txt_args_listen_port",
+ interactive=True)
+
+ with gr.Row():
+ chk_group_args = gr.CheckboxGroup(choices=list(self.prelude.checkboxes.keys()), value=chk_args,
+ show_label=False)
+ additional_args = gr.Text(label='COMMANDLINE_ARGS (Divide by space)', value=txt_args,
+ elem_id="txt_args_more", interactive=True)
+
+ with gr.Row():
+ with gr.Column():
+ txt_save_status = gr.Markdown(visible=False, interactive=False, show_label=False)
+ drp_choose_version = gr.Dropdown(label="WebUI Version",
+ choices=['Official Release', 'Python Integrated'],
+ value=webui_ver, elem_id="drp_args_version",
+ interactive=True)
+ gr.HTML(
+ '
*Save your settings to webui-user.bat file. Use Python Integrated only if your'
+ ' WebUI is extracted from a zip file and does not need python installation
')
+ save_settings = gr.Button(value="Save settings", elem_id="btn_arg_save_setting")
+
+ with gr.Row():
+ # with gr.Column():
+ # settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="ms_settings_submit")
+ # with gr.Column():
+ restart_gradio = gr.Button(value='Apply & Restart WebUI', variant='primary',
+ elem_id="ms_settings_restart_gradio")
+
+ '''def mod_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
+ global commandline_args
+
+ get_final_args(drp_gpu, drp_theme, txt_listen_port, hk_group_args, additional_args)
+
+ print(commandline_args)
+ print(sys.argv)
+ #if '--xformers' not in sys.argv:
+ #sys.argv.append('--xformers')
+
+ settings_submit.click(mod_args, inputs=[drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args], outputs=[])'''
+
+ save_settings.click(self.runtime.change_boot_setting,
+ inputs=[drp_choose_version, drp_gpu, drp_theme, txt_listen_port, chk_group_args,
+ additional_args], outputs=[txt_save_status])
+
+ restart_gradio.click(
+ fn=self.request_restart,
+ _js='restart_reload',
+ inputs=[],
+ outputs=[],
+ )
+
+ with gr.Column():
+ with gr.Row():
+ machine_settings = self.prelude.get_sys_info()
+ txt_sys_info = gr.TextArea(value=machine_settings, lines=20, max_lines=20,
+ label="System Info",
+ show_label=False, interactive=False)
+ with gr.Row():
+ sys_info_refbtn = gr.Button(value="Refresh")
+
+ drp_gpu.change(self.runtime.update_xformers, inputs=[drp_gpu, chk_group_args], outputs=[chk_group_args])
+ sys_info_refbtn.click(self.prelude.get_sys_info, None, txt_sys_info)
+
+ def create_subtab_model_management(self) -> None:
+ with gr.TabItem('Model Management', elem_id="model_management_tab") as tab_batch:
+ with gr.Row():
+ with gr.Column():
+ my_models = self.runtime.get_local_models()
+ ds_my_models = gr.Dataset(
+ components=[gr.HTML(visible=False, label='Cover', elem_id='my_model_cover'),
+ gr.Textbox(visible=False, label='Name/Version'),
+ gr.Textbox(visible=False, label='File Name'),
+ gr.Textbox(visible=False, label='Hash'), gr.Textbox(visible=False, label='Creator'),
+ gr.Textbox(visible=False, label='Type'), gr.Textbox(visible=False, label='NSFW'),
+ gr.Textbox(visible=False, label='Trigger Words', elem_id='my_model_trigger_words')],
+ elem_id='my_model_lib',
+ label="My Models",
+ headers=None,
+ samples=my_models,
+ samples_per_page=50)
+ with gr.Column():
+ html_model_prompt = gr.HTML(visible=True,
+ value='')
+
+ with gr.Row():
+ add = gr.Button(value="Add", variant="primary")
+ # delete = gr.Button(value="Delete")
+ with gr.Row():
+ reset_btn = gr.Button(value="Reset")
+ json_input = gr.Button(value="Load from JSON")
+ png_input = gr.Button(value="Detect from image")
+ png_input_area = gr.Image(label="Detect from image", elem_id="openpose_editor_input")
+ bg_input = gr.Button(value="Add Background image")
+
+ def create_subtab_model_download(self) -> None:
+ with gr.TabItem('Model Download', elem_id="model_download_tab") as tab_downloads:
+ with gr.Row():
+ with gr.Column(elem_id="col_model_list"):
+ with gr.Row().style(equal_height=True):
+ model_source_dropdown = gr.Dropdown(choices=["civitai", "liandange"],
+ value=self.runtime.model_source,
+ label="Select Model Source",
+ type="value",
+ show_label=True,
+ elem_id="model_source").style(full_width=True)
+ with gr.Row().style(equal_height=True):
+ search_text = gr.Textbox(
+ label="Model name",
+ show_label=False,
+ max_lines=1,
+ placeholder="Enter model name",
+ )
+ btn_search = gr.Button("Search")
+
+ with gr.Row().style(equal_height=True):
+ nsfw_checker = gr.Checkbox(label='NSFW', value=False, elem_id="chk_nsfw", interactive=True)
+ model_type = gr.Radio(["All", "Checkpoint", "LORA", "TextualInversion", "Hypernetwork"],
+ show_label=False, value='All', elem_id="rad_model_type",
+ interactive=True).style(full_width=True)
+
+ images = self.runtime.get_images_html()
+ self.runtime.ds_models = gr.Dataset(
+ components=[gr.HTML(visible=False)],
+ headers=None,
+ type="values",
+ label="Models",
+ samples=images,
+ samples_per_page=60,
+ elem_id="model_dataset").style(type="gallery", container=True)
+
+ with gr.Column(elem_id="col_model_info"):
+ with gr.Row():
+ cover_gallery = gr.Gallery(label="Cover", show_label=False, visible=True).style(grid=[4],
+ height="2")
+
+ with gr.Row():
+ with gr.Column():
+ download_summary = gr.HTML('No downloading tasks ongoing
')
+ downloading_status = gr.Button(value=f"{self.refresh_symbol} Refresh Downloading Status",
+ elem_id="ms_dwn_status")
+ with gr.Row():
+ model_dropdown = gr.Dropdown(choices=['Select Model'], label="Models", show_label=False,
+ value='Select Model', elem_id='ms_dwn_button',
+ interactive=True)
+
+ is_civitai_model_source_active = self.runtime.model_source == "civitai"
+ with gr.Row(variant="panel"):
+ dwn_button = gr.Button(value='Download',
+ visible=is_civitai_model_source_active, elem_id='ms_dwn_button')
+ open_url_in_browser_newtab_button = gr.HTML(
+ value=''
+ 'Download
',
+ visible=not is_civitai_model_source_active)
+ with gr.Row():
+ model_info = gr.HTML(visible=True)
+
+ nsfw_checker.change(self.runtime.set_nsfw, inputs=[search_text, nsfw_checker, model_type],
+ outputs=self.runtime.ds_models)
+
+ model_type.change(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models)
+
+ btn_search.click(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models)
+
+ self.runtime.ds_models.click(self.runtime.get_model_info,
+ inputs=[self.runtime.ds_models],
+ outputs=[
+ cover_gallery,
+ model_dropdown,
+ model_info,
+ open_url_in_browser_newtab_button
+ ])
+
+ dwn_button.click(self.runtime.download_model, inputs=[model_dropdown], outputs=[download_summary])
+ downloading_status.click(self.runtime.get_downloading_status, inputs=[], outputs=[download_summary])
+
+ model_source_dropdown.change(self.switch_model_source,
+ inputs=[model_source_dropdown],
+ outputs=[self.runtime.ds_models, dwn_button, open_url_in_browser_newtab_button])
+
+ def request_restart(self, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
+ print('request_restart: cmd_arg = ', self.runtime.cmdline_args)
+ print('request_restart: sys.argv = ', sys.argv)
+
+ modules.shared.state.interrupt()
+ modules.shared.state.need_restart = True
+
+ # reset args
+ sys.argv = [sys.argv[0]]
+ os.environ['COMMANDLINE_ARGS'] = ""
+ print('remove', sys.argv)
+
+ for arg in self.runtime.cmdline_args:
+ sys.argv.append(arg)
+
+ print('after', sys.argv)
+ launch.prepare_environment()
+ launch.start()
+
+ def switch_model_source(self, new_model_source: str):
+ self.runtime.model_source = new_model_source
+ show_download_button = self.runtime.model_source == "civitai"
+ images = self.runtime.get_images_html()
+ self.runtime.ds_models.samples = images
+ return (
+ gr.Dataset.update(samples=images),
+ gr.Button.update(visible=show_download_button),
+ gr.HTML.update(visible=not show_download_button)
+ )
+
+ def introception(self) -> None:
+ self.runtime.introception()
diff --git a/scripts/download/__init__.py b/scripts/download/__init__.py
new file mode 100644
index 0000000..8ce6385
--- /dev/null
+++ b/scripts/download/__init__.py
@@ -0,0 +1 @@
+__all__ = ["msai_downloader_manager"]
diff --git a/scripts/download/msai_downloader_manager.py b/scripts/download/msai_downloader_manager.py
new file mode 100644
index 0000000..ea11e20
--- /dev/null
+++ b/scripts/download/msai_downloader_manager.py
@@ -0,0 +1,248 @@
+import asyncio
+import os.path
+import queue
+import time
+import requests
+import typing as t
+from threading import Thread, Lock
+
+from scripts.download.msai_file_downloader import MiaoshouFileDownloader
+from scripts.logging.msai_logger import Logger
+from scripts.msai_utils.msai_singleton import MiaoshouSingleton
+import scripts.msai_utils.msai_toolkit as toolkit
+
+
+class DownloadingEntry(object):
+ def __init__(self, target_url: str = None, local_file: str = None,
+ local_directory: str = None, estimated_total_size: float = 0., expected_checksum: str = None):
+ self._target_url = target_url
+ self._local_file = local_file
+ self._local_directory = local_directory
+ self._expected_checksum = expected_checksum
+
+ self._estimated_total_size = estimated_total_size
+ self._total_size = 0
+ self._downloaded_size = 0
+
+ self._downloading = False
+ self._failure = False
+
+ @property
+ def target_url(self) -> str:
+ return self._target_url
+
+ @property
+ def local_file(self) -> str:
+ return self._local_file
+
+ @property
+ def local_directory(self) -> str:
+ return self._local_directory
+
+ @property
+ def expected_checksum(self) -> str:
+ return self._expected_checksum
+
+ @property
+ def total_size(self) -> int:
+ return self._total_size
+
+ @total_size.setter
+ def total_size(self, sz: int) -> None:
+ self._total_size = sz
+
+ @property
+ def downloaded_size(self) -> int:
+ return self._downloaded_size
+
+ @downloaded_size.setter
+ def downloaded_size(self, sz: int) -> None:
+ self._downloaded_size = sz
+
+ @property
+ def estimated_size(self) -> float:
+ return self._estimated_total_size
+
+ def is_downloading(self) -> bool:
+ return self._downloading
+
+ def start_download(self) -> None:
+ self._downloading = True
+
+ def update_final_status(self, result: bool) -> None:
+ self._failure = result
+ self._downloading = False
+
+ def is_failure(self) -> bool:
+ return self._failure
+
+
+class AsyncLoopThread(Thread):
+ def __init__(self):
+ super(AsyncLoopThread, self).__init__(daemon=True)
+ self.loop = asyncio.new_event_loop()
+ self.logger = Logger()
+ self.logger.info("looper thread is created")
+
+ def run(self):
+ asyncio.set_event_loop(self.loop)
+ self.logger.info("looper thread is running")
+ self.loop.run_forever()
+
+
+class MiaoshouDownloaderManager(metaclass=MiaoshouSingleton):
+ _downloading_entries: t.Dict[str, DownloadingEntry] = None
+
+ def __init__(self):
+ if self._downloading_entries is None:
+ self._downloading_entries = {}
+ self.message_queue = queue.Queue()
+
+ self.logger = Logger()
+ self.looper = AsyncLoopThread()
+ self.looper.start()
+ self.logger.info("download manager is ready")
+ self._mutex = Lock()
+
+ def consume_all_ready_messages(self) -> None:
+ """
+ capture all enqueued messages, this method should not be used if you are iterating over the message queue
+ :return:
+ None
+ :side-effect:
+ update downloading entries' status
+ """
+ while True:
+ # self.logger.info("fetching the enqueued message")
+ try:
+ (aurl, finished_size, total_size) = self.message_queue.get(block=False, timeout=0.2)
+ # self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}")
+ try:
+ self._mutex.acquire(blocking=True)
+ self._downloading_entries[aurl].total_size = total_size
+ self._downloading_entries[aurl].downloaded_size = finished_size
+ finally:
+ self._mutex.release()
+ except queue.Empty:
+ break
+
+ def iterator(self) -> t.Tuple[float, float]:
+
+ while True:
+ self.logger.info("waiting for incoming message")
+
+ try:
+ (aurl, finished_size, total_size) = self.message_queue.get(block=True)
+ self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}")
+ try:
+ self._mutex.acquire(blocking=True)
+ self._downloading_entries[aurl].total_size = total_size
+ self._downloading_entries[aurl].downloaded_size = finished_size
+
+ tasks_total_size = 0.
+ tasks_finished_size = 0.
+
+ for e in self._downloading_entries.values():
+ tasks_total_size += e.total_size
+ tasks_finished_size += e.downloaded_size
+
+ yield tasks_finished_size, tasks_total_size
+ finally:
+ self._mutex.release()
+ except queue.Empty:
+ if len(asyncio.all_tasks(self.looper.loop)) == 0:
+ self.logger.info("all downloading tasks finished")
+ break
+
+ async def _submit_task(self, download_entry: DownloadingEntry) -> None:
+ try:
+ self._mutex.acquire(blocking=True)
+ if download_entry.target_url in self._downloading_entries:
+ self.logger.warn(f"{download_entry.target_url} is already downloading")
+ return
+ else:
+ download_entry.start_download()
+ self._downloading_entries[download_entry.target_url] = download_entry
+ finally:
+ self._mutex.release()
+
+ file_downloader = MiaoshouFileDownloader(
+ target_url=download_entry.target_url,
+ local_file=download_entry.local_file,
+ local_directory=download_entry.local_directory,
+ channel=self.message_queue if download_entry.estimated_size else None,
+ estimated_total_length=download_entry.estimated_size,
+ expected_checksum=download_entry.expected_checksum,
+ )
+
+ result: bool = await self.looper.loop.run_in_executor(None, file_downloader.download_file)
+
+ try:
+ self._mutex.acquire(blocking=True)
+ self._downloading_entries[download_entry.target_url].update_final_status(result)
+ finally:
+ self._mutex.release()
+
+ def download(self, source_url: str, target_file: str, estimated_total_size: float,
+ expected_checksum: str = None) -> None:
+ target_dir = os.path.dirname(target_file)
+ target_filename = os.path.basename(target_file)
+ download_entry = DownloadingEntry(
+ target_url=source_url,
+ local_file=target_filename,
+ local_directory=target_dir,
+ estimated_total_size=estimated_total_size,
+ expected_checksum=expected_checksum
+ )
+
+ asyncio.run_coroutine_threadsafe(self._submit_task(download_entry), self.looper.loop)
+
+ def tasks_summary(self) -> t.Tuple[int, int, str]:
+ self.consume_all_ready_messages()
+
+ total_tasks_num = 0
+ ongoing_tasks_num = 0
+ failed_tasks_num = 0
+
+ try:
+ description = ""
+ self._mutex.acquire(blocking=True)
+ for name, entry in self._downloading_entries.items():
+ if entry.estimated_size is None:
+ continue
+
+ total_tasks_num += 1
+
+ if entry.total_size > 0.:
+ description += f"
{entry.local_file} ({toolkit.get_readable_size(entry.total_size)}) : "
+ else:
+ description += f"
{entry.local_file} ({toolkit.get_readable_size(entry.estimated_size)}) : "
+
+ if entry.is_downloading():
+ ongoing_tasks_num += 1
+ finished_percent = entry.downloaded_size/entry.estimated_size * 100
+ description += f'{round(finished_percent, 2)} %'
+ elif entry.is_failure():
+ failed_tasks_num += 1
+ description += 'failed!'
+ else:
+ description += 'finished'
+ description += "
"
+ finally:
+ self._mutex.release()
+ pass
+
+ description += "
"
+ overall = f"""
+
+ {ongoing_tasks_num} ongoing,
+ {total_tasks_num - ongoing_tasks_num - failed_tasks_num} finished,
+ {failed_tasks_num} failed.
+
+
+
+ """
+
+ return ongoing_tasks_num, total_tasks_num, overall + description
+
+
diff --git a/scripts/download/msai_file_downloader.py b/scripts/download/msai_file_downloader.py
new file mode 100644
index 0000000..e165c95
--- /dev/null
+++ b/scripts/download/msai_file_downloader.py
@@ -0,0 +1,226 @@
+import os
+import pickle
+import queue
+import time
+import typing as t
+from pathlib import Path
+from urllib.parse import urlparse
+
+import rehash
+import requests
+from requests.adapters import HTTPAdapter
+from tqdm import tqdm
+from urllib3.util import Retry
+
+import scripts.msai_utils.msai_toolkit as toolkit
+from scripts.logging.msai_logger import Logger
+
+
+class MiaoshouFileDownloader(object):
+ CHUNK_SIZE = 1024 * 1024
+
+ def __init__(self, target_url: str = None,
+ local_file: str = None, local_directory: str = None, estimated_total_length: float = 0.,
+ expected_checksum: str = None,
+ channel: queue.Queue = None,
+ max_retries=3) -> None:
+ self.logger = Logger()
+
+ self.target_url: str = target_url
+ self.local_file: str = local_file
+ self.local_directory = local_directory
+ self.expected_checksum = expected_checksum
+ self.max_retries = max_retries
+
+ self.accept_ranges: bool = False
+ self.estimated_content_length = estimated_total_length
+ self.content_length: int = -1
+ self.finished_chunk_size: int = 0
+
+ self.channel = channel # for communication
+
+ # Support 3 retries and backoff
+ retry_strategy = Retry(
+ total=3,
+ backoff_factor=1,
+ status_forcelist=[429, 500, 502, 503, 504],
+ method_whitelist=["HEAD", "GET", "OPTIONS"]
+ )
+ adapter = HTTPAdapter(max_retries=retry_strategy)
+ self.session = requests.Session()
+ self.session.mount("https://", adapter)
+ self.session.mount("http://", adapter)
+
+ # inform message receiver at once
+ if self.channel:
+ self.channel.put_nowait((
+ self.target_url,
+ self.finished_chunk_size,
+ self.estimated_content_length,
+ ))
+
+ # Head request to get file-length and check whether it supports ranges.
+ def get_file_info_from_server(self, target_url: str) -> t.Tuple[bool, float]:
+ try:
+ headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip
+ response = requests.head(target_url, headers=headers, allow_redirects=True)
+ response.raise_for_status()
+ content_length = None
+ if "Content-Length" in response.headers:
+ content_length = int(response.headers['Content-Length'])
+ accept_ranges = (response.headers.get("Accept-Ranges") == "bytes")
+ return accept_ranges, float(content_length)
+ except Exception as ex:
+ self.logger.info(f"HEAD Request Error: {ex}")
+ return False, self.estimated_content_length
+
+ def download_file_full(self, target_url: str, local_filepath: str) -> t.Optional[str]:
+ try:
+ checksum = rehash.sha256()
+ headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip
+
+ with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN",
+ desc=os.path.basename(self.local_file)) as progressbar, \
+ self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \
+ open(local_filepath, 'wb') as file_out:
+ response.raise_for_status()
+
+ for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE):
+ file_out.write(chunk)
+ checksum.update(chunk)
+ progressbar.update(len(chunk))
+ self.update_progress(len(chunk))
+
+ except Exception as ex:
+ self.logger.info(f"Download error: {ex}")
+ return None
+
+ return checksum.hexdigest()
+
+ def download_file_resumable(self, target_url: str, local_filepath: str) -> t.Optional[str]:
+ # Always go off the checkpoint as the file was flushed before writing.
+ download_checkpoint = local_filepath + ".downloading"
+ try:
+ resume_point, checksum = pickle.load(open(download_checkpoint, "rb"))
+ assert os.path.exists(local_filepath) # catch checkpoint without file
+ self.logger.info("File already exists, resuming download.")
+ except Exception as e:
+ self.logger.error(f"failed to load downloading checkpoint - {download_checkpoint} due to {e}")
+ resume_point = 0
+ checksum = rehash.sha256()
+ if os.path.exists(local_filepath):
+ os.remove(local_filepath)
+ Path(local_filepath).touch()
+
+ assert (resume_point < self.content_length)
+
+ self.finished_chunk_size = resume_point
+
+ # Support resuming
+ headers = {"Range": f"bytes={resume_point}-", "Accept-Encoding": "identity"}
+ try:
+ with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN",
+ desc=os.path.basename(self.local_file)) as progressbar, \
+ self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \
+ open(local_filepath, 'r+b') as file_out:
+ response.raise_for_status()
+ self.update_progress(resume_point)
+ file_out.seek(resume_point)
+
+ for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE):
+ file_out.write(chunk)
+ file_out.flush()
+ resume_point += len(chunk)
+ checksum.update(chunk)
+ pickle.dump((resume_point, checksum), open(download_checkpoint, "wb"))
+ progressbar.update(len(chunk))
+ self.update_progress(len(chunk))
+
+ # Only remove checkpoint at full size in case connection cut
+ if os.path.getsize(local_filepath) == self.content_length:
+ os.remove(download_checkpoint)
+ else:
+ return None
+
+ except Exception as ex:
+ self.logger.error(f"Download error: {ex}")
+ return None
+
+ return checksum.hexdigest()
+
+ def update_progress(self, finished_chunk_size: int) -> None:
+ self.finished_chunk_size += finished_chunk_size
+
+ if self.channel:
+ self.channel.put_nowait((
+ self.target_url,
+ self.finished_chunk_size,
+ self.content_length,
+ ))
+
+ # In order to avoid leaving extra garbage meta files behind this
+ # will overwrite any existing files found at local_file. If you don't want this
+ # behaviour you can handle this externally.
+ # local_file and local_directory could write to unexpected places if the source
+ # is untrusted, be careful!
+ def download_file(self) -> bool:
+ success = False
+ try:
+ # Need to rebuild local_file_final each time in case of different urls
+ if not self.local_file:
+ specific_local_file = os.path.basename(urlparse(self.target_url).path)
+ else:
+ specific_local_file = self.local_file
+
+ download_temp_dir = toolkit.get_user_temp_dir()
+ toolkit.assert_user_temp_dir()
+
+ if self.local_directory:
+ os.makedirs(self.local_directory, exist_ok=True)
+
+ specific_local_file = os.path.join(download_temp_dir, specific_local_file)
+
+ self.accept_ranges, self.content_length = self.get_file_info_from_server(self.target_url)
+ self.logger.info(f"Accept-Ranges: {self.accept_ranges}. content length: {self.content_length}")
+ if self.accept_ranges and self.content_length:
+ download_method = self.download_file_resumable
+ self.logger.info("Server supports resume")
+ else:
+ download_method = self.download_file_full
+ self.logger.info(f"Server doesn't support resume.")
+
+ for i in range(self.max_retries):
+ self.logger.info(f"Download Attempt {i + 1}")
+ checksum = download_method(self.target_url, specific_local_file)
+ if checksum:
+ match = ""
+ if self.expected_checksum:
+ match = ", Checksum Match"
+
+ if self.expected_checksum and self.expected_checksum != checksum:
+ self.logger.info(f"Checksum doesn't match. Calculated {checksum} "
+ f"Expecting: {self.expected_checksum}")
+ else:
+ self.logger.info(f"Download successful{match}. Checksum {checksum}")
+ success = True
+ break
+ time.sleep(1)
+
+ if success:
+ print("*" * 80)
+ self.logger.info(f"{self.target_url} [DOWNLOADED COMPLETELY]")
+ print("*" * 80)
+ if self.local_directory:
+ target_local_file = os.path.join(self.local_directory, self.local_file)
+ else:
+ target_local_file = self.local_file
+ toolkit.move_file(specific_local_file, target_local_file)
+ else:
+ print("*" * 80)
+ self.logger.info(f"{self.target_url} [ FAILED ]")
+ print("*" * 80)
+
+ except Exception as ex:
+ self.logger.info(f"Unexpected Error: {ex}") # Only from block above
+
+ return success
diff --git a/scripts/logging/__init__.py b/scripts/logging/__init__.py
new file mode 100644
index 0000000..0e75e22
--- /dev/null
+++ b/scripts/logging/__init__.py
@@ -0,0 +1 @@
+__all__ = ["msai_logger"]
diff --git a/scripts/logging/msai_logger.py b/scripts/logging/msai_logger.py
new file mode 100644
index 0000000..95dd71f
--- /dev/null
+++ b/scripts/logging/msai_logger.py
@@ -0,0 +1,100 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import typing as t
+
+from scripts.msai_utils.msai_singleton import MiaoshouSingleton
+
+
+class Logger(metaclass=MiaoshouSingleton):
+ _dataset = None
+
+ KEY_TRACE_PATH = "trace_path"
+ KEY_INFO = "info"
+ KEY_ERROR = "error"
+ KEY_JOB = "job"
+
+ def _do_init(self, log_folder: str, disable_console_output: bool = False) -> None:
+ # Setup trace_path with empty string by default, it will be assigned with valid content if needed
+ self._dataset = {Logger.KEY_TRACE_PATH: ""}
+
+ print(f"logs_location: {log_folder}")
+ os.makedirs(log_folder, exist_ok=True)
+
+ # Setup basic logging configuration
+ logging.basicConfig(level=logging.INFO,
+ filemode='w',
+ format='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s')
+
+ # Setup info logging
+ self._dataset[Logger.KEY_INFO] = logging.getLogger(Logger.KEY_INFO)
+ msg_handler = logging.FileHandler(os.path.join(log_folder, "info.log"),
+ "a",
+ encoding="UTF-8")
+ msg_handler.setLevel(logging.INFO)
+ msg_handler.setFormatter(
+ logging.Formatter(fmt='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s'))
+ self._dataset[Logger.KEY_INFO].addHandler(msg_handler)
+
+ # Setup error logging
+ self._dataset[Logger.KEY_ERROR] = logging.getLogger(Logger.KEY_ERROR)
+ error_handler = logging.FileHandler(
+ os.path.join(log_folder, f'error_{datetime.date.today().strftime("%Y%m%d")}.log'),
+ mode="a",
+ encoding='UTF-8')
+ error_handler.setLevel(logging.ERROR)
+ error_handler.setFormatter(
+ logging.Formatter(
+ fmt=f"{self._dataset.get('trace_path')}:\n "
+ f"%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s"))
+ self._dataset[Logger.KEY_ERROR].addHandler(error_handler)
+
+ # Setup job logging
+ self._dataset[Logger.KEY_JOB] = logging.getLogger(Logger.KEY_JOB)
+ job_handler = logging.FileHandler(os.path.join(log_folder, "jobs.log"),
+ mode="a",
+ encoding="UTF-8")
+ self._dataset[Logger.KEY_JOB].addHandler(job_handler)
+
+ if disable_console_output:
+ for k in [Logger.KEY_INFO, Logger.KEY_JOB, Logger.KEY_ERROR]:
+ l: logging.Logger = self._dataset[k]
+ l.propagate = not disable_console_output
+
+ def __init__(self, log_folder: str = None, disable_console_output: bool = False) -> None:
+ if self._dataset is None:
+ try:
+ self._do_init(log_folder=log_folder, disable_console_output=disable_console_output)
+ except Exception as e:
+ print(e)
+
+ def update_path_info(self, current_path: str) -> None:
+ self._dataset[Logger.KEY_TRACE_PATH] = current_path
+
+ def callback_func(self, exc_type: t.Any, exc_value: t.Any, exc_tracback: t.Any) -> None:
+ self._dataset[Logger.KEY_JOB].error(f"job failed for {self._dataset[Logger.KEY_TRACE_PATH]}")
+ self._dataset[Logger.KEY_INFO].error(f"{self._dataset[Logger.KEY_TRACE_PATH]}\n, callback_func: ",
+ exc_info=(exc_type, exc_value, exc_tracback))
+
+ def debug(self, fmt, *args, **kwargs) -> None:
+ l: logging.Logger = self._dataset[Logger.KEY_INFO]
+ l.debug(fmt, *args, **kwargs, stacklevel=2)
+
+ def info(self, fmt, *args, **kwargs) -> None:
+ l: logging.Logger = self._dataset[Logger.KEY_INFO]
+ l.info(fmt, *args, **kwargs, stacklevel=2)
+
+ def warn(self, fmt, *args, **kwargs) -> None:
+ l: logging.Logger = self._dataset[Logger.KEY_INFO]
+ l.warn(fmt, *args, **kwargs, stacklevel=2)
+
+ def error(self, fmt, *args, **kwargs) -> None:
+ l: logging.Logger = self._dataset[Logger.KEY_ERROR]
+ l.error(fmt, *args, **kwargs, stacklevel=2)
+
+ def job(self, fmt, *args, **kwargs) -> None:
+ l: logging.Logger = self._dataset[Logger.KEY_JOB]
+ l.info(fmt, *args, **kwargs, stacklevel=2)
+
+
diff --git a/scripts/main.py b/scripts/main.py
new file mode 100644
index 0000000..82296ba
--- /dev/null
+++ b/scripts/main.py
@@ -0,0 +1,22 @@
+import modules
+import modules.scripts as scripts
+
+from scripts.assistant.miaoshou import MiaoShouAssistant
+
+
+class MiaoshouScript(scripts.Script):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def title(self):
+ return "Miaoshou Assistant"
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ return ()
+
+
+assistant = MiaoShouAssistant()
+modules.script_callbacks.on_ui_tabs(assistant.on_event_ui_tabs_opened)
diff --git a/scripts/msai_utils/__init__.py b/scripts/msai_utils/__init__.py
new file mode 100644
index 0000000..474a79a
--- /dev/null
+++ b/scripts/msai_utils/__init__.py
@@ -0,0 +1 @@
+__all__ = ["msai_singleton", "msai_toolkit"]
diff --git a/scripts/msai_utils/msai_singleton.py b/scripts/msai_utils/msai_singleton.py
new file mode 100644
index 0000000..c2f831c
--- /dev/null
+++ b/scripts/msai_utils/msai_singleton.py
@@ -0,0 +1,8 @@
+class MiaoshouSingleton(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(MiaoshouSingleton, cls).__call__(*args, **kwargs)
+ cls._instances[cls].__init__(*args, **kwargs)
+ return cls._instances[cls]
diff --git a/scripts/msai_utils/msai_toolkit.py b/scripts/msai_utils/msai_toolkit.py
new file mode 100644
index 0000000..cda6706
--- /dev/null
+++ b/scripts/msai_utils/msai_toolkit.py
@@ -0,0 +1,88 @@
+import json
+import os
+import platform
+import shutil
+import typing as t
+from datetime import datetime
+from pathlib import Path
+
+
+def read_json(file) -> t.Any:
+ try:
+ with open(file, "r", encoding="utf-8-sig") as f:
+ return json.load(f)
+ except Exception as e:
+ print(e)
+ return None
+
+
+def write_json(file, content) -> None:
+ try:
+ with open(file, 'w') as f:
+ json.dump(content, f, indent=4)
+ except Exception as e:
+ print(e)
+
+
+def get_args(args) -> t.List[str]:
+ parameters = []
+ idx = 0
+ for arg in args:
+ if idx == 0 and '--' not in arg:
+ pass
+ elif '--' in arg:
+ parameters.append(rf'{arg}')
+ idx += 1
+ else:
+ parameters[idx - 1] = parameters[idx - 1] + ' ' + rf'{arg}'
+
+ return parameters
+
+
+def get_readable_size(size: int, precision=2) -> str:
+ if size is None:
+ return ""
+
+ suffixes = ['B', 'KB', 'MB', 'GB', 'TB']
+ suffixIndex = 0
+ while size >= 1024 and suffixIndex < len(suffixes):
+ suffixIndex += 1 # increment the index of the suffix
+ size = size / 1024.0 # apply the division
+ return "%.*f%s" % (precision, size, suffixes[suffixIndex])
+
+
+def get_file_last_modified_time(path_to_file: str) -> datetime:
+ if path_to_file is None:
+ return datetime.now()
+
+ if platform.system() == "Windows":
+ return datetime.fromtimestamp(os.path.getmtime(path_to_file))
+ else:
+ stat = os.stat(path_to_file)
+ return datetime.fromtimestamp(stat.st_mtime)
+
+
+def get_not_found_image_url() -> str:
+ return "https://msdn.miaoshouai.com/msdn/userimage/not-found.svg"
+
+
+def get_user_temp_dir() -> str:
+ return os.path.join(Path.home().absolute(), ".miaoshou_assistant_download")
+
+
+def assert_user_temp_dir() -> None:
+ os.makedirs(get_user_temp_dir(), exist_ok=True)
+
+
+def move_file(src: str, dst: str) -> None:
+ if not src or not dst:
+ return
+
+ if not os.path.exists(src):
+ return
+
+ if os.path.exists(dst):
+ os.remove(dst)
+
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
+ shutil.move(src, dst)
diff --git a/scripts/runtime/__init__.py b/scripts/runtime/__init__.py
new file mode 100644
index 0000000..fe306c7
--- /dev/null
+++ b/scripts/runtime/__init__.py
@@ -0,0 +1,5 @@
+__all__ = ["msai_prelude", "msai_runtime"]
+
+from . import msai_prelude as prelude
+
+prelude.MiaoshouPrelude().load()
diff --git a/scripts/runtime/msai_prelude.py b/scripts/runtime/msai_prelude.py
new file mode 100644
index 0000000..bb14663
--- /dev/null
+++ b/scripts/runtime/msai_prelude.py
@@ -0,0 +1,158 @@
+import os
+import platform
+import sys
+import typing as t
+
+import psutil
+import torch
+
+import launch
+from scripts.logging.msai_logger import Logger
+from scripts.msai_utils import msai_toolkit as toolkit
+from scripts.msai_utils.msai_singleton import MiaoshouSingleton
+
+
+class MiaoshouPrelude(metaclass=MiaoshouSingleton):
+ _dataset = None
+
+ def __init__(self) -> None:
+ # Potential race condition, not call in multithread environment
+ if MiaoshouPrelude._dataset is None:
+ self._init_constants()
+
+ MiaoshouPrelude._dataset = {
+ "log_folder": os.path.join(self.ext_folder, "logs")
+ }
+
+ disable_log_console_output: bool = False
+ if self.all_settings.get("boot_settings"):
+ if self.all_settings["boot_settings"].get("disable_log_console_output") is not None:
+ disable_log_console_output = self.all_settings["boot_settings"].get("disable_log_console_output")
+
+ self._logger = Logger(self._dataset["log_folder"], disable_console_output=disable_log_console_output)
+
+ def _init_constants(self) -> None:
+ self._api_url = {
+ "civitai": "https://civitai.com/api/v1/models",
+ "liandange": "https://model-api.paomiantv.cn/model/api/models",
+ }
+ self._ext_folder = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".."))
+ self._setting_file = os.path.join(self.ext_folder, "configs", "settings.json")
+ self._model_hash_file = os.path.join(self.ext_folder, "configs", "model_hash.json")
+ self._model_json = {
+ 'civitai': os.path.join(self.ext_folder, 'configs', 'civitai_models.json'),
+ 'liandange': os.path.join(self.ext_folder, 'configs', 'liandange_models.json'),
+ }
+ self._checkboxes = {
+ 'Enable xFormers': '--xformers',
+ 'No Half': '--no-half',
+ 'No Half VAE': '--no-half-vae',
+ 'Enable API': '--api',
+ 'Auto Launch': '--autolaunch',
+ 'Allow Local Network Access': '--listen',
+ }
+
+ self._gpu_setting = {
+ 'CPU Only': '--precision full --no-half --use-cpu SD GFPGAN BSRGAN ESRGAN SCUNet CodeFormer --all',
+ 'GTX 16xx': '--lowvram --xformers --precision full --no-half',
+ 'Low: 4-6G VRAM': '--xformers --lowvram',
+ 'Med: 6-8G VRAM': '--xformers --medvram',
+ 'Normal: 8+G VRAM': '',
+ }
+
+ self._theme_setting = {
+ 'Auto': '',
+ 'Light Mode': '--theme=light',
+ 'Dark Mode': '--theme=dark',
+ }
+
+ @property
+ def ext_folder(self) -> str:
+ return self._ext_folder
+
+ @property
+ def log_folder(self) -> str:
+ return self._dataset.get("log_folder")
+
+ @property
+ def all_settings(self) -> t.Any:
+ return toolkit.read_json(self._setting_file)
+
+ @property
+ def boot_settings(self) -> t.Any:
+ all_setting = self.all_settings
+ if all_setting:
+ return all_setting['boot_settings']
+ else:
+ return None
+
+ def api_url(self, model_source: str) -> t.Optional[str]:
+ return self._api_url.get(model_source)
+
+ @property
+ def setting_file(self) -> str:
+ return self._setting_file
+
+ @property
+ def model_hash_file(self) -> str:
+ return self._model_hash_file
+
+ @property
+ def checkboxes(self) -> t.Dict[str, str]:
+ return self._checkboxes
+
+ @property
+ def gpu_setting(self) -> t.Dict[str, str]:
+ return self._gpu_setting
+
+ @property
+ def theme_setting(self) -> t.Dict[str, str]:
+ return self._theme_setting
+
+ @property
+ def model_json(self) -> t.Dict[str, t.Any]:
+ return self._model_json
+
+ def update_model_json(self, site: str, models: t.Dict[str, t.Any]) -> None:
+ if self._model_json.get(site) is None:
+ self._logger.error(f"cannot save model info for {site}")
+ return
+
+ self._logger.info(f"{self._model_json[site]} updated")
+ toolkit.write_json(self._model_json[site], models)
+
+ def load(self) -> None:
+ self._logger.info("start to do prelude")
+ self._logger.info(f"cmdline args: {' '.join(sys.argv[1:])}")
+
+ @classmethod
+ def get_sys_info(cls) -> str:
+ sys_info = 'System Information\n\n'
+
+ sys_info += r'OS Name: {0} {1}'.format(platform.system(), platform.release()) + '\n'
+ sys_info += r'OS Version: {0}'.format(platform.version()) + '\n'
+ sys_info += r'WebUI Version: {0}'.format(
+ f'https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{launch.commit_hash()}') + '\n'
+ sys_info += r'Torch Version: {0}'.format(getattr(torch, '__long_version__', torch.__version__)) + '\n'
+ sys_info += r'Python Version: {0}'.format(sys.version) + '\n\n'
+ sys_info += r'CPU: {0}'.format(platform.processor()) + '\n'
+ sys_info += r'CPU Cores: {0}/{1}'.format(psutil.cpu_count(logical=False), psutil.cpu_count(logical=True)) + '\n'
+
+ # FIXME: should uncomment line below and remove my own workaround for MacOS
+ # sys_info += r'CPU Frequency: {0} GHz'.format(round(psutil.cpu_freq().max/1000,2)) + '\n'
+ # workaround: let my macbook M1 happy
+ sys_info += r'CPU Frequency: {0} GHz'.format(round(2540.547 / 1000, 2)) + '\n'
+
+ sys_info += r'CPU Usage: {0}%'.format(psutil.cpu_percent()) + '\n\n'
+ sys_info += r'RAM: {0}'.format(toolkit.get_readable_size(psutil.virtual_memory().total)) + '\n'
+ sys_info += r'Memory Usage: {0}%'.format(psutil.virtual_memory().percent) + '\n\n'
+ for i in range(torch.cuda.device_count()):
+ sys_info += r'Graphics Card{0}: {1} ({2})'.format(i, torch.cuda.get_device_properties(i).name,
+ toolkit.get_readable_size(
+ torch.cuda.get_device_properties(
+ i).total_memory)) + '\n'
+ sys_info += r'Available VRAM: {0}'.format(toolkit.get_readable_size(torch.cuda.mem_get_info(i)[0])) + '\n'
+
+ return sys_info
+
+
diff --git a/scripts/runtime/msai_runtime.py b/scripts/runtime/msai_runtime.py
new file mode 100644
index 0000000..5475293
--- /dev/null
+++ b/scripts/runtime/msai_runtime.py
@@ -0,0 +1,630 @@
+import datetime
+import fileinput
+import os
+import platform
+import re
+import shutil
+import sys
+import time
+import typing as t
+
+import gradio as gr
+import requests
+
+import modules
+from modules.sd_models import CheckpointInfo
+from scripts.download.msai_downloader_manager import MiaoshouDownloaderManager
+from scripts.logging.msai_logger import Logger
+from scripts.msai_utils import msai_toolkit as toolkit
+from scripts.runtime.msai_prelude import MiaoshouPrelude
+
+
+class MiaoshouRuntime(object):
+ def __init__(self):
+ self.cmdline_args: t.List[str] = None
+ self.logger = Logger()
+ self.prelude = MiaoshouPrelude()
+ self._old_additional: str = None
+ self._model_set: t.List[t.Dict] = None
+ self._model_set_last_access_time: datetime.datetime = None
+ self._ds_models: gr.Dataset = None
+ self._allow_nsfw: bool = False
+ self._model_source: str = "civitai" # civitai is the default model source
+
+ # TODO: may be owned by downloader class
+ self.model_files = []
+
+ self.downloader_manager = MiaoshouDownloaderManager()
+
+ def get_default_args(self, commandline_args: t.List[str] = None):
+ if commandline_args is None:
+ commandline_args: t.List[str] = toolkit.get_args(sys.argv[1:])
+ self.cmdline_args = commandline_args
+ self.logger.info(f"default commandline args: {commandline_args}")
+
+ checkbox_values = []
+ additional_args = ""
+ saved_setting = self.prelude.boot_settings
+
+ gpu = saved_setting.get('drp_args_vram')
+ theme = saved_setting.get('drp_args_theme')
+ port = saved_setting.get('txt_args_listen_port')
+
+ for arg in commandline_args:
+ if 'theme' in arg:
+ theme = [k for k, v in self.prelude.theme_setting.items() if v == arg][0]
+ if 'port' in arg:
+ port = arg.split(' ')[-1]
+
+ for chk in self.prelude.checkboxes:
+ for arg in commandline_args:
+ if self.prelude.checkboxes[chk] == arg:
+ checkbox_values.append(chk)
+
+ gpu_arg_list = [f'--{i.strip()}' for i in ' '.join(list(self.prelude.gpu_setting.values())).split('--')]
+ for arg in commandline_args:
+ if 'port' not in arg \
+ and arg not in list(self.prelude.theme_setting.values()) \
+ and arg not in list(self.prelude.checkboxes.values()) \
+ and arg not in gpu_arg_list:
+ additional_args += (' ' + rf'{arg}')
+
+ self._old_additional = additional_args
+ webui_ver = saved_setting['drp_choose_version']
+
+ return gpu, theme, port, checkbox_values, additional_args.replace('\\', '\\\\').strip(), webui_ver
+
+ def add_arg(self, args: str = "") -> None:
+ for arg in args.split('--'):
+ self.logger.info(f'add arg: {arg.strip()}')
+ if f"--{arg.strip()}" not in self.cmdline_args and arg.strip() != '':
+ self.cmdline_args.append(f'--{arg.strip()}')
+
+ def remove_arg(self, args: str = "") -> None:
+ arg_keywords = ['port', 'theme']
+
+ for arg in args.split('--'):
+ if arg in arg_keywords:
+ for cmdl in self.cmdline_args:
+ if arg in cmdl:
+ self.cmdline_args.remove(cmdl)
+ break
+ elif f'--{arg.strip()}' in self.cmdline_args and arg.strip() != '':
+ print(f"remove args:{arg.strip()}")
+ self.cmdline_args.remove(f'--{arg.strip()}')
+
+ def get_final_args(self, gpu, theme, port, checkgroup, more_args) -> None:
+ # gpu settings
+ for s1 in self.prelude.gpu_setting:
+ if s1 in gpu:
+ for s2 in self.prelude.gpu_setting:
+ if s2 != s1:
+ self.remove_arg(self.prelude.gpu_setting[s2])
+ self.add_arg(self.prelude.gpu_setting[s1])
+
+ if port != '7860':
+ self.add_arg(f'--port {port}')
+ else:
+ self.remove_arg('--port')
+
+ # theme settings
+ self.remove_arg('--theme')
+ for t in self.prelude.theme_setting:
+ if t == theme:
+ self.add_arg(self.prelude.theme_setting[t])
+ break
+
+ # check box settings
+ for chked in checkgroup:
+ self.logger.info(f'checked:{self.prelude.checkboxes[chked]}')
+ self.add_arg(self.prelude.checkboxes[chked])
+
+ for unchk in list(set(list(self.prelude.checkboxes.keys())) - set(checkgroup)):
+ print(f'unchecked:{unchk}')
+ self.remove_arg(self.prelude.checkboxes[unchk])
+
+ # additional commandline settings
+ self.remove_arg(self._old_additional)
+ self.add_arg(more_args.replace('\\\\', '\\'))
+ self._old_additional = more_args.replace('\\\\', '\\')
+
+ def fetch_all_models(self) -> t.List[t.Dict]:
+ endpoint_url = self.prelude.api_url(self.model_source)
+ if endpoint_url is None:
+ self.logger.error(f"{self.model_source} is not supported")
+ return []
+
+ self.logger.info(f"start to fetch model info from '{self.model_source}':{endpoint_url}")
+
+ limit_threshold = 100
+
+ all_set = []
+ response = requests.get(endpoint_url + f'?page=1&limit={limit_threshold}')
+ num_of_pages = response.json()['metadata']['totalPages']
+ self.logger.info(f"total pages = {num_of_pages}")
+
+ continuous_error_counts = 0
+
+ for p in range(1, num_of_pages + 1):
+ try:
+ response = requests.get(endpoint_url + f'?page={p}&limit={limit_threshold}')
+ payload = response.json()
+ if payload.get("success") is not None and not payload.get("success"):
+ self.logger.error(f"failed to fetch page[{p + 1}]")
+ continuous_error_counts += 1
+ if continuous_error_counts > 10:
+ break
+ else:
+ continue
+
+ continuous_error_counts = 0 # reset error flag
+ self.logger.debug(f"start to process page[{p}]")
+
+ for model in payload['items']:
+ all_set.append(model)
+
+ self.logger.debug(f"page[{p}] : {len(payload['items'])} items added")
+ except Exception as e:
+ self.logger.error(f"failed to fetch page[{p + 1}] due to {e}")
+ time.sleep(3)
+
+ if len(all_set) > 0:
+ self.prelude.update_model_json(self.model_source, all_set)
+ else:
+ self.logger.error("fetch_all_models: emtpy body received")
+
+ return all_set
+
+ def refresh_all_models(self) -> None:
+ if self.fetch_all_models():
+ if self.ds_models:
+ self.ds_models.samples = self.model_set
+ self.ds_models.update(samples=self.model_set)
+ else:
+ self.logger.error(f"ds models is null")
+
+ def get_images_html(self, search: str = '', model_type: str = 'All') -> t.List[str]:
+ self.logger.info(f"get_image_html: model_type = {model_type}, and search pattern = '{search}'")
+
+ model_cover_thumbnails = []
+ model_format = []
+
+ if self.model_set is None:
+ self.logger.error("model_set is null")
+ return []
+
+ self.logger.info(f"{len(self.model_set)} items inside '{self.model_source}'")
+
+ search = search.lower()
+ for model in self.model_set:
+ try:
+ if model.get('type') is not None \
+ and model.get('type') not in model_format:
+ model_format.append(model['type'])
+
+ if search == '' or \
+ (model.get('name') is not None and search in model.get('name').lower()) \
+ or (model.get('description') is not None and search in model.get('description').lower()):
+
+ if (model_type == 'All' or model_type in model.get('type')) \
+ and (self.allow_nsfw or (not self.allow_nsfw and not model.get('nsfw'))):
+ model_cover_thumbnails.append([
+ [f"""
+
+
+
})
+
+
+
{model.get('name')}
+
Type: {model.get('type')}
+
Rating: {model.get('stats')['rating']}
+
+
+ """],
+ model['id']])
+ except Exception:
+ continue
+
+ return model_cover_thumbnails
+
+ # TODO: add typing hint
+ def update_boot_settings(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
+ boot_settings = self.prelude.boot_settings
+ boot_settings['drp_args_vram'] = drp_gpu
+ boot_settings["drp_args_theme"] = drp_theme
+ boot_settings['txt_args_listen_port'] = txt_listen_port
+ for chk in chk_group_args:
+ self.logger.debug(chk)
+ boot_settings[chk] = self.prelude.checkboxes[chk]
+ boot_settings['txt_args_more'] = additional_args
+ boot_settings['drp_choose_version'] = version
+
+ all_settings = self.prelude.all_settings
+ all_settings['boot_settings'] = boot_settings
+
+ toolkit.write_json(self.prelude.setting_file, all_settings)
+
+ def get_all_models(self, site: str) -> t.Any:
+ return toolkit.read_json(self.prelude.model_json[site])
+
+ def update_model_json(self, site: str, models: t.Any) -> None:
+ toolkit.write_json(self.prelude.model_json[site], models)
+
+ def get_hash_from_json(self, chk_point: CheckpointInfo) -> CheckpointInfo:
+ model_hashes = toolkit.read_json(self.prelude.model_hash_file)
+
+ if len(model_hashes) == 0 or chk_point.title not in model_hashes.keys():
+ chk_point.shorthash = chk_point.calculate_shorthash()
+ model_hashes[chk_point.title] = chk_point.shorthash
+ toolkit.write_json(self.prelude.model_hash_file, model_hashes)
+ else:
+ chk_point.shorthash = model_hashes[chk_point.title]
+
+ return chk_point
+
+ def get_local_models(self) -> t.List[t.Any]:
+ models = []
+
+ for file in modules.sd_models.checkpoint_tiles():
+ chkpt_info = modules.sd_models.get_closet_checkpoint_match(file)
+ if chkpt_info.sha256 is None and chkpt_info.shorthash is None:
+ chkpt_info = self.get_hash_from_json(chkpt_info)
+
+ fname = re.sub(r'\[.*?\]', "", chkpt_info.title)
+ model_info = self.search_model_by_hash(chkpt_info.sha256, chkpt_info.shorthash, fname)
+ if model_info is not None:
+ models.append(model_info)
+ else:
+ self.logger.info(
+ f"{chkpt_info.title}, {chkpt_info.hash}, {chkpt_info.shorthash}, {chkpt_info.sha256}")
+ models.append([
+ [f'
'],
+ [os.path.basename(fname)],
+ [fname],
+ [chkpt_info.shorthash],
+ [], [], []])
+
+ return models
+
+ def search_model_by_hash(self, lookup_sha256: str, lookup_shash: str, fname: str) -> t.Optional[t.List[t.Any]]:
+ self.logger.info(f"lookup_sha256: {lookup_sha256}, lookup_shash: {lookup_shash}, fname: {fname}")
+
+ res = None
+ if lookup_sha256 is None and lookup_shash is None:
+ return None
+
+ for model in self.model_set:
+ match = False
+
+ for ver in model['modelVersions']:
+ for file in ver['files']:
+ if lookup_sha256 is not None and 'SHA256' in file['hashes'].keys():
+ match = (lookup_sha256.upper() == file['hashes']['SHA256'].upper())
+ elif lookup_shash is not None:
+ match = (lookup_shash[:10].upper() in [h.upper() for h in file['hashes'].values()])
+
+ if match:
+ cover_link = ver['images'][0]['url'].replace('width=450', 'width=100')
+ mid = model['id']
+ res = [
+ [
+ f'
'],
+ [f"{model['name']}/{ver['name']}"],
+ [fname],
+ [lookup_shash],
+ [model['creator']['username']],
+ [model['type']],
+ [model['nsfw']],
+ [ver['trainedWords']],
+ ]
+
+ if match:
+ break
+
+ return res
+
+ def update_xformers(self, gpu, checkgroup):
+ if '--xformers' in self.prelude.gpu_setting[gpu]:
+ if 'Enable xFormers' not in checkgroup:
+ checkgroup.append('Enable xFormers')
+
+ return checkgroup
+
+ def set_nsfw(self, search='', nsfw_checker=False, model_type='All') -> t.Dict:
+ self._allow_nsfw = nsfw_checker
+ new_list = self.get_images_html(search, model_type)
+ if self._ds_models is None:
+ self.logger.error(f"_ds_models is not initialized")
+ return {}
+
+ self._ds_models.samples = new_list
+ return self._ds_models.update(samples=new_list)
+
+ def search_model(self, search='', model_type='All') -> t.Dict:
+ if self._ds_models is None:
+ self.logger.error(f"_ds_models is not initialized")
+ return {}
+
+ new_list = self.get_images_html(search, model_type)
+
+ self._ds_models.samples = new_list
+ return self._ds_models.update(samples=new_list)
+
+ def get_model_info(self, models) -> t.Tuple[t.List[t.List[str]], t.Dict, str, t.Dict]:
+ drop_list = []
+ cover_imgs = []
+ htmlDetail = ""
+
+ mid = models[1]
+
+ # TODO: use map to enhance the performances
+ m = [e for e in self.model_set if e['id'] == mid][0]
+
+ self.model_files.clear()
+
+ download_url_by_default = None
+ if m and m.get('modelVersions') and len(m.get('modelVersions')) > 0:
+ latest_version = m['modelVersions'][0]
+
+ if latest_version.get('images') and isinstance(latest_version.get('images'), list):
+ for img in latest_version['images']:
+ if self.allow_nsfw or (not self.allow_nsfw and not img.get('nsfw')):
+ if img.get('url'):
+ cover_imgs.append([img['url'], ''])
+
+ if latest_version.get('files') and isinstance(latest_version.get('files'), list):
+ for file in latest_version['files']:
+ # error checking for mandatory fields
+ if file.get('id') is not None and file.get('downloadUrl') is not None:
+ item_name = None
+ if file.get('name'):
+ item_name = file.get('name')
+ if not item_name and latest_version.get('name'):
+ item_name = latest_version['name']
+ if not item_name:
+ item_name = "unknown"
+
+ self.model_files.append({
+ "id:": file['id'],
+ "url": file['downloadUrl'],
+ "name": item_name,
+ "type": m['type'] if m.get('type') else "unknown",
+ "size": file['sizeKB'] * 1024 if file.get('sizeKB') else "unknown",
+ "format": file['format'] if file.get('format') else "unknown",
+ "cover": cover_imgs[0][0] if len(cover_imgs) > 0 else toolkit.get_not_found_image_url(),
+ })
+ file_size = toolkit.get_readable_size(file['sizeKB'] * 1024) if file.get('sizeKB') else ""
+ if file_size:
+ drop_list.append(f"{item_name} ({file_size})")
+ else:
+ drop_list.append(f"{item_name}")
+
+ if not download_url_by_default:
+ download_url_by_default = file.get('downloadUrl')
+
+ htmlDetail = ''
+ if m.get('name'):
+ htmlDetail += f"
{m['name']}
"
+ if m.get('stats') and m.get('stats').get('downloadCount'):
+ htmlDetail += f"
Downloads: {m['stats']['downloadCount']}
"
+ if m.get('stats') and m.get('stats').get('rating'):
+ htmlDetail += f"
Rating: {m['stats']['rating']}
"
+ if m.get('creator') and m.get('creator').get('username'):
+ htmlDetail += f"
Author: {m['creator']['username']}
"
+ if latest_version.get('name'):
+ htmlDetail += f"| Version: | {latest_version['name']} |
"
+ if latest_version.get('updatedAt'):
+ htmlDetail += f"| Updated Time: | {latest_version['updatedAt']} |
"
+ if m.get('type'):
+ htmlDetail += f"| Type: | {m['type']} |
"
+ if latest_version.get('baseModel'):
+ htmlDetail += f"| Base Model: | {latest_version['baseModel']} |
"
+ htmlDetail += f"| NFSW: | {m.get('nsfw') if m.get('nsfw') is not None else 'false'} |
"
+ if m.get('tags') and isinstance(m.get('tags'), list):
+ htmlDetail += f"| Tags: | "
+ for t in m['tags']:
+ htmlDetail += f"{t}"
+ htmlDetail += " |
"
+ if latest_version.get('trainedWords'):
+ htmlDetail += f"| Trigger Words: | "
+ for t in latest_version['trainedWords']:
+ htmlDetail += f"{t}"
+ htmlDetail += " |
"
+ htmlDetail += "
"
+ htmlDetail += f"{m['description'] if m.get('description') else 'N/A'}
"
+
+ return (
+ cover_imgs,
+ gr.Dropdown.update(choices=drop_list, value=drop_list[0] if len(drop_list) > 0 else []),
+ htmlDetail,
+ gr.HTML.update(value=f''
+ f'Download
')
+ )
+
+ def get_downloading_status(self):
+ (_, _, desc) = self.downloader_manager.tasks_summary()
+ return gr.HTML.update(value=desc)
+
+ def download_model(self, filename: str):
+ model_path = modules.paths.models_path
+ script_path = modules.paths.script_path
+
+ urls = []
+ for _, f in enumerate(self.model_files):
+ if not f.get('name'):
+ continue
+ model_fname = re.sub(r"\s*\(\d+(?:\.\d*)?.B\)\s*$", "", f['name'])
+
+ if model_fname in filename:
+ m_pre, m_ext = os.path.splitext(model_fname)
+ cover_fname = f"{m_pre}.jpg"
+
+ if f['type'] == 'LORA':
+ cover_fname = os.path.join(model_path, 'Lora', cover_fname)
+ model_fname = os.path.join(model_path, 'Lora', model_fname)
+ elif f['format'] == 'VAE':
+ cover_fname = os.path.join(model_path, 'VAE', cover_fname)
+ model_fname = os.path.join(model_path, 'VAE', model_fname)
+ elif f['format'] == 'TextualInversion':
+ cover_fname = os.path.join(script_path, 'embeddings', cover_fname)
+ model_fname = os.path.join(script_path, 'embeddings', model_fname)
+ elif f['format'] == 'Hypernetwork':
+ cover_fname = os.path.join(model_path, 'hypernetworks', cover_fname)
+ model_fname = os.path.join(model_path, 'hypernetworks', model_fname)
+ else:
+ cover_fname = os.path.join(model_path, 'Stable-diffusion', cover_fname)
+ model_fname = os.path.join(model_path, 'Stable-diffusion', model_fname)
+
+ urls.append((f['cover'], f['url'], f['size'], cover_fname, model_fname))
+ break
+
+ for (cover_url, model_url, total_size, local_cover_name, local_model_name) in urls:
+ self.downloader_manager.download(
+ source_url=cover_url,
+ target_file=local_cover_name,
+ estimated_total_size=None,
+ )
+ self.downloader_manager.download(
+ source_url=model_url,
+ target_file=local_model_name,
+ estimated_total_size=total_size,
+ )
+
+ #
+ # currently, web-ui is without queue enabled.
+ #
+ # webui_queue_enabled = False
+ # if webui_queue_enabled:
+ # start = time.time()
+ # downloading_tasks_iter = self.downloader_manager.iterator()
+ # for i in progressbar.tqdm(range(100), unit="byte", desc="Models Downloading"):
+ # while True:
+ # try:
+ # finished_bytes, total_bytes = next(downloading_tasks_iter)
+ # v = finished_bytes / total_bytes
+ # print(f"\n v = {v}")
+ # if isinstance(v, float) and int(v * 100) < i:
+ # print(f"\nv({v}) < {i}")
+ # continue
+ # else:
+ # break
+ # except StopIteration:
+ # break
+ #
+ # time.sleep(0.5)
+ #
+ # self.logger.info(f"[downloading] finished after {time.time() - start} secs")
+
+ time.sleep(2)
+
+ self.model_files.clear()
+ return gr.HTML.update(value=f"{len(urls)} downloading tasks added into task list
")
+
+ def change_boot_setting(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
+ self.get_final_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args)
+ self.logger.info(f'saved_cmd: {self.cmdline_args}')
+
+ target_webui_user_file = "webui-user.bat"
+ script_export_keyword = "export"
+ if platform.system() == "Linux":
+ target_webui_user_file = "webui-user.sh"
+ elif platform.system() == "Darwin":
+ target_webui_user_file = "webui-macos-env.sh"
+ else:
+ script_export_keyword = "set"
+
+ filepath = os.path.join(modules.shared.script_path, target_webui_user_file)
+ self.logger.info(f"to update: {filepath}")
+
+ msg = 'Result: Setting Saved.'
+ if version == 'Official Release':
+ try:
+ if not os.path.exists(filepath):
+ shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath)
+
+ with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file:
+ for line in file:
+ if 'COMMANDLINE_ARGS' in line:
+ rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\')
+ line = f'{script_export_keyword} COMMANDLINE_ARGS="{rep_txt}"\n'
+ sys.stdout.write(line)
+
+ except Exception as e:
+ msg = f'Error: {str(e)}'
+ else:
+ try:
+ if not os.path.exists(filepath):
+ shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath)
+
+ with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file:
+ for line in file:
+ if 'webui.py' in line:
+ rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\')
+ line = f"python\python.exe webui.py {rep_txt}\n"
+ sys.stdout.write(line)
+ except Exception as e:
+ msg = f'Error: {str(e)}'
+
+ self.update_boot_settings(version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args)
+ return gr.update(value=msg, visible=True)
+
+ @property
+ def model_set(self) -> t.List[t.Dict]:
+ try:
+ self.logger.info(f"access to model info for '{self.model_source}'")
+ model_json_mtime = toolkit.get_file_last_modified_time(self.prelude.model_json[self.model_source])
+
+ if self._model_set is None or self._model_set_last_access_time is None \
+ or self._model_set_last_access_time < model_json_mtime:
+ self._model_set = self.get_all_models(self.model_source)
+ self._model_set_last_access_time = model_json_mtime
+ self.logger.info(f"load '{self.model_source}' model data from local file")
+ except Exception as e:
+ self._model_set = self.fetch_all_models()
+ self._model_set_last_access_time = datetime.datetime.now()
+
+ return self._model_set
+
+ @property
+ def allow_nsfw(self) -> bool:
+ return self._allow_nsfw
+
+ @property
+ def old_additional_args(self) -> str:
+ return self._old_additional
+
+ @property
+ def ds_models(self) -> gr.Dataset:
+ return self._ds_models
+
+ @ds_models.setter
+ def ds_models(self, newone: gr.Dataset):
+ self._ds_models = newone
+
+ @property
+ def model_source(self) -> str:
+ return self._model_source
+
+ @model_source.setter
+ def model_source(self, newone: str):
+ self.logger.info(f"model source changes from {self.model_source} to {newone}")
+ self._model_source = newone
+ self._model_set_last_access_time = None # reset timestamp
+
+ def introception(self) -> None:
+ (gpu, theme, port, checkbox_values, extra_args, ver) = self.get_default_args()
+
+ print("################################################################")
+ print("MIAOSHOU ASSISTANT ARGUMENTS:")
+
+ print(f" gpu = {gpu}")
+ print(f" theme = {theme}")
+ print(f" port = {port}")
+ print(f" checkbox_values = {checkbox_values}")
+ print(f" extra_args = {extra_args}")
+ print(f" webui ver = {ver}")
+
+ print("################################################################")
+