diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f781c71 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# for vim +.~ +.*.swp +*~ + +# for MacOS +.DS_Store + +__pycache__/ + +.idea/ + +logs/ +flagged/ + +configs/civitai_models.json +configs/liandange_models.json diff --git a/install.py b/install.py new file mode 100644 index 0000000..6920515 --- /dev/null +++ b/install.py @@ -0,0 +1,30 @@ +import launch +import os +import gzip +import io + + +def install_preset_models_if_needed(): + assets_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets") + configs_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") + + for model_filename in ["civitai_models.json", "liandange_models.json"]: + gzip_file = os.path.join(assets_folder, f"{model_filename}.gz") + target_file = os.path.join(configs_folder, f"{model_filename}") + if not os.path.exists(target_file): + with gzip.open(gzip_file, "rb") as compressed_file: + with io.TextIOWrapper(compressed_file, encoding="utf-8") as decoder: + content = decoder.read() + with open(target_file, "w") as model_file: + model_file.write(content) + + +req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") + +with open(req_file) as file: + for lib in file: + lib = lib.strip() + if not launch.is_installed(lib): + launch.run_pip(f"install {lib}", f"Miaoshou assistant requirement: {lib}") + +install_preset_models_if_needed() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e8bf606 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +psutil +rehash +tqdm diff --git a/scripts/assistant/__init__.py b/scripts/assistant/__init__.py new file mode 100644 index 0000000..bf393ef --- /dev/null +++ b/scripts/assistant/__init__.py @@ -0,0 +1 @@ +__all__ = ["miaoshou"] diff --git a/scripts/assistant/miaoshou.py b/scripts/assistant/miaoshou.py new file mode 100644 index 0000000..69edd52 --- /dev/null +++ b/scripts/assistant/miaoshou.py @@ -0,0 +1,256 @@ +import os +import sys +import typing as t + +import gradio as gr + +import launch +import modules +from scripts.logging.msai_logger import Logger +from scripts.runtime.msai_prelude import MiaoshouPrelude +from scripts.runtime.msai_runtime import MiaoshouRuntime + + +class MiaoShouAssistant(object): + # default css definition + default_css = '#my_model_cover{width: 100px;} #my_model_trigger_words{width: 200px;}' + + def __init__(self) -> None: + self.logger = Logger() + self.prelude = MiaoshouPrelude() + self.runtime = MiaoshouRuntime() + self.refresh_symbol = '\U0001f504' + + def on_event_ui_tabs_opened(self) -> t.List[t.Optional[t.Tuple[t.Any, str, str]]]: + with gr.Blocks(analytics_enabled=False, css=MiaoShouAssistant.default_css) as miaoshou_assistant: + self.create_subtab_boot_assistant() + self.create_subtab_model_management() + self.create_subtab_model_download() + + return [(miaoshou_assistant.queue(), "Miaoshou Assistant", "miaoshou_assistant")] + + def create_subtab_boot_assistant(self) -> None: + with gr.TabItem('Boot Assistant', elem_id="boot_assistant_tab") as boot_assistant: + with gr.Row(): + with gr.Column(elem_id="col_model_list"): + gpu, theme, port, chk_args, txt_args, webui_ver = self.runtime.get_default_args() + gr.Markdown(value="Argument settings") + with gr.Row(): + drp_gpu = gr.Dropdown(label="", elem_id="drp_args_vram", + choices=list(self.prelude.gpu_setting.keys()), + value=gpu, interactive=True) + drp_theme = gr.Dropdown(label="UI Theme", choices=list(self.prelude.theme_setting.keys()), + value=theme, + elem_id="drp_args_theme", interactive=True) + txt_listen_port = gr.Text(label='Listen Port', value=port, elem_id="txt_args_listen_port", + interactive=True) + + with gr.Row(): + chk_group_args = gr.CheckboxGroup(choices=list(self.prelude.checkboxes.keys()), value=chk_args, + show_label=False) + additional_args = gr.Text(label='COMMANDLINE_ARGS (Divide by space)', value=txt_args, + elem_id="txt_args_more", interactive=True) + + with gr.Row(): + with gr.Column(): + txt_save_status = gr.Markdown(visible=False, interactive=False, show_label=False) + drp_choose_version = gr.Dropdown(label="WebUI Version", + choices=['Official Release', 'Python Integrated'], + value=webui_ver, elem_id="drp_args_version", + interactive=True) + gr.HTML( + '

*Save your settings to webui-user.bat file. Use Python Integrated only if your' + ' WebUI is extracted from a zip file and does not need python installation

') + save_settings = gr.Button(value="Save settings", elem_id="btn_arg_save_setting") + + with gr.Row(): + # with gr.Column(): + # settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="ms_settings_submit") + # with gr.Column(): + restart_gradio = gr.Button(value='Apply & Restart WebUI', variant='primary', + elem_id="ms_settings_restart_gradio") + + '''def mod_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args): + global commandline_args + + get_final_args(drp_gpu, drp_theme, txt_listen_port, hk_group_args, additional_args) + + print(commandline_args) + print(sys.argv) + #if '--xformers' not in sys.argv: + #sys.argv.append('--xformers') + + settings_submit.click(mod_args, inputs=[drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args], outputs=[])''' + + save_settings.click(self.runtime.change_boot_setting, + inputs=[drp_choose_version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, + additional_args], outputs=[txt_save_status]) + + restart_gradio.click( + fn=self.request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + with gr.Column(): + with gr.Row(): + machine_settings = self.prelude.get_sys_info() + txt_sys_info = gr.TextArea(value=machine_settings, lines=20, max_lines=20, + label="System Info", + show_label=False, interactive=False) + with gr.Row(): + sys_info_refbtn = gr.Button(value="Refresh") + + drp_gpu.change(self.runtime.update_xformers, inputs=[drp_gpu, chk_group_args], outputs=[chk_group_args]) + sys_info_refbtn.click(self.prelude.get_sys_info, None, txt_sys_info) + + def create_subtab_model_management(self) -> None: + with gr.TabItem('Model Management', elem_id="model_management_tab") as tab_batch: + with gr.Row(): + with gr.Column(): + my_models = self.runtime.get_local_models() + ds_my_models = gr.Dataset( + components=[gr.HTML(visible=False, label='Cover', elem_id='my_model_cover'), + gr.Textbox(visible=False, label='Name/Version'), + gr.Textbox(visible=False, label='File Name'), + gr.Textbox(visible=False, label='Hash'), gr.Textbox(visible=False, label='Creator'), + gr.Textbox(visible=False, label='Type'), gr.Textbox(visible=False, label='NSFW'), + gr.Textbox(visible=False, label='Trigger Words', elem_id='my_model_trigger_words')], + elem_id='my_model_lib', + label="My Models", + headers=None, + samples=my_models, + samples_per_page=50) + with gr.Column(): + html_model_prompt = gr.HTML(visible=True, + value='

No Model Selected

') + + with gr.Row(): + add = gr.Button(value="Add", variant="primary") + # delete = gr.Button(value="Delete") + with gr.Row(): + reset_btn = gr.Button(value="Reset") + json_input = gr.Button(value="Load from JSON") + png_input = gr.Button(value="Detect from image") + png_input_area = gr.Image(label="Detect from image", elem_id="openpose_editor_input") + bg_input = gr.Button(value="Add Background image") + + def create_subtab_model_download(self) -> None: + with gr.TabItem('Model Download', elem_id="model_download_tab") as tab_downloads: + with gr.Row(): + with gr.Column(elem_id="col_model_list"): + with gr.Row().style(equal_height=True): + model_source_dropdown = gr.Dropdown(choices=["civitai", "liandange"], + value=self.runtime.model_source, + label="Select Model Source", + type="value", + show_label=True, + elem_id="model_source").style(full_width=True) + with gr.Row().style(equal_height=True): + search_text = gr.Textbox( + label="Model name", + show_label=False, + max_lines=1, + placeholder="Enter model name", + ) + btn_search = gr.Button("Search") + + with gr.Row().style(equal_height=True): + nsfw_checker = gr.Checkbox(label='NSFW', value=False, elem_id="chk_nsfw", interactive=True) + model_type = gr.Radio(["All", "Checkpoint", "LORA", "TextualInversion", "Hypernetwork"], + show_label=False, value='All', elem_id="rad_model_type", + interactive=True).style(full_width=True) + + images = self.runtime.get_images_html() + self.runtime.ds_models = gr.Dataset( + components=[gr.HTML(visible=False)], + headers=None, + type="values", + label="Models", + samples=images, + samples_per_page=60, + elem_id="model_dataset").style(type="gallery", container=True) + + with gr.Column(elem_id="col_model_info"): + with gr.Row(): + cover_gallery = gr.Gallery(label="Cover", show_label=False, visible=True).style(grid=[4], + height="2") + + with gr.Row(): + with gr.Column(): + download_summary = gr.HTML('
No downloading tasks ongoing
') + downloading_status = gr.Button(value=f"{self.refresh_symbol} Refresh Downloading Status", + elem_id="ms_dwn_status") + with gr.Row(): + model_dropdown = gr.Dropdown(choices=['Select Model'], label="Models", show_label=False, + value='Select Model', elem_id='ms_dwn_button', + interactive=True) + + is_civitai_model_source_active = self.runtime.model_source == "civitai" + with gr.Row(variant="panel"): + dwn_button = gr.Button(value='Download', + visible=is_civitai_model_source_active, elem_id='ms_dwn_button') + open_url_in_browser_newtab_button = gr.HTML( + value='

' + 'Download

', + visible=not is_civitai_model_source_active) + with gr.Row(): + model_info = gr.HTML(visible=True) + + nsfw_checker.change(self.runtime.set_nsfw, inputs=[search_text, nsfw_checker, model_type], + outputs=self.runtime.ds_models) + + model_type.change(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models) + + btn_search.click(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models) + + self.runtime.ds_models.click(self.runtime.get_model_info, + inputs=[self.runtime.ds_models], + outputs=[ + cover_gallery, + model_dropdown, + model_info, + open_url_in_browser_newtab_button + ]) + + dwn_button.click(self.runtime.download_model, inputs=[model_dropdown], outputs=[download_summary]) + downloading_status.click(self.runtime.get_downloading_status, inputs=[], outputs=[download_summary]) + + model_source_dropdown.change(self.switch_model_source, + inputs=[model_source_dropdown], + outputs=[self.runtime.ds_models, dwn_button, open_url_in_browser_newtab_button]) + + def request_restart(self, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args): + print('request_restart: cmd_arg = ', self.runtime.cmdline_args) + print('request_restart: sys.argv = ', sys.argv) + + modules.shared.state.interrupt() + modules.shared.state.need_restart = True + + # reset args + sys.argv = [sys.argv[0]] + os.environ['COMMANDLINE_ARGS'] = "" + print('remove', sys.argv) + + for arg in self.runtime.cmdline_args: + sys.argv.append(arg) + + print('after', sys.argv) + launch.prepare_environment() + launch.start() + + def switch_model_source(self, new_model_source: str): + self.runtime.model_source = new_model_source + show_download_button = self.runtime.model_source == "civitai" + images = self.runtime.get_images_html() + self.runtime.ds_models.samples = images + return ( + gr.Dataset.update(samples=images), + gr.Button.update(visible=show_download_button), + gr.HTML.update(visible=not show_download_button) + ) + + def introception(self) -> None: + self.runtime.introception() diff --git a/scripts/download/__init__.py b/scripts/download/__init__.py new file mode 100644 index 0000000..8ce6385 --- /dev/null +++ b/scripts/download/__init__.py @@ -0,0 +1 @@ +__all__ = ["msai_downloader_manager"] diff --git a/scripts/download/msai_downloader_manager.py b/scripts/download/msai_downloader_manager.py new file mode 100644 index 0000000..ea11e20 --- /dev/null +++ b/scripts/download/msai_downloader_manager.py @@ -0,0 +1,248 @@ +import asyncio +import os.path +import queue +import time +import requests +import typing as t +from threading import Thread, Lock + +from scripts.download.msai_file_downloader import MiaoshouFileDownloader +from scripts.logging.msai_logger import Logger +from scripts.msai_utils.msai_singleton import MiaoshouSingleton +import scripts.msai_utils.msai_toolkit as toolkit + + +class DownloadingEntry(object): + def __init__(self, target_url: str = None, local_file: str = None, + local_directory: str = None, estimated_total_size: float = 0., expected_checksum: str = None): + self._target_url = target_url + self._local_file = local_file + self._local_directory = local_directory + self._expected_checksum = expected_checksum + + self._estimated_total_size = estimated_total_size + self._total_size = 0 + self._downloaded_size = 0 + + self._downloading = False + self._failure = False + + @property + def target_url(self) -> str: + return self._target_url + + @property + def local_file(self) -> str: + return self._local_file + + @property + def local_directory(self) -> str: + return self._local_directory + + @property + def expected_checksum(self) -> str: + return self._expected_checksum + + @property + def total_size(self) -> int: + return self._total_size + + @total_size.setter + def total_size(self, sz: int) -> None: + self._total_size = sz + + @property + def downloaded_size(self) -> int: + return self._downloaded_size + + @downloaded_size.setter + def downloaded_size(self, sz: int) -> None: + self._downloaded_size = sz + + @property + def estimated_size(self) -> float: + return self._estimated_total_size + + def is_downloading(self) -> bool: + return self._downloading + + def start_download(self) -> None: + self._downloading = True + + def update_final_status(self, result: bool) -> None: + self._failure = result + self._downloading = False + + def is_failure(self) -> bool: + return self._failure + + +class AsyncLoopThread(Thread): + def __init__(self): + super(AsyncLoopThread, self).__init__(daemon=True) + self.loop = asyncio.new_event_loop() + self.logger = Logger() + self.logger.info("looper thread is created") + + def run(self): + asyncio.set_event_loop(self.loop) + self.logger.info("looper thread is running") + self.loop.run_forever() + + +class MiaoshouDownloaderManager(metaclass=MiaoshouSingleton): + _downloading_entries: t.Dict[str, DownloadingEntry] = None + + def __init__(self): + if self._downloading_entries is None: + self._downloading_entries = {} + self.message_queue = queue.Queue() + + self.logger = Logger() + self.looper = AsyncLoopThread() + self.looper.start() + self.logger.info("download manager is ready") + self._mutex = Lock() + + def consume_all_ready_messages(self) -> None: + """ + capture all enqueued messages, this method should not be used if you are iterating over the message queue + :return: + None + :side-effect: + update downloading entries' status + """ + while True: + # self.logger.info("fetching the enqueued message") + try: + (aurl, finished_size, total_size) = self.message_queue.get(block=False, timeout=0.2) + # self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}") + try: + self._mutex.acquire(blocking=True) + self._downloading_entries[aurl].total_size = total_size + self._downloading_entries[aurl].downloaded_size = finished_size + finally: + self._mutex.release() + except queue.Empty: + break + + def iterator(self) -> t.Tuple[float, float]: + + while True: + self.logger.info("waiting for incoming message") + + try: + (aurl, finished_size, total_size) = self.message_queue.get(block=True) + self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}") + try: + self._mutex.acquire(blocking=True) + self._downloading_entries[aurl].total_size = total_size + self._downloading_entries[aurl].downloaded_size = finished_size + + tasks_total_size = 0. + tasks_finished_size = 0. + + for e in self._downloading_entries.values(): + tasks_total_size += e.total_size + tasks_finished_size += e.downloaded_size + + yield tasks_finished_size, tasks_total_size + finally: + self._mutex.release() + except queue.Empty: + if len(asyncio.all_tasks(self.looper.loop)) == 0: + self.logger.info("all downloading tasks finished") + break + + async def _submit_task(self, download_entry: DownloadingEntry) -> None: + try: + self._mutex.acquire(blocking=True) + if download_entry.target_url in self._downloading_entries: + self.logger.warn(f"{download_entry.target_url} is already downloading") + return + else: + download_entry.start_download() + self._downloading_entries[download_entry.target_url] = download_entry + finally: + self._mutex.release() + + file_downloader = MiaoshouFileDownloader( + target_url=download_entry.target_url, + local_file=download_entry.local_file, + local_directory=download_entry.local_directory, + channel=self.message_queue if download_entry.estimated_size else None, + estimated_total_length=download_entry.estimated_size, + expected_checksum=download_entry.expected_checksum, + ) + + result: bool = await self.looper.loop.run_in_executor(None, file_downloader.download_file) + + try: + self._mutex.acquire(blocking=True) + self._downloading_entries[download_entry.target_url].update_final_status(result) + finally: + self._mutex.release() + + def download(self, source_url: str, target_file: str, estimated_total_size: float, + expected_checksum: str = None) -> None: + target_dir = os.path.dirname(target_file) + target_filename = os.path.basename(target_file) + download_entry = DownloadingEntry( + target_url=source_url, + local_file=target_filename, + local_directory=target_dir, + estimated_total_size=estimated_total_size, + expected_checksum=expected_checksum + ) + + asyncio.run_coroutine_threadsafe(self._submit_task(download_entry), self.looper.loop) + + def tasks_summary(self) -> t.Tuple[int, int, str]: + self.consume_all_ready_messages() + + total_tasks_num = 0 + ongoing_tasks_num = 0 + failed_tasks_num = 0 + + try: + description = "
" + self._mutex.acquire(blocking=True) + for name, entry in self._downloading_entries.items(): + if entry.estimated_size is None: + continue + + total_tasks_num += 1 + + if entry.total_size > 0.: + description += f"

{entry.local_file} ({toolkit.get_readable_size(entry.total_size)}) : " + else: + description += f"

{entry.local_file} ({toolkit.get_readable_size(entry.estimated_size)}) : " + + if entry.is_downloading(): + ongoing_tasks_num += 1 + finished_percent = entry.downloaded_size/entry.estimated_size * 100 + description += f'{round(finished_percent, 2)} %' + elif entry.is_failure(): + failed_tasks_num += 1 + description += 'failed!' + else: + description += 'finished' + description += "


" + finally: + self._mutex.release() + pass + + description += "
" + overall = f""" +

+ {ongoing_tasks_num} ongoing, + {total_tasks_num - ongoing_tasks_num - failed_tasks_num} finished, + {failed_tasks_num} failed. +

+
+
+ """ + + return ongoing_tasks_num, total_tasks_num, overall + description + + diff --git a/scripts/download/msai_file_downloader.py b/scripts/download/msai_file_downloader.py new file mode 100644 index 0000000..e165c95 --- /dev/null +++ b/scripts/download/msai_file_downloader.py @@ -0,0 +1,226 @@ +import os +import pickle +import queue +import time +import typing as t +from pathlib import Path +from urllib.parse import urlparse + +import rehash +import requests +from requests.adapters import HTTPAdapter +from tqdm import tqdm +from urllib3.util import Retry + +import scripts.msai_utils.msai_toolkit as toolkit +from scripts.logging.msai_logger import Logger + + +class MiaoshouFileDownloader(object): + CHUNK_SIZE = 1024 * 1024 + + def __init__(self, target_url: str = None, + local_file: str = None, local_directory: str = None, estimated_total_length: float = 0., + expected_checksum: str = None, + channel: queue.Queue = None, + max_retries=3) -> None: + self.logger = Logger() + + self.target_url: str = target_url + self.local_file: str = local_file + self.local_directory = local_directory + self.expected_checksum = expected_checksum + self.max_retries = max_retries + + self.accept_ranges: bool = False + self.estimated_content_length = estimated_total_length + self.content_length: int = -1 + self.finished_chunk_size: int = 0 + + self.channel = channel # for communication + + # Support 3 retries and backoff + retry_strategy = Retry( + total=3, + backoff_factor=1, + status_forcelist=[429, 500, 502, 503, 504], + method_whitelist=["HEAD", "GET", "OPTIONS"] + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + self.session = requests.Session() + self.session.mount("https://", adapter) + self.session.mount("http://", adapter) + + # inform message receiver at once + if self.channel: + self.channel.put_nowait(( + self.target_url, + self.finished_chunk_size, + self.estimated_content_length, + )) + + # Head request to get file-length and check whether it supports ranges. + def get_file_info_from_server(self, target_url: str) -> t.Tuple[bool, float]: + try: + headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip + response = requests.head(target_url, headers=headers, allow_redirects=True) + response.raise_for_status() + content_length = None + if "Content-Length" in response.headers: + content_length = int(response.headers['Content-Length']) + accept_ranges = (response.headers.get("Accept-Ranges") == "bytes") + return accept_ranges, float(content_length) + except Exception as ex: + self.logger.info(f"HEAD Request Error: {ex}") + return False, self.estimated_content_length + + def download_file_full(self, target_url: str, local_filepath: str) -> t.Optional[str]: + try: + checksum = rehash.sha256() + headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip + + with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN", + desc=os.path.basename(self.local_file)) as progressbar, \ + self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \ + open(local_filepath, 'wb') as file_out: + response.raise_for_status() + + for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE): + file_out.write(chunk) + checksum.update(chunk) + progressbar.update(len(chunk)) + self.update_progress(len(chunk)) + + except Exception as ex: + self.logger.info(f"Download error: {ex}") + return None + + return checksum.hexdigest() + + def download_file_resumable(self, target_url: str, local_filepath: str) -> t.Optional[str]: + # Always go off the checkpoint as the file was flushed before writing. + download_checkpoint = local_filepath + ".downloading" + try: + resume_point, checksum = pickle.load(open(download_checkpoint, "rb")) + assert os.path.exists(local_filepath) # catch checkpoint without file + self.logger.info("File already exists, resuming download.") + except Exception as e: + self.logger.error(f"failed to load downloading checkpoint - {download_checkpoint} due to {e}") + resume_point = 0 + checksum = rehash.sha256() + if os.path.exists(local_filepath): + os.remove(local_filepath) + Path(local_filepath).touch() + + assert (resume_point < self.content_length) + + self.finished_chunk_size = resume_point + + # Support resuming + headers = {"Range": f"bytes={resume_point}-", "Accept-Encoding": "identity"} + try: + with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN", + desc=os.path.basename(self.local_file)) as progressbar, \ + self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \ + open(local_filepath, 'r+b') as file_out: + response.raise_for_status() + self.update_progress(resume_point) + file_out.seek(resume_point) + + for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE): + file_out.write(chunk) + file_out.flush() + resume_point += len(chunk) + checksum.update(chunk) + pickle.dump((resume_point, checksum), open(download_checkpoint, "wb")) + progressbar.update(len(chunk)) + self.update_progress(len(chunk)) + + # Only remove checkpoint at full size in case connection cut + if os.path.getsize(local_filepath) == self.content_length: + os.remove(download_checkpoint) + else: + return None + + except Exception as ex: + self.logger.error(f"Download error: {ex}") + return None + + return checksum.hexdigest() + + def update_progress(self, finished_chunk_size: int) -> None: + self.finished_chunk_size += finished_chunk_size + + if self.channel: + self.channel.put_nowait(( + self.target_url, + self.finished_chunk_size, + self.content_length, + )) + + # In order to avoid leaving extra garbage meta files behind this + # will overwrite any existing files found at local_file. If you don't want this + # behaviour you can handle this externally. + # local_file and local_directory could write to unexpected places if the source + # is untrusted, be careful! + def download_file(self) -> bool: + success = False + try: + # Need to rebuild local_file_final each time in case of different urls + if not self.local_file: + specific_local_file = os.path.basename(urlparse(self.target_url).path) + else: + specific_local_file = self.local_file + + download_temp_dir = toolkit.get_user_temp_dir() + toolkit.assert_user_temp_dir() + + if self.local_directory: + os.makedirs(self.local_directory, exist_ok=True) + + specific_local_file = os.path.join(download_temp_dir, specific_local_file) + + self.accept_ranges, self.content_length = self.get_file_info_from_server(self.target_url) + self.logger.info(f"Accept-Ranges: {self.accept_ranges}. content length: {self.content_length}") + if self.accept_ranges and self.content_length: + download_method = self.download_file_resumable + self.logger.info("Server supports resume") + else: + download_method = self.download_file_full + self.logger.info(f"Server doesn't support resume.") + + for i in range(self.max_retries): + self.logger.info(f"Download Attempt {i + 1}") + checksum = download_method(self.target_url, specific_local_file) + if checksum: + match = "" + if self.expected_checksum: + match = ", Checksum Match" + + if self.expected_checksum and self.expected_checksum != checksum: + self.logger.info(f"Checksum doesn't match. Calculated {checksum} " + f"Expecting: {self.expected_checksum}") + else: + self.logger.info(f"Download successful{match}. Checksum {checksum}") + success = True + break + time.sleep(1) + + if success: + print("*" * 80) + self.logger.info(f"{self.target_url} [DOWNLOADED COMPLETELY]") + print("*" * 80) + if self.local_directory: + target_local_file = os.path.join(self.local_directory, self.local_file) + else: + target_local_file = self.local_file + toolkit.move_file(specific_local_file, target_local_file) + else: + print("*" * 80) + self.logger.info(f"{self.target_url} [ FAILED ]") + print("*" * 80) + + except Exception as ex: + self.logger.info(f"Unexpected Error: {ex}") # Only from block above + + return success diff --git a/scripts/logging/__init__.py b/scripts/logging/__init__.py new file mode 100644 index 0000000..0e75e22 --- /dev/null +++ b/scripts/logging/__init__.py @@ -0,0 +1 @@ +__all__ = ["msai_logger"] diff --git a/scripts/logging/msai_logger.py b/scripts/logging/msai_logger.py new file mode 100644 index 0000000..95dd71f --- /dev/null +++ b/scripts/logging/msai_logger.py @@ -0,0 +1,100 @@ +import datetime +import logging +import logging.handlers +import os +import typing as t + +from scripts.msai_utils.msai_singleton import MiaoshouSingleton + + +class Logger(metaclass=MiaoshouSingleton): + _dataset = None + + KEY_TRACE_PATH = "trace_path" + KEY_INFO = "info" + KEY_ERROR = "error" + KEY_JOB = "job" + + def _do_init(self, log_folder: str, disable_console_output: bool = False) -> None: + # Setup trace_path with empty string by default, it will be assigned with valid content if needed + self._dataset = {Logger.KEY_TRACE_PATH: ""} + + print(f"logs_location: {log_folder}") + os.makedirs(log_folder, exist_ok=True) + + # Setup basic logging configuration + logging.basicConfig(level=logging.INFO, + filemode='w', + format='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s') + + # Setup info logging + self._dataset[Logger.KEY_INFO] = logging.getLogger(Logger.KEY_INFO) + msg_handler = logging.FileHandler(os.path.join(log_folder, "info.log"), + "a", + encoding="UTF-8") + msg_handler.setLevel(logging.INFO) + msg_handler.setFormatter( + logging.Formatter(fmt='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s')) + self._dataset[Logger.KEY_INFO].addHandler(msg_handler) + + # Setup error logging + self._dataset[Logger.KEY_ERROR] = logging.getLogger(Logger.KEY_ERROR) + error_handler = logging.FileHandler( + os.path.join(log_folder, f'error_{datetime.date.today().strftime("%Y%m%d")}.log'), + mode="a", + encoding='UTF-8') + error_handler.setLevel(logging.ERROR) + error_handler.setFormatter( + logging.Formatter( + fmt=f"{self._dataset.get('trace_path')}:\n " + f"%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s")) + self._dataset[Logger.KEY_ERROR].addHandler(error_handler) + + # Setup job logging + self._dataset[Logger.KEY_JOB] = logging.getLogger(Logger.KEY_JOB) + job_handler = logging.FileHandler(os.path.join(log_folder, "jobs.log"), + mode="a", + encoding="UTF-8") + self._dataset[Logger.KEY_JOB].addHandler(job_handler) + + if disable_console_output: + for k in [Logger.KEY_INFO, Logger.KEY_JOB, Logger.KEY_ERROR]: + l: logging.Logger = self._dataset[k] + l.propagate = not disable_console_output + + def __init__(self, log_folder: str = None, disable_console_output: bool = False) -> None: + if self._dataset is None: + try: + self._do_init(log_folder=log_folder, disable_console_output=disable_console_output) + except Exception as e: + print(e) + + def update_path_info(self, current_path: str) -> None: + self._dataset[Logger.KEY_TRACE_PATH] = current_path + + def callback_func(self, exc_type: t.Any, exc_value: t.Any, exc_tracback: t.Any) -> None: + self._dataset[Logger.KEY_JOB].error(f"job failed for {self._dataset[Logger.KEY_TRACE_PATH]}") + self._dataset[Logger.KEY_INFO].error(f"{self._dataset[Logger.KEY_TRACE_PATH]}\n, callback_func: ", + exc_info=(exc_type, exc_value, exc_tracback)) + + def debug(self, fmt, *args, **kwargs) -> None: + l: logging.Logger = self._dataset[Logger.KEY_INFO] + l.debug(fmt, *args, **kwargs, stacklevel=2) + + def info(self, fmt, *args, **kwargs) -> None: + l: logging.Logger = self._dataset[Logger.KEY_INFO] + l.info(fmt, *args, **kwargs, stacklevel=2) + + def warn(self, fmt, *args, **kwargs) -> None: + l: logging.Logger = self._dataset[Logger.KEY_INFO] + l.warn(fmt, *args, **kwargs, stacklevel=2) + + def error(self, fmt, *args, **kwargs) -> None: + l: logging.Logger = self._dataset[Logger.KEY_ERROR] + l.error(fmt, *args, **kwargs, stacklevel=2) + + def job(self, fmt, *args, **kwargs) -> None: + l: logging.Logger = self._dataset[Logger.KEY_JOB] + l.info(fmt, *args, **kwargs, stacklevel=2) + + diff --git a/scripts/main.py b/scripts/main.py new file mode 100644 index 0000000..82296ba --- /dev/null +++ b/scripts/main.py @@ -0,0 +1,22 @@ +import modules +import modules.scripts as scripts + +from scripts.assistant.miaoshou import MiaoShouAssistant + + +class MiaoshouScript(scripts.Script): + def __init__(self) -> None: + super().__init__() + + def title(self): + return "Miaoshou Assistant" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + return () + + +assistant = MiaoShouAssistant() +modules.script_callbacks.on_ui_tabs(assistant.on_event_ui_tabs_opened) diff --git a/scripts/msai_utils/__init__.py b/scripts/msai_utils/__init__.py new file mode 100644 index 0000000..474a79a --- /dev/null +++ b/scripts/msai_utils/__init__.py @@ -0,0 +1 @@ +__all__ = ["msai_singleton", "msai_toolkit"] diff --git a/scripts/msai_utils/msai_singleton.py b/scripts/msai_utils/msai_singleton.py new file mode 100644 index 0000000..c2f831c --- /dev/null +++ b/scripts/msai_utils/msai_singleton.py @@ -0,0 +1,8 @@ +class MiaoshouSingleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(MiaoshouSingleton, cls).__call__(*args, **kwargs) + cls._instances[cls].__init__(*args, **kwargs) + return cls._instances[cls] diff --git a/scripts/msai_utils/msai_toolkit.py b/scripts/msai_utils/msai_toolkit.py new file mode 100644 index 0000000..cda6706 --- /dev/null +++ b/scripts/msai_utils/msai_toolkit.py @@ -0,0 +1,88 @@ +import json +import os +import platform +import shutil +import typing as t +from datetime import datetime +from pathlib import Path + + +def read_json(file) -> t.Any: + try: + with open(file, "r", encoding="utf-8-sig") as f: + return json.load(f) + except Exception as e: + print(e) + return None + + +def write_json(file, content) -> None: + try: + with open(file, 'w') as f: + json.dump(content, f, indent=4) + except Exception as e: + print(e) + + +def get_args(args) -> t.List[str]: + parameters = [] + idx = 0 + for arg in args: + if idx == 0 and '--' not in arg: + pass + elif '--' in arg: + parameters.append(rf'{arg}') + idx += 1 + else: + parameters[idx - 1] = parameters[idx - 1] + ' ' + rf'{arg}' + + return parameters + + +def get_readable_size(size: int, precision=2) -> str: + if size is None: + return "" + + suffixes = ['B', 'KB', 'MB', 'GB', 'TB'] + suffixIndex = 0 + while size >= 1024 and suffixIndex < len(suffixes): + suffixIndex += 1 # increment the index of the suffix + size = size / 1024.0 # apply the division + return "%.*f%s" % (precision, size, suffixes[suffixIndex]) + + +def get_file_last_modified_time(path_to_file: str) -> datetime: + if path_to_file is None: + return datetime.now() + + if platform.system() == "Windows": + return datetime.fromtimestamp(os.path.getmtime(path_to_file)) + else: + stat = os.stat(path_to_file) + return datetime.fromtimestamp(stat.st_mtime) + + +def get_not_found_image_url() -> str: + return "https://msdn.miaoshouai.com/msdn/userimage/not-found.svg" + + +def get_user_temp_dir() -> str: + return os.path.join(Path.home().absolute(), ".miaoshou_assistant_download") + + +def assert_user_temp_dir() -> None: + os.makedirs(get_user_temp_dir(), exist_ok=True) + + +def move_file(src: str, dst: str) -> None: + if not src or not dst: + return + + if not os.path.exists(src): + return + + if os.path.exists(dst): + os.remove(dst) + + os.makedirs(os.path.dirname(dst), exist_ok=True) + shutil.move(src, dst) diff --git a/scripts/runtime/__init__.py b/scripts/runtime/__init__.py new file mode 100644 index 0000000..fe306c7 --- /dev/null +++ b/scripts/runtime/__init__.py @@ -0,0 +1,5 @@ +__all__ = ["msai_prelude", "msai_runtime"] + +from . import msai_prelude as prelude + +prelude.MiaoshouPrelude().load() diff --git a/scripts/runtime/msai_prelude.py b/scripts/runtime/msai_prelude.py new file mode 100644 index 0000000..bb14663 --- /dev/null +++ b/scripts/runtime/msai_prelude.py @@ -0,0 +1,158 @@ +import os +import platform +import sys +import typing as t + +import psutil +import torch + +import launch +from scripts.logging.msai_logger import Logger +from scripts.msai_utils import msai_toolkit as toolkit +from scripts.msai_utils.msai_singleton import MiaoshouSingleton + + +class MiaoshouPrelude(metaclass=MiaoshouSingleton): + _dataset = None + + def __init__(self) -> None: + # Potential race condition, not call in multithread environment + if MiaoshouPrelude._dataset is None: + self._init_constants() + + MiaoshouPrelude._dataset = { + "log_folder": os.path.join(self.ext_folder, "logs") + } + + disable_log_console_output: bool = False + if self.all_settings.get("boot_settings"): + if self.all_settings["boot_settings"].get("disable_log_console_output") is not None: + disable_log_console_output = self.all_settings["boot_settings"].get("disable_log_console_output") + + self._logger = Logger(self._dataset["log_folder"], disable_console_output=disable_log_console_output) + + def _init_constants(self) -> None: + self._api_url = { + "civitai": "https://civitai.com/api/v1/models", + "liandange": "https://model-api.paomiantv.cn/model/api/models", + } + self._ext_folder = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")) + self._setting_file = os.path.join(self.ext_folder, "configs", "settings.json") + self._model_hash_file = os.path.join(self.ext_folder, "configs", "model_hash.json") + self._model_json = { + 'civitai': os.path.join(self.ext_folder, 'configs', 'civitai_models.json'), + 'liandange': os.path.join(self.ext_folder, 'configs', 'liandange_models.json'), + } + self._checkboxes = { + 'Enable xFormers': '--xformers', + 'No Half': '--no-half', + 'No Half VAE': '--no-half-vae', + 'Enable API': '--api', + 'Auto Launch': '--autolaunch', + 'Allow Local Network Access': '--listen', + } + + self._gpu_setting = { + 'CPU Only': '--precision full --no-half --use-cpu SD GFPGAN BSRGAN ESRGAN SCUNet CodeFormer --all', + 'GTX 16xx': '--lowvram --xformers --precision full --no-half', + 'Low: 4-6G VRAM': '--xformers --lowvram', + 'Med: 6-8G VRAM': '--xformers --medvram', + 'Normal: 8+G VRAM': '', + } + + self._theme_setting = { + 'Auto': '', + 'Light Mode': '--theme=light', + 'Dark Mode': '--theme=dark', + } + + @property + def ext_folder(self) -> str: + return self._ext_folder + + @property + def log_folder(self) -> str: + return self._dataset.get("log_folder") + + @property + def all_settings(self) -> t.Any: + return toolkit.read_json(self._setting_file) + + @property + def boot_settings(self) -> t.Any: + all_setting = self.all_settings + if all_setting: + return all_setting['boot_settings'] + else: + return None + + def api_url(self, model_source: str) -> t.Optional[str]: + return self._api_url.get(model_source) + + @property + def setting_file(self) -> str: + return self._setting_file + + @property + def model_hash_file(self) -> str: + return self._model_hash_file + + @property + def checkboxes(self) -> t.Dict[str, str]: + return self._checkboxes + + @property + def gpu_setting(self) -> t.Dict[str, str]: + return self._gpu_setting + + @property + def theme_setting(self) -> t.Dict[str, str]: + return self._theme_setting + + @property + def model_json(self) -> t.Dict[str, t.Any]: + return self._model_json + + def update_model_json(self, site: str, models: t.Dict[str, t.Any]) -> None: + if self._model_json.get(site) is None: + self._logger.error(f"cannot save model info for {site}") + return + + self._logger.info(f"{self._model_json[site]} updated") + toolkit.write_json(self._model_json[site], models) + + def load(self) -> None: + self._logger.info("start to do prelude") + self._logger.info(f"cmdline args: {' '.join(sys.argv[1:])}") + + @classmethod + def get_sys_info(cls) -> str: + sys_info = 'System Information\n\n' + + sys_info += r'OS Name: {0} {1}'.format(platform.system(), platform.release()) + '\n' + sys_info += r'OS Version: {0}'.format(platform.version()) + '\n' + sys_info += r'WebUI Version: {0}'.format( + f'https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{launch.commit_hash()}') + '\n' + sys_info += r'Torch Version: {0}'.format(getattr(torch, '__long_version__', torch.__version__)) + '\n' + sys_info += r'Python Version: {0}'.format(sys.version) + '\n\n' + sys_info += r'CPU: {0}'.format(platform.processor()) + '\n' + sys_info += r'CPU Cores: {0}/{1}'.format(psutil.cpu_count(logical=False), psutil.cpu_count(logical=True)) + '\n' + + # FIXME: should uncomment line below and remove my own workaround for MacOS + # sys_info += r'CPU Frequency: {0} GHz'.format(round(psutil.cpu_freq().max/1000,2)) + '\n' + # workaround: let my macbook M1 happy + sys_info += r'CPU Frequency: {0} GHz'.format(round(2540.547 / 1000, 2)) + '\n' + + sys_info += r'CPU Usage: {0}%'.format(psutil.cpu_percent()) + '\n\n' + sys_info += r'RAM: {0}'.format(toolkit.get_readable_size(psutil.virtual_memory().total)) + '\n' + sys_info += r'Memory Usage: {0}%'.format(psutil.virtual_memory().percent) + '\n\n' + for i in range(torch.cuda.device_count()): + sys_info += r'Graphics Card{0}: {1} ({2})'.format(i, torch.cuda.get_device_properties(i).name, + toolkit.get_readable_size( + torch.cuda.get_device_properties( + i).total_memory)) + '\n' + sys_info += r'Available VRAM: {0}'.format(toolkit.get_readable_size(torch.cuda.mem_get_info(i)[0])) + '\n' + + return sys_info + + diff --git a/scripts/runtime/msai_runtime.py b/scripts/runtime/msai_runtime.py new file mode 100644 index 0000000..5475293 --- /dev/null +++ b/scripts/runtime/msai_runtime.py @@ -0,0 +1,630 @@ +import datetime +import fileinput +import os +import platform +import re +import shutil +import sys +import time +import typing as t + +import gradio as gr +import requests + +import modules +from modules.sd_models import CheckpointInfo +from scripts.download.msai_downloader_manager import MiaoshouDownloaderManager +from scripts.logging.msai_logger import Logger +from scripts.msai_utils import msai_toolkit as toolkit +from scripts.runtime.msai_prelude import MiaoshouPrelude + + +class MiaoshouRuntime(object): + def __init__(self): + self.cmdline_args: t.List[str] = None + self.logger = Logger() + self.prelude = MiaoshouPrelude() + self._old_additional: str = None + self._model_set: t.List[t.Dict] = None + self._model_set_last_access_time: datetime.datetime = None + self._ds_models: gr.Dataset = None + self._allow_nsfw: bool = False + self._model_source: str = "civitai" # civitai is the default model source + + # TODO: may be owned by downloader class + self.model_files = [] + + self.downloader_manager = MiaoshouDownloaderManager() + + def get_default_args(self, commandline_args: t.List[str] = None): + if commandline_args is None: + commandline_args: t.List[str] = toolkit.get_args(sys.argv[1:]) + self.cmdline_args = commandline_args + self.logger.info(f"default commandline args: {commandline_args}") + + checkbox_values = [] + additional_args = "" + saved_setting = self.prelude.boot_settings + + gpu = saved_setting.get('drp_args_vram') + theme = saved_setting.get('drp_args_theme') + port = saved_setting.get('txt_args_listen_port') + + for arg in commandline_args: + if 'theme' in arg: + theme = [k for k, v in self.prelude.theme_setting.items() if v == arg][0] + if 'port' in arg: + port = arg.split(' ')[-1] + + for chk in self.prelude.checkboxes: + for arg in commandline_args: + if self.prelude.checkboxes[chk] == arg: + checkbox_values.append(chk) + + gpu_arg_list = [f'--{i.strip()}' for i in ' '.join(list(self.prelude.gpu_setting.values())).split('--')] + for arg in commandline_args: + if 'port' not in arg \ + and arg not in list(self.prelude.theme_setting.values()) \ + and arg not in list(self.prelude.checkboxes.values()) \ + and arg not in gpu_arg_list: + additional_args += (' ' + rf'{arg}') + + self._old_additional = additional_args + webui_ver = saved_setting['drp_choose_version'] + + return gpu, theme, port, checkbox_values, additional_args.replace('\\', '\\\\').strip(), webui_ver + + def add_arg(self, args: str = "") -> None: + for arg in args.split('--'): + self.logger.info(f'add arg: {arg.strip()}') + if f"--{arg.strip()}" not in self.cmdline_args and arg.strip() != '': + self.cmdline_args.append(f'--{arg.strip()}') + + def remove_arg(self, args: str = "") -> None: + arg_keywords = ['port', 'theme'] + + for arg in args.split('--'): + if arg in arg_keywords: + for cmdl in self.cmdline_args: + if arg in cmdl: + self.cmdline_args.remove(cmdl) + break + elif f'--{arg.strip()}' in self.cmdline_args and arg.strip() != '': + print(f"remove args:{arg.strip()}") + self.cmdline_args.remove(f'--{arg.strip()}') + + def get_final_args(self, gpu, theme, port, checkgroup, more_args) -> None: + # gpu settings + for s1 in self.prelude.gpu_setting: + if s1 in gpu: + for s2 in self.prelude.gpu_setting: + if s2 != s1: + self.remove_arg(self.prelude.gpu_setting[s2]) + self.add_arg(self.prelude.gpu_setting[s1]) + + if port != '7860': + self.add_arg(f'--port {port}') + else: + self.remove_arg('--port') + + # theme settings + self.remove_arg('--theme') + for t in self.prelude.theme_setting: + if t == theme: + self.add_arg(self.prelude.theme_setting[t]) + break + + # check box settings + for chked in checkgroup: + self.logger.info(f'checked:{self.prelude.checkboxes[chked]}') + self.add_arg(self.prelude.checkboxes[chked]) + + for unchk in list(set(list(self.prelude.checkboxes.keys())) - set(checkgroup)): + print(f'unchecked:{unchk}') + self.remove_arg(self.prelude.checkboxes[unchk]) + + # additional commandline settings + self.remove_arg(self._old_additional) + self.add_arg(more_args.replace('\\\\', '\\')) + self._old_additional = more_args.replace('\\\\', '\\') + + def fetch_all_models(self) -> t.List[t.Dict]: + endpoint_url = self.prelude.api_url(self.model_source) + if endpoint_url is None: + self.logger.error(f"{self.model_source} is not supported") + return [] + + self.logger.info(f"start to fetch model info from '{self.model_source}':{endpoint_url}") + + limit_threshold = 100 + + all_set = [] + response = requests.get(endpoint_url + f'?page=1&limit={limit_threshold}') + num_of_pages = response.json()['metadata']['totalPages'] + self.logger.info(f"total pages = {num_of_pages}") + + continuous_error_counts = 0 + + for p in range(1, num_of_pages + 1): + try: + response = requests.get(endpoint_url + f'?page={p}&limit={limit_threshold}') + payload = response.json() + if payload.get("success") is not None and not payload.get("success"): + self.logger.error(f"failed to fetch page[{p + 1}]") + continuous_error_counts += 1 + if continuous_error_counts > 10: + break + else: + continue + + continuous_error_counts = 0 # reset error flag + self.logger.debug(f"start to process page[{p}]") + + for model in payload['items']: + all_set.append(model) + + self.logger.debug(f"page[{p}] : {len(payload['items'])} items added") + except Exception as e: + self.logger.error(f"failed to fetch page[{p + 1}] due to {e}") + time.sleep(3) + + if len(all_set) > 0: + self.prelude.update_model_json(self.model_source, all_set) + else: + self.logger.error("fetch_all_models: emtpy body received") + + return all_set + + def refresh_all_models(self) -> None: + if self.fetch_all_models(): + if self.ds_models: + self.ds_models.samples = self.model_set + self.ds_models.update(samples=self.model_set) + else: + self.logger.error(f"ds models is null") + + def get_images_html(self, search: str = '', model_type: str = 'All') -> t.List[str]: + self.logger.info(f"get_image_html: model_type = {model_type}, and search pattern = '{search}'") + + model_cover_thumbnails = [] + model_format = [] + + if self.model_set is None: + self.logger.error("model_set is null") + return [] + + self.logger.info(f"{len(self.model_set)} items inside '{self.model_source}'") + + search = search.lower() + for model in self.model_set: + try: + if model.get('type') is not None \ + and model.get('type') not in model_format: + model_format.append(model['type']) + + if search == '' or \ + (model.get('name') is not None and search in model.get('name').lower()) \ + or (model.get('description') is not None and search in model.get('description').lower()): + + if (model_type == 'All' or model_type in model.get('type')) \ + and (self.allow_nsfw or (not self.allow_nsfw and not model.get('nsfw'))): + model_cover_thumbnails.append([ + [f""" +
+
+ +
+
+

{model.get('name')}

+

Type: {model.get('type')}

+

Rating: {model.get('stats')['rating']}

+
+
+ """], + model['id']]) + except Exception: + continue + + return model_cover_thumbnails + + # TODO: add typing hint + def update_boot_settings(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args): + boot_settings = self.prelude.boot_settings + boot_settings['drp_args_vram'] = drp_gpu + boot_settings["drp_args_theme"] = drp_theme + boot_settings['txt_args_listen_port'] = txt_listen_port + for chk in chk_group_args: + self.logger.debug(chk) + boot_settings[chk] = self.prelude.checkboxes[chk] + boot_settings['txt_args_more'] = additional_args + boot_settings['drp_choose_version'] = version + + all_settings = self.prelude.all_settings + all_settings['boot_settings'] = boot_settings + + toolkit.write_json(self.prelude.setting_file, all_settings) + + def get_all_models(self, site: str) -> t.Any: + return toolkit.read_json(self.prelude.model_json[site]) + + def update_model_json(self, site: str, models: t.Any) -> None: + toolkit.write_json(self.prelude.model_json[site], models) + + def get_hash_from_json(self, chk_point: CheckpointInfo) -> CheckpointInfo: + model_hashes = toolkit.read_json(self.prelude.model_hash_file) + + if len(model_hashes) == 0 or chk_point.title not in model_hashes.keys(): + chk_point.shorthash = chk_point.calculate_shorthash() + model_hashes[chk_point.title] = chk_point.shorthash + toolkit.write_json(self.prelude.model_hash_file, model_hashes) + else: + chk_point.shorthash = model_hashes[chk_point.title] + + return chk_point + + def get_local_models(self) -> t.List[t.Any]: + models = [] + + for file in modules.sd_models.checkpoint_tiles(): + chkpt_info = modules.sd_models.get_closet_checkpoint_match(file) + if chkpt_info.sha256 is None and chkpt_info.shorthash is None: + chkpt_info = self.get_hash_from_json(chkpt_info) + + fname = re.sub(r'\[.*?\]', "", chkpt_info.title) + model_info = self.search_model_by_hash(chkpt_info.sha256, chkpt_info.shorthash, fname) + if model_info is not None: + models.append(model_info) + else: + self.logger.info( + f"{chkpt_info.title}, {chkpt_info.hash}, {chkpt_info.shorthash}, {chkpt_info.sha256}") + models.append([ + [f''], + [os.path.basename(fname)], + [fname], + [chkpt_info.shorthash], + [], [], []]) + + return models + + def search_model_by_hash(self, lookup_sha256: str, lookup_shash: str, fname: str) -> t.Optional[t.List[t.Any]]: + self.logger.info(f"lookup_sha256: {lookup_sha256}, lookup_shash: {lookup_shash}, fname: {fname}") + + res = None + if lookup_sha256 is None and lookup_shash is None: + return None + + for model in self.model_set: + match = False + + for ver in model['modelVersions']: + for file in ver['files']: + if lookup_sha256 is not None and 'SHA256' in file['hashes'].keys(): + match = (lookup_sha256.upper() == file['hashes']['SHA256'].upper()) + elif lookup_shash is not None: + match = (lookup_shash[:10].upper() in [h.upper() for h in file['hashes'].values()]) + + if match: + cover_link = ver['images'][0]['url'].replace('width=450', 'width=100') + mid = model['id'] + res = [ + [ + f''], + [f"{model['name']}/{ver['name']}"], + [fname], + [lookup_shash], + [model['creator']['username']], + [model['type']], + [model['nsfw']], + [ver['trainedWords']], + ] + + if match: + break + + return res + + def update_xformers(self, gpu, checkgroup): + if '--xformers' in self.prelude.gpu_setting[gpu]: + if 'Enable xFormers' not in checkgroup: + checkgroup.append('Enable xFormers') + + return checkgroup + + def set_nsfw(self, search='', nsfw_checker=False, model_type='All') -> t.Dict: + self._allow_nsfw = nsfw_checker + new_list = self.get_images_html(search, model_type) + if self._ds_models is None: + self.logger.error(f"_ds_models is not initialized") + return {} + + self._ds_models.samples = new_list + return self._ds_models.update(samples=new_list) + + def search_model(self, search='', model_type='All') -> t.Dict: + if self._ds_models is None: + self.logger.error(f"_ds_models is not initialized") + return {} + + new_list = self.get_images_html(search, model_type) + + self._ds_models.samples = new_list + return self._ds_models.update(samples=new_list) + + def get_model_info(self, models) -> t.Tuple[t.List[t.List[str]], t.Dict, str, t.Dict]: + drop_list = [] + cover_imgs = [] + htmlDetail = "

Empty

" + + mid = models[1] + + # TODO: use map to enhance the performances + m = [e for e in self.model_set if e['id'] == mid][0] + + self.model_files.clear() + + download_url_by_default = None + if m and m.get('modelVersions') and len(m.get('modelVersions')) > 0: + latest_version = m['modelVersions'][0] + + if latest_version.get('images') and isinstance(latest_version.get('images'), list): + for img in latest_version['images']: + if self.allow_nsfw or (not self.allow_nsfw and not img.get('nsfw')): + if img.get('url'): + cover_imgs.append([img['url'], '']) + + if latest_version.get('files') and isinstance(latest_version.get('files'), list): + for file in latest_version['files']: + # error checking for mandatory fields + if file.get('id') is not None and file.get('downloadUrl') is not None: + item_name = None + if file.get('name'): + item_name = file.get('name') + if not item_name and latest_version.get('name'): + item_name = latest_version['name'] + if not item_name: + item_name = "unknown" + + self.model_files.append({ + "id:": file['id'], + "url": file['downloadUrl'], + "name": item_name, + "type": m['type'] if m.get('type') else "unknown", + "size": file['sizeKB'] * 1024 if file.get('sizeKB') else "unknown", + "format": file['format'] if file.get('format') else "unknown", + "cover": cover_imgs[0][0] if len(cover_imgs) > 0 else toolkit.get_not_found_image_url(), + }) + file_size = toolkit.get_readable_size(file['sizeKB'] * 1024) if file.get('sizeKB') else "" + if file_size: + drop_list.append(f"{item_name} ({file_size})") + else: + drop_list.append(f"{item_name}") + + if not download_url_by_default: + download_url_by_default = file.get('downloadUrl') + + htmlDetail = '
' + if m.get('name'): + htmlDetail += f"

{m['name']}


" + if m.get('stats') and m.get('stats').get('downloadCount'): + htmlDetail += f"

Downloads: {m['stats']['downloadCount']}

" + if m.get('stats') and m.get('stats').get('rating'): + htmlDetail += f"

Rating: {m['stats']['rating']}

" + if m.get('creator') and m.get('creator').get('username'): + htmlDetail += f"

Author: {m['creator']['username']}



" + if latest_version.get('name'): + htmlDetail += f"
" + if latest_version.get('updatedAt'): + htmlDetail += f"" + if m.get('type'): + htmlDetail += f"" + if latest_version.get('baseModel'): + htmlDetail += f"" + htmlDetail += f"" + if m.get('tags') and isinstance(m.get('tags'), list): + htmlDetail += f"" + if latest_version.get('trainedWords'): + htmlDetail += f"" + htmlDetail += "
Version:{latest_version['name']}
Updated Time:{latest_version['updatedAt']}
Type:{m['type']}
Base Model:{latest_version['baseModel']}
NFSW:{m.get('nsfw') if m.get('nsfw') is not None else 'false'}
Tags:" + for t in m['tags']: + htmlDetail += f"{t}" + htmlDetail += "
Trigger Words:" + for t in latest_version['trainedWords']: + htmlDetail += f"{t}" + htmlDetail += "
" + htmlDetail += f"
{m['description'] if m.get('description') else 'N/A'}
" + + return ( + cover_imgs, + gr.Dropdown.update(choices=drop_list, value=drop_list[0] if len(drop_list) > 0 else []), + htmlDetail, + gr.HTML.update(value=f'

' + f'Download

') + ) + + def get_downloading_status(self): + (_, _, desc) = self.downloader_manager.tasks_summary() + return gr.HTML.update(value=desc) + + def download_model(self, filename: str): + model_path = modules.paths.models_path + script_path = modules.paths.script_path + + urls = [] + for _, f in enumerate(self.model_files): + if not f.get('name'): + continue + model_fname = re.sub(r"\s*\(\d+(?:\.\d*)?.B\)\s*$", "", f['name']) + + if model_fname in filename: + m_pre, m_ext = os.path.splitext(model_fname) + cover_fname = f"{m_pre}.jpg" + + if f['type'] == 'LORA': + cover_fname = os.path.join(model_path, 'Lora', cover_fname) + model_fname = os.path.join(model_path, 'Lora', model_fname) + elif f['format'] == 'VAE': + cover_fname = os.path.join(model_path, 'VAE', cover_fname) + model_fname = os.path.join(model_path, 'VAE', model_fname) + elif f['format'] == 'TextualInversion': + cover_fname = os.path.join(script_path, 'embeddings', cover_fname) + model_fname = os.path.join(script_path, 'embeddings', model_fname) + elif f['format'] == 'Hypernetwork': + cover_fname = os.path.join(model_path, 'hypernetworks', cover_fname) + model_fname = os.path.join(model_path, 'hypernetworks', model_fname) + else: + cover_fname = os.path.join(model_path, 'Stable-diffusion', cover_fname) + model_fname = os.path.join(model_path, 'Stable-diffusion', model_fname) + + urls.append((f['cover'], f['url'], f['size'], cover_fname, model_fname)) + break + + for (cover_url, model_url, total_size, local_cover_name, local_model_name) in urls: + self.downloader_manager.download( + source_url=cover_url, + target_file=local_cover_name, + estimated_total_size=None, + ) + self.downloader_manager.download( + source_url=model_url, + target_file=local_model_name, + estimated_total_size=total_size, + ) + + # + # currently, web-ui is without queue enabled. + # + # webui_queue_enabled = False + # if webui_queue_enabled: + # start = time.time() + # downloading_tasks_iter = self.downloader_manager.iterator() + # for i in progressbar.tqdm(range(100), unit="byte", desc="Models Downloading"): + # while True: + # try: + # finished_bytes, total_bytes = next(downloading_tasks_iter) + # v = finished_bytes / total_bytes + # print(f"\n v = {v}") + # if isinstance(v, float) and int(v * 100) < i: + # print(f"\nv({v}) < {i}") + # continue + # else: + # break + # except StopIteration: + # break + # + # time.sleep(0.5) + # + # self.logger.info(f"[downloading] finished after {time.time() - start} secs") + + time.sleep(2) + + self.model_files.clear() + return gr.HTML.update(value=f"

{len(urls)} downloading tasks added into task list

") + + def change_boot_setting(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args): + self.get_final_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args) + self.logger.info(f'saved_cmd: {self.cmdline_args}') + + target_webui_user_file = "webui-user.bat" + script_export_keyword = "export" + if platform.system() == "Linux": + target_webui_user_file = "webui-user.sh" + elif platform.system() == "Darwin": + target_webui_user_file = "webui-macos-env.sh" + else: + script_export_keyword = "set" + + filepath = os.path.join(modules.shared.script_path, target_webui_user_file) + self.logger.info(f"to update: {filepath}") + + msg = 'Result: Setting Saved.' + if version == 'Official Release': + try: + if not os.path.exists(filepath): + shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath) + + with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file: + for line in file: + if 'COMMANDLINE_ARGS' in line: + rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\') + line = f'{script_export_keyword} COMMANDLINE_ARGS="{rep_txt}"\n' + sys.stdout.write(line) + + except Exception as e: + msg = f'Error: {str(e)}' + else: + try: + if not os.path.exists(filepath): + shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath) + + with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file: + for line in file: + if 'webui.py' in line: + rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\') + line = f"python\python.exe webui.py {rep_txt}\n" + sys.stdout.write(line) + except Exception as e: + msg = f'Error: {str(e)}' + + self.update_boot_settings(version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args) + return gr.update(value=msg, visible=True) + + @property + def model_set(self) -> t.List[t.Dict]: + try: + self.logger.info(f"access to model info for '{self.model_source}'") + model_json_mtime = toolkit.get_file_last_modified_time(self.prelude.model_json[self.model_source]) + + if self._model_set is None or self._model_set_last_access_time is None \ + or self._model_set_last_access_time < model_json_mtime: + self._model_set = self.get_all_models(self.model_source) + self._model_set_last_access_time = model_json_mtime + self.logger.info(f"load '{self.model_source}' model data from local file") + except Exception as e: + self._model_set = self.fetch_all_models() + self._model_set_last_access_time = datetime.datetime.now() + + return self._model_set + + @property + def allow_nsfw(self) -> bool: + return self._allow_nsfw + + @property + def old_additional_args(self) -> str: + return self._old_additional + + @property + def ds_models(self) -> gr.Dataset: + return self._ds_models + + @ds_models.setter + def ds_models(self, newone: gr.Dataset): + self._ds_models = newone + + @property + def model_source(self) -> str: + return self._model_source + + @model_source.setter + def model_source(self, newone: str): + self.logger.info(f"model source changes from {self.model_source} to {newone}") + self._model_source = newone + self._model_set_last_access_time = None # reset timestamp + + def introception(self) -> None: + (gpu, theme, port, checkbox_values, extra_args, ver) = self.get_default_args() + + print("################################################################") + print("MIAOSHOU ASSISTANT ARGUMENTS:") + + print(f" gpu = {gpu}") + print(f" theme = {theme}") + print(f" port = {port}") + print(f" checkbox_values = {checkbox_values}") + print(f" extra_args = {extra_args}") + print(f" webui ver = {ver}") + + print("################################################################") +