implement partial features
Signed-off-by: yanshoutong <yanshoutong@sina.cn>pull/38/head
parent
c3017fef4e
commit
e67f68e7c3
|
|
@ -0,0 +1,17 @@
|
|||
# for vim
|
||||
.~
|
||||
.*.swp
|
||||
*~
|
||||
|
||||
# for MacOS
|
||||
.DS_Store
|
||||
|
||||
__pycache__/
|
||||
|
||||
.idea/
|
||||
|
||||
logs/
|
||||
flagged/
|
||||
|
||||
configs/civitai_models.json
|
||||
configs/liandange_models.json
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
import launch
|
||||
import os
|
||||
import gzip
|
||||
import io
|
||||
|
||||
|
||||
def install_preset_models_if_needed():
|
||||
assets_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets")
|
||||
configs_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
|
||||
|
||||
for model_filename in ["civitai_models.json", "liandange_models.json"]:
|
||||
gzip_file = os.path.join(assets_folder, f"{model_filename}.gz")
|
||||
target_file = os.path.join(configs_folder, f"{model_filename}")
|
||||
if not os.path.exists(target_file):
|
||||
with gzip.open(gzip_file, "rb") as compressed_file:
|
||||
with io.TextIOWrapper(compressed_file, encoding="utf-8") as decoder:
|
||||
content = decoder.read()
|
||||
with open(target_file, "w") as model_file:
|
||||
model_file.write(content)
|
||||
|
||||
|
||||
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
||||
|
||||
with open(req_file) as file:
|
||||
for lib in file:
|
||||
lib = lib.strip()
|
||||
if not launch.is_installed(lib):
|
||||
launch.run_pip(f"install {lib}", f"Miaoshou assistant requirement: {lib}")
|
||||
|
||||
install_preset_models_if_needed()
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
psutil
|
||||
rehash
|
||||
tqdm
|
||||
|
|
@ -0,0 +1 @@
|
|||
__all__ = ["miaoshou"]
|
||||
|
|
@ -0,0 +1,256 @@
|
|||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import launch
|
||||
import modules
|
||||
from scripts.logging.msai_logger import Logger
|
||||
from scripts.runtime.msai_prelude import MiaoshouPrelude
|
||||
from scripts.runtime.msai_runtime import MiaoshouRuntime
|
||||
|
||||
|
||||
class MiaoShouAssistant(object):
|
||||
# default css definition
|
||||
default_css = '#my_model_cover{width: 100px;} #my_model_trigger_words{width: 200px;}'
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.logger = Logger()
|
||||
self.prelude = MiaoshouPrelude()
|
||||
self.runtime = MiaoshouRuntime()
|
||||
self.refresh_symbol = '\U0001f504'
|
||||
|
||||
def on_event_ui_tabs_opened(self) -> t.List[t.Optional[t.Tuple[t.Any, str, str]]]:
|
||||
with gr.Blocks(analytics_enabled=False, css=MiaoShouAssistant.default_css) as miaoshou_assistant:
|
||||
self.create_subtab_boot_assistant()
|
||||
self.create_subtab_model_management()
|
||||
self.create_subtab_model_download()
|
||||
|
||||
return [(miaoshou_assistant.queue(), "Miaoshou Assistant", "miaoshou_assistant")]
|
||||
|
||||
def create_subtab_boot_assistant(self) -> None:
|
||||
with gr.TabItem('Boot Assistant', elem_id="boot_assistant_tab") as boot_assistant:
|
||||
with gr.Row():
|
||||
with gr.Column(elem_id="col_model_list"):
|
||||
gpu, theme, port, chk_args, txt_args, webui_ver = self.runtime.get_default_args()
|
||||
gr.Markdown(value="Argument settings")
|
||||
with gr.Row():
|
||||
drp_gpu = gr.Dropdown(label="", elem_id="drp_args_vram",
|
||||
choices=list(self.prelude.gpu_setting.keys()),
|
||||
value=gpu, interactive=True)
|
||||
drp_theme = gr.Dropdown(label="UI Theme", choices=list(self.prelude.theme_setting.keys()),
|
||||
value=theme,
|
||||
elem_id="drp_args_theme", interactive=True)
|
||||
txt_listen_port = gr.Text(label='Listen Port', value=port, elem_id="txt_args_listen_port",
|
||||
interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
chk_group_args = gr.CheckboxGroup(choices=list(self.prelude.checkboxes.keys()), value=chk_args,
|
||||
show_label=False)
|
||||
additional_args = gr.Text(label='COMMANDLINE_ARGS (Divide by space)', value=txt_args,
|
||||
elem_id="txt_args_more", interactive=True)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
txt_save_status = gr.Markdown(visible=False, interactive=False, show_label=False)
|
||||
drp_choose_version = gr.Dropdown(label="WebUI Version",
|
||||
choices=['Official Release', 'Python Integrated'],
|
||||
value=webui_ver, elem_id="drp_args_version",
|
||||
interactive=True)
|
||||
gr.HTML(
|
||||
'<div><p>*Save your settings to webui-user.bat file. Use Python Integrated only if your'
|
||||
' WebUI is extracted from a zip file and does not need python installation</p></div>')
|
||||
save_settings = gr.Button(value="Save settings", elem_id="btn_arg_save_setting")
|
||||
|
||||
with gr.Row():
|
||||
# with gr.Column():
|
||||
# settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="ms_settings_submit")
|
||||
# with gr.Column():
|
||||
restart_gradio = gr.Button(value='Apply & Restart WebUI', variant='primary',
|
||||
elem_id="ms_settings_restart_gradio")
|
||||
|
||||
'''def mod_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
|
||||
global commandline_args
|
||||
|
||||
get_final_args(drp_gpu, drp_theme, txt_listen_port, hk_group_args, additional_args)
|
||||
|
||||
print(commandline_args)
|
||||
print(sys.argv)
|
||||
#if '--xformers' not in sys.argv:
|
||||
#sys.argv.append('--xformers')
|
||||
|
||||
settings_submit.click(mod_args, inputs=[drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args], outputs=[])'''
|
||||
|
||||
save_settings.click(self.runtime.change_boot_setting,
|
||||
inputs=[drp_choose_version, drp_gpu, drp_theme, txt_listen_port, chk_group_args,
|
||||
additional_args], outputs=[txt_save_status])
|
||||
|
||||
restart_gradio.click(
|
||||
fn=self.request_restart,
|
||||
_js='restart_reload',
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
machine_settings = self.prelude.get_sys_info()
|
||||
txt_sys_info = gr.TextArea(value=machine_settings, lines=20, max_lines=20,
|
||||
label="System Info",
|
||||
show_label=False, interactive=False)
|
||||
with gr.Row():
|
||||
sys_info_refbtn = gr.Button(value="Refresh")
|
||||
|
||||
drp_gpu.change(self.runtime.update_xformers, inputs=[drp_gpu, chk_group_args], outputs=[chk_group_args])
|
||||
sys_info_refbtn.click(self.prelude.get_sys_info, None, txt_sys_info)
|
||||
|
||||
def create_subtab_model_management(self) -> None:
|
||||
with gr.TabItem('Model Management', elem_id="model_management_tab") as tab_batch:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
my_models = self.runtime.get_local_models()
|
||||
ds_my_models = gr.Dataset(
|
||||
components=[gr.HTML(visible=False, label='Cover', elem_id='my_model_cover'),
|
||||
gr.Textbox(visible=False, label='Name/Version'),
|
||||
gr.Textbox(visible=False, label='File Name'),
|
||||
gr.Textbox(visible=False, label='Hash'), gr.Textbox(visible=False, label='Creator'),
|
||||
gr.Textbox(visible=False, label='Type'), gr.Textbox(visible=False, label='NSFW'),
|
||||
gr.Textbox(visible=False, label='Trigger Words', elem_id='my_model_trigger_words')],
|
||||
elem_id='my_model_lib',
|
||||
label="My Models",
|
||||
headers=None,
|
||||
samples=my_models,
|
||||
samples_per_page=50)
|
||||
with gr.Column():
|
||||
html_model_prompt = gr.HTML(visible=True,
|
||||
value='<div style="height:400px;"><p>No Model Selected</p></div>')
|
||||
|
||||
with gr.Row():
|
||||
add = gr.Button(value="Add", variant="primary")
|
||||
# delete = gr.Button(value="Delete")
|
||||
with gr.Row():
|
||||
reset_btn = gr.Button(value="Reset")
|
||||
json_input = gr.Button(value="Load from JSON")
|
||||
png_input = gr.Button(value="Detect from image")
|
||||
png_input_area = gr.Image(label="Detect from image", elem_id="openpose_editor_input")
|
||||
bg_input = gr.Button(value="Add Background image")
|
||||
|
||||
def create_subtab_model_download(self) -> None:
|
||||
with gr.TabItem('Model Download', elem_id="model_download_tab") as tab_downloads:
|
||||
with gr.Row():
|
||||
with gr.Column(elem_id="col_model_list"):
|
||||
with gr.Row().style(equal_height=True):
|
||||
model_source_dropdown = gr.Dropdown(choices=["civitai", "liandange"],
|
||||
value=self.runtime.model_source,
|
||||
label="Select Model Source",
|
||||
type="value",
|
||||
show_label=True,
|
||||
elem_id="model_source").style(full_width=True)
|
||||
with gr.Row().style(equal_height=True):
|
||||
search_text = gr.Textbox(
|
||||
label="Model name",
|
||||
show_label=False,
|
||||
max_lines=1,
|
||||
placeholder="Enter model name",
|
||||
)
|
||||
btn_search = gr.Button("Search")
|
||||
|
||||
with gr.Row().style(equal_height=True):
|
||||
nsfw_checker = gr.Checkbox(label='NSFW', value=False, elem_id="chk_nsfw", interactive=True)
|
||||
model_type = gr.Radio(["All", "Checkpoint", "LORA", "TextualInversion", "Hypernetwork"],
|
||||
show_label=False, value='All', elem_id="rad_model_type",
|
||||
interactive=True).style(full_width=True)
|
||||
|
||||
images = self.runtime.get_images_html()
|
||||
self.runtime.ds_models = gr.Dataset(
|
||||
components=[gr.HTML(visible=False)],
|
||||
headers=None,
|
||||
type="values",
|
||||
label="Models",
|
||||
samples=images,
|
||||
samples_per_page=60,
|
||||
elem_id="model_dataset").style(type="gallery", container=True)
|
||||
|
||||
with gr.Column(elem_id="col_model_info"):
|
||||
with gr.Row():
|
||||
cover_gallery = gr.Gallery(label="Cover", show_label=False, visible=True).style(grid=[4],
|
||||
height="2")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
download_summary = gr.HTML('<div><span>No downloading tasks ongoing</span></div>')
|
||||
downloading_status = gr.Button(value=f"{self.refresh_symbol} Refresh Downloading Status",
|
||||
elem_id="ms_dwn_status")
|
||||
with gr.Row():
|
||||
model_dropdown = gr.Dropdown(choices=['Select Model'], label="Models", show_label=False,
|
||||
value='Select Model', elem_id='ms_dwn_button',
|
||||
interactive=True)
|
||||
|
||||
is_civitai_model_source_active = self.runtime.model_source == "civitai"
|
||||
with gr.Row(variant="panel"):
|
||||
dwn_button = gr.Button(value='Download',
|
||||
visible=is_civitai_model_source_active, elem_id='ms_dwn_button')
|
||||
open_url_in_browser_newtab_button = gr.HTML(
|
||||
value='<p style="text-align: center;">'
|
||||
'<a style="text-align: center;" href="https://models.paomiantv.cn/models" '
|
||||
'target="_blank">Download</a></p>',
|
||||
visible=not is_civitai_model_source_active)
|
||||
with gr.Row():
|
||||
model_info = gr.HTML(visible=True)
|
||||
|
||||
nsfw_checker.change(self.runtime.set_nsfw, inputs=[search_text, nsfw_checker, model_type],
|
||||
outputs=self.runtime.ds_models)
|
||||
|
||||
model_type.change(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models)
|
||||
|
||||
btn_search.click(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models)
|
||||
|
||||
self.runtime.ds_models.click(self.runtime.get_model_info,
|
||||
inputs=[self.runtime.ds_models],
|
||||
outputs=[
|
||||
cover_gallery,
|
||||
model_dropdown,
|
||||
model_info,
|
||||
open_url_in_browser_newtab_button
|
||||
])
|
||||
|
||||
dwn_button.click(self.runtime.download_model, inputs=[model_dropdown], outputs=[download_summary])
|
||||
downloading_status.click(self.runtime.get_downloading_status, inputs=[], outputs=[download_summary])
|
||||
|
||||
model_source_dropdown.change(self.switch_model_source,
|
||||
inputs=[model_source_dropdown],
|
||||
outputs=[self.runtime.ds_models, dwn_button, open_url_in_browser_newtab_button])
|
||||
|
||||
def request_restart(self, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
|
||||
print('request_restart: cmd_arg = ', self.runtime.cmdline_args)
|
||||
print('request_restart: sys.argv = ', sys.argv)
|
||||
|
||||
modules.shared.state.interrupt()
|
||||
modules.shared.state.need_restart = True
|
||||
|
||||
# reset args
|
||||
sys.argv = [sys.argv[0]]
|
||||
os.environ['COMMANDLINE_ARGS'] = ""
|
||||
print('remove', sys.argv)
|
||||
|
||||
for arg in self.runtime.cmdline_args:
|
||||
sys.argv.append(arg)
|
||||
|
||||
print('after', sys.argv)
|
||||
launch.prepare_environment()
|
||||
launch.start()
|
||||
|
||||
def switch_model_source(self, new_model_source: str):
|
||||
self.runtime.model_source = new_model_source
|
||||
show_download_button = self.runtime.model_source == "civitai"
|
||||
images = self.runtime.get_images_html()
|
||||
self.runtime.ds_models.samples = images
|
||||
return (
|
||||
gr.Dataset.update(samples=images),
|
||||
gr.Button.update(visible=show_download_button),
|
||||
gr.HTML.update(visible=not show_download_button)
|
||||
)
|
||||
|
||||
def introception(self) -> None:
|
||||
self.runtime.introception()
|
||||
|
|
@ -0,0 +1 @@
|
|||
__all__ = ["msai_downloader_manager"]
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
import asyncio
|
||||
import os.path
|
||||
import queue
|
||||
import time
|
||||
import requests
|
||||
import typing as t
|
||||
from threading import Thread, Lock
|
||||
|
||||
from scripts.download.msai_file_downloader import MiaoshouFileDownloader
|
||||
from scripts.logging.msai_logger import Logger
|
||||
from scripts.msai_utils.msai_singleton import MiaoshouSingleton
|
||||
import scripts.msai_utils.msai_toolkit as toolkit
|
||||
|
||||
|
||||
class DownloadingEntry(object):
|
||||
def __init__(self, target_url: str = None, local_file: str = None,
|
||||
local_directory: str = None, estimated_total_size: float = 0., expected_checksum: str = None):
|
||||
self._target_url = target_url
|
||||
self._local_file = local_file
|
||||
self._local_directory = local_directory
|
||||
self._expected_checksum = expected_checksum
|
||||
|
||||
self._estimated_total_size = estimated_total_size
|
||||
self._total_size = 0
|
||||
self._downloaded_size = 0
|
||||
|
||||
self._downloading = False
|
||||
self._failure = False
|
||||
|
||||
@property
|
||||
def target_url(self) -> str:
|
||||
return self._target_url
|
||||
|
||||
@property
|
||||
def local_file(self) -> str:
|
||||
return self._local_file
|
||||
|
||||
@property
|
||||
def local_directory(self) -> str:
|
||||
return self._local_directory
|
||||
|
||||
@property
|
||||
def expected_checksum(self) -> str:
|
||||
return self._expected_checksum
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
return self._total_size
|
||||
|
||||
@total_size.setter
|
||||
def total_size(self, sz: int) -> None:
|
||||
self._total_size = sz
|
||||
|
||||
@property
|
||||
def downloaded_size(self) -> int:
|
||||
return self._downloaded_size
|
||||
|
||||
@downloaded_size.setter
|
||||
def downloaded_size(self, sz: int) -> None:
|
||||
self._downloaded_size = sz
|
||||
|
||||
@property
|
||||
def estimated_size(self) -> float:
|
||||
return self._estimated_total_size
|
||||
|
||||
def is_downloading(self) -> bool:
|
||||
return self._downloading
|
||||
|
||||
def start_download(self) -> None:
|
||||
self._downloading = True
|
||||
|
||||
def update_final_status(self, result: bool) -> None:
|
||||
self._failure = result
|
||||
self._downloading = False
|
||||
|
||||
def is_failure(self) -> bool:
|
||||
return self._failure
|
||||
|
||||
|
||||
class AsyncLoopThread(Thread):
|
||||
def __init__(self):
|
||||
super(AsyncLoopThread, self).__init__(daemon=True)
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.logger = Logger()
|
||||
self.logger.info("looper thread is created")
|
||||
|
||||
def run(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.logger.info("looper thread is running")
|
||||
self.loop.run_forever()
|
||||
|
||||
|
||||
class MiaoshouDownloaderManager(metaclass=MiaoshouSingleton):
|
||||
_downloading_entries: t.Dict[str, DownloadingEntry] = None
|
||||
|
||||
def __init__(self):
|
||||
if self._downloading_entries is None:
|
||||
self._downloading_entries = {}
|
||||
self.message_queue = queue.Queue()
|
||||
|
||||
self.logger = Logger()
|
||||
self.looper = AsyncLoopThread()
|
||||
self.looper.start()
|
||||
self.logger.info("download manager is ready")
|
||||
self._mutex = Lock()
|
||||
|
||||
def consume_all_ready_messages(self) -> None:
|
||||
"""
|
||||
capture all enqueued messages, this method should not be used if you are iterating over the message queue
|
||||
:return:
|
||||
None
|
||||
:side-effect:
|
||||
update downloading entries' status
|
||||
"""
|
||||
while True:
|
||||
# self.logger.info("fetching the enqueued message")
|
||||
try:
|
||||
(aurl, finished_size, total_size) = self.message_queue.get(block=False, timeout=0.2)
|
||||
# self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}")
|
||||
try:
|
||||
self._mutex.acquire(blocking=True)
|
||||
self._downloading_entries[aurl].total_size = total_size
|
||||
self._downloading_entries[aurl].downloaded_size = finished_size
|
||||
finally:
|
||||
self._mutex.release()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def iterator(self) -> t.Tuple[float, float]:
|
||||
|
||||
while True:
|
||||
self.logger.info("waiting for incoming message")
|
||||
|
||||
try:
|
||||
(aurl, finished_size, total_size) = self.message_queue.get(block=True)
|
||||
self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}")
|
||||
try:
|
||||
self._mutex.acquire(blocking=True)
|
||||
self._downloading_entries[aurl].total_size = total_size
|
||||
self._downloading_entries[aurl].downloaded_size = finished_size
|
||||
|
||||
tasks_total_size = 0.
|
||||
tasks_finished_size = 0.
|
||||
|
||||
for e in self._downloading_entries.values():
|
||||
tasks_total_size += e.total_size
|
||||
tasks_finished_size += e.downloaded_size
|
||||
|
||||
yield tasks_finished_size, tasks_total_size
|
||||
finally:
|
||||
self._mutex.release()
|
||||
except queue.Empty:
|
||||
if len(asyncio.all_tasks(self.looper.loop)) == 0:
|
||||
self.logger.info("all downloading tasks finished")
|
||||
break
|
||||
|
||||
async def _submit_task(self, download_entry: DownloadingEntry) -> None:
|
||||
try:
|
||||
self._mutex.acquire(blocking=True)
|
||||
if download_entry.target_url in self._downloading_entries:
|
||||
self.logger.warn(f"{download_entry.target_url} is already downloading")
|
||||
return
|
||||
else:
|
||||
download_entry.start_download()
|
||||
self._downloading_entries[download_entry.target_url] = download_entry
|
||||
finally:
|
||||
self._mutex.release()
|
||||
|
||||
file_downloader = MiaoshouFileDownloader(
|
||||
target_url=download_entry.target_url,
|
||||
local_file=download_entry.local_file,
|
||||
local_directory=download_entry.local_directory,
|
||||
channel=self.message_queue if download_entry.estimated_size else None,
|
||||
estimated_total_length=download_entry.estimated_size,
|
||||
expected_checksum=download_entry.expected_checksum,
|
||||
)
|
||||
|
||||
result: bool = await self.looper.loop.run_in_executor(None, file_downloader.download_file)
|
||||
|
||||
try:
|
||||
self._mutex.acquire(blocking=True)
|
||||
self._downloading_entries[download_entry.target_url].update_final_status(result)
|
||||
finally:
|
||||
self._mutex.release()
|
||||
|
||||
def download(self, source_url: str, target_file: str, estimated_total_size: float,
|
||||
expected_checksum: str = None) -> None:
|
||||
target_dir = os.path.dirname(target_file)
|
||||
target_filename = os.path.basename(target_file)
|
||||
download_entry = DownloadingEntry(
|
||||
target_url=source_url,
|
||||
local_file=target_filename,
|
||||
local_directory=target_dir,
|
||||
estimated_total_size=estimated_total_size,
|
||||
expected_checksum=expected_checksum
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(self._submit_task(download_entry), self.looper.loop)
|
||||
|
||||
def tasks_summary(self) -> t.Tuple[int, int, str]:
|
||||
self.consume_all_ready_messages()
|
||||
|
||||
total_tasks_num = 0
|
||||
ongoing_tasks_num = 0
|
||||
failed_tasks_num = 0
|
||||
|
||||
try:
|
||||
description = "<div>"
|
||||
self._mutex.acquire(blocking=True)
|
||||
for name, entry in self._downloading_entries.items():
|
||||
if entry.estimated_size is None:
|
||||
continue
|
||||
|
||||
total_tasks_num += 1
|
||||
|
||||
if entry.total_size > 0.:
|
||||
description += f"<p>{entry.local_file} ({toolkit.get_readable_size(entry.total_size)}) : "
|
||||
else:
|
||||
description += f"<p>{entry.local_file} ({toolkit.get_readable_size(entry.estimated_size)}) : "
|
||||
|
||||
if entry.is_downloading():
|
||||
ongoing_tasks_num += 1
|
||||
finished_percent = entry.downloaded_size/entry.estimated_size * 100
|
||||
description += f'<span style="color:blue;font-weight:bold">{round(finished_percent, 2)} %</span>'
|
||||
elif entry.is_failure():
|
||||
failed_tasks_num += 1
|
||||
description += '<span style="color:red;font-weight:bold">failed!</span>'
|
||||
else:
|
||||
description += '<span style="color:green;font-weight:bold">finished</span>'
|
||||
description += "</p><br>"
|
||||
finally:
|
||||
self._mutex.release()
|
||||
pass
|
||||
|
||||
description += "</div>"
|
||||
overall = f"""
|
||||
<h4>
|
||||
<span style="color:blue;font-weight:bold">{ongoing_tasks_num}</span> ongoing,
|
||||
<span style="color:green;font-weight:bold">{total_tasks_num - ongoing_tasks_num - failed_tasks_num}</span> finished,
|
||||
<span style="color:red;font-weight:bold">{failed_tasks_num}</span> failed.
|
||||
</h4>
|
||||
<br>
|
||||
<br>
|
||||
"""
|
||||
|
||||
return ongoing_tasks_num, total_tasks_num, overall + description
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,226 @@
|
|||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import rehash
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from tqdm import tqdm
|
||||
from urllib3.util import Retry
|
||||
|
||||
import scripts.msai_utils.msai_toolkit as toolkit
|
||||
from scripts.logging.msai_logger import Logger
|
||||
|
||||
|
||||
class MiaoshouFileDownloader(object):
|
||||
CHUNK_SIZE = 1024 * 1024
|
||||
|
||||
def __init__(self, target_url: str = None,
|
||||
local_file: str = None, local_directory: str = None, estimated_total_length: float = 0.,
|
||||
expected_checksum: str = None,
|
||||
channel: queue.Queue = None,
|
||||
max_retries=3) -> None:
|
||||
self.logger = Logger()
|
||||
|
||||
self.target_url: str = target_url
|
||||
self.local_file: str = local_file
|
||||
self.local_directory = local_directory
|
||||
self.expected_checksum = expected_checksum
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.accept_ranges: bool = False
|
||||
self.estimated_content_length = estimated_total_length
|
||||
self.content_length: int = -1
|
||||
self.finished_chunk_size: int = 0
|
||||
|
||||
self.channel = channel # for communication
|
||||
|
||||
# Support 3 retries and backoff
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
backoff_factor=1,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
method_whitelist=["HEAD", "GET", "OPTIONS"]
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session = requests.Session()
|
||||
self.session.mount("https://", adapter)
|
||||
self.session.mount("http://", adapter)
|
||||
|
||||
# inform message receiver at once
|
||||
if self.channel:
|
||||
self.channel.put_nowait((
|
||||
self.target_url,
|
||||
self.finished_chunk_size,
|
||||
self.estimated_content_length,
|
||||
))
|
||||
|
||||
# Head request to get file-length and check whether it supports ranges.
|
||||
def get_file_info_from_server(self, target_url: str) -> t.Tuple[bool, float]:
|
||||
try:
|
||||
headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip
|
||||
response = requests.head(target_url, headers=headers, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content_length = None
|
||||
if "Content-Length" in response.headers:
|
||||
content_length = int(response.headers['Content-Length'])
|
||||
accept_ranges = (response.headers.get("Accept-Ranges") == "bytes")
|
||||
return accept_ranges, float(content_length)
|
||||
except Exception as ex:
|
||||
self.logger.info(f"HEAD Request Error: {ex}")
|
||||
return False, self.estimated_content_length
|
||||
|
||||
def download_file_full(self, target_url: str, local_filepath: str) -> t.Optional[str]:
|
||||
try:
|
||||
checksum = rehash.sha256()
|
||||
headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip
|
||||
|
||||
with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN",
|
||||
desc=os.path.basename(self.local_file)) as progressbar, \
|
||||
self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \
|
||||
open(local_filepath, 'wb') as file_out:
|
||||
response.raise_for_status()
|
||||
|
||||
for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE):
|
||||
file_out.write(chunk)
|
||||
checksum.update(chunk)
|
||||
progressbar.update(len(chunk))
|
||||
self.update_progress(len(chunk))
|
||||
|
||||
except Exception as ex:
|
||||
self.logger.info(f"Download error: {ex}")
|
||||
return None
|
||||
|
||||
return checksum.hexdigest()
|
||||
|
||||
def download_file_resumable(self, target_url: str, local_filepath: str) -> t.Optional[str]:
|
||||
# Always go off the checkpoint as the file was flushed before writing.
|
||||
download_checkpoint = local_filepath + ".downloading"
|
||||
try:
|
||||
resume_point, checksum = pickle.load(open(download_checkpoint, "rb"))
|
||||
assert os.path.exists(local_filepath) # catch checkpoint without file
|
||||
self.logger.info("File already exists, resuming download.")
|
||||
except Exception as e:
|
||||
self.logger.error(f"failed to load downloading checkpoint - {download_checkpoint} due to {e}")
|
||||
resume_point = 0
|
||||
checksum = rehash.sha256()
|
||||
if os.path.exists(local_filepath):
|
||||
os.remove(local_filepath)
|
||||
Path(local_filepath).touch()
|
||||
|
||||
assert (resume_point < self.content_length)
|
||||
|
||||
self.finished_chunk_size = resume_point
|
||||
|
||||
# Support resuming
|
||||
headers = {"Range": f"bytes={resume_point}-", "Accept-Encoding": "identity"}
|
||||
try:
|
||||
with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN",
|
||||
desc=os.path.basename(self.local_file)) as progressbar, \
|
||||
self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \
|
||||
open(local_filepath, 'r+b') as file_out:
|
||||
response.raise_for_status()
|
||||
self.update_progress(resume_point)
|
||||
file_out.seek(resume_point)
|
||||
|
||||
for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE):
|
||||
file_out.write(chunk)
|
||||
file_out.flush()
|
||||
resume_point += len(chunk)
|
||||
checksum.update(chunk)
|
||||
pickle.dump((resume_point, checksum), open(download_checkpoint, "wb"))
|
||||
progressbar.update(len(chunk))
|
||||
self.update_progress(len(chunk))
|
||||
|
||||
# Only remove checkpoint at full size in case connection cut
|
||||
if os.path.getsize(local_filepath) == self.content_length:
|
||||
os.remove(download_checkpoint)
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as ex:
|
||||
self.logger.error(f"Download error: {ex}")
|
||||
return None
|
||||
|
||||
return checksum.hexdigest()
|
||||
|
||||
def update_progress(self, finished_chunk_size: int) -> None:
|
||||
self.finished_chunk_size += finished_chunk_size
|
||||
|
||||
if self.channel:
|
||||
self.channel.put_nowait((
|
||||
self.target_url,
|
||||
self.finished_chunk_size,
|
||||
self.content_length,
|
||||
))
|
||||
|
||||
# In order to avoid leaving extra garbage meta files behind this
|
||||
# will overwrite any existing files found at local_file. If you don't want this
|
||||
# behaviour you can handle this externally.
|
||||
# local_file and local_directory could write to unexpected places if the source
|
||||
# is untrusted, be careful!
|
||||
def download_file(self) -> bool:
|
||||
success = False
|
||||
try:
|
||||
# Need to rebuild local_file_final each time in case of different urls
|
||||
if not self.local_file:
|
||||
specific_local_file = os.path.basename(urlparse(self.target_url).path)
|
||||
else:
|
||||
specific_local_file = self.local_file
|
||||
|
||||
download_temp_dir = toolkit.get_user_temp_dir()
|
||||
toolkit.assert_user_temp_dir()
|
||||
|
||||
if self.local_directory:
|
||||
os.makedirs(self.local_directory, exist_ok=True)
|
||||
|
||||
specific_local_file = os.path.join(download_temp_dir, specific_local_file)
|
||||
|
||||
self.accept_ranges, self.content_length = self.get_file_info_from_server(self.target_url)
|
||||
self.logger.info(f"Accept-Ranges: {self.accept_ranges}. content length: {self.content_length}")
|
||||
if self.accept_ranges and self.content_length:
|
||||
download_method = self.download_file_resumable
|
||||
self.logger.info("Server supports resume")
|
||||
else:
|
||||
download_method = self.download_file_full
|
||||
self.logger.info(f"Server doesn't support resume.")
|
||||
|
||||
for i in range(self.max_retries):
|
||||
self.logger.info(f"Download Attempt {i + 1}")
|
||||
checksum = download_method(self.target_url, specific_local_file)
|
||||
if checksum:
|
||||
match = ""
|
||||
if self.expected_checksum:
|
||||
match = ", Checksum Match"
|
||||
|
||||
if self.expected_checksum and self.expected_checksum != checksum:
|
||||
self.logger.info(f"Checksum doesn't match. Calculated {checksum} "
|
||||
f"Expecting: {self.expected_checksum}")
|
||||
else:
|
||||
self.logger.info(f"Download successful{match}. Checksum {checksum}")
|
||||
success = True
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
if success:
|
||||
print("*" * 80)
|
||||
self.logger.info(f"{self.target_url} [DOWNLOADED COMPLETELY]")
|
||||
print("*" * 80)
|
||||
if self.local_directory:
|
||||
target_local_file = os.path.join(self.local_directory, self.local_file)
|
||||
else:
|
||||
target_local_file = self.local_file
|
||||
toolkit.move_file(specific_local_file, target_local_file)
|
||||
else:
|
||||
print("*" * 80)
|
||||
self.logger.info(f"{self.target_url} [ FAILED ]")
|
||||
print("*" * 80)
|
||||
|
||||
except Exception as ex:
|
||||
self.logger.info(f"Unexpected Error: {ex}") # Only from block above
|
||||
|
||||
return success
|
||||
|
|
@ -0,0 +1 @@
|
|||
__all__ = ["msai_logger"]
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
import datetime
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from scripts.msai_utils.msai_singleton import MiaoshouSingleton
|
||||
|
||||
|
||||
class Logger(metaclass=MiaoshouSingleton):
|
||||
_dataset = None
|
||||
|
||||
KEY_TRACE_PATH = "trace_path"
|
||||
KEY_INFO = "info"
|
||||
KEY_ERROR = "error"
|
||||
KEY_JOB = "job"
|
||||
|
||||
def _do_init(self, log_folder: str, disable_console_output: bool = False) -> None:
|
||||
# Setup trace_path with empty string by default, it will be assigned with valid content if needed
|
||||
self._dataset = {Logger.KEY_TRACE_PATH: ""}
|
||||
|
||||
print(f"logs_location: {log_folder}")
|
||||
os.makedirs(log_folder, exist_ok=True)
|
||||
|
||||
# Setup basic logging configuration
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
filemode='w',
|
||||
format='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s')
|
||||
|
||||
# Setup info logging
|
||||
self._dataset[Logger.KEY_INFO] = logging.getLogger(Logger.KEY_INFO)
|
||||
msg_handler = logging.FileHandler(os.path.join(log_folder, "info.log"),
|
||||
"a",
|
||||
encoding="UTF-8")
|
||||
msg_handler.setLevel(logging.INFO)
|
||||
msg_handler.setFormatter(
|
||||
logging.Formatter(fmt='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s'))
|
||||
self._dataset[Logger.KEY_INFO].addHandler(msg_handler)
|
||||
|
||||
# Setup error logging
|
||||
self._dataset[Logger.KEY_ERROR] = logging.getLogger(Logger.KEY_ERROR)
|
||||
error_handler = logging.FileHandler(
|
||||
os.path.join(log_folder, f'error_{datetime.date.today().strftime("%Y%m%d")}.log'),
|
||||
mode="a",
|
||||
encoding='UTF-8')
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
error_handler.setFormatter(
|
||||
logging.Formatter(
|
||||
fmt=f"{self._dataset.get('trace_path')}:\n "
|
||||
f"%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s"))
|
||||
self._dataset[Logger.KEY_ERROR].addHandler(error_handler)
|
||||
|
||||
# Setup job logging
|
||||
self._dataset[Logger.KEY_JOB] = logging.getLogger(Logger.KEY_JOB)
|
||||
job_handler = logging.FileHandler(os.path.join(log_folder, "jobs.log"),
|
||||
mode="a",
|
||||
encoding="UTF-8")
|
||||
self._dataset[Logger.KEY_JOB].addHandler(job_handler)
|
||||
|
||||
if disable_console_output:
|
||||
for k in [Logger.KEY_INFO, Logger.KEY_JOB, Logger.KEY_ERROR]:
|
||||
l: logging.Logger = self._dataset[k]
|
||||
l.propagate = not disable_console_output
|
||||
|
||||
def __init__(self, log_folder: str = None, disable_console_output: bool = False) -> None:
|
||||
if self._dataset is None:
|
||||
try:
|
||||
self._do_init(log_folder=log_folder, disable_console_output=disable_console_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
def update_path_info(self, current_path: str) -> None:
|
||||
self._dataset[Logger.KEY_TRACE_PATH] = current_path
|
||||
|
||||
def callback_func(self, exc_type: t.Any, exc_value: t.Any, exc_tracback: t.Any) -> None:
|
||||
self._dataset[Logger.KEY_JOB].error(f"job failed for {self._dataset[Logger.KEY_TRACE_PATH]}")
|
||||
self._dataset[Logger.KEY_INFO].error(f"{self._dataset[Logger.KEY_TRACE_PATH]}\n, callback_func: ",
|
||||
exc_info=(exc_type, exc_value, exc_tracback))
|
||||
|
||||
def debug(self, fmt, *args, **kwargs) -> None:
|
||||
l: logging.Logger = self._dataset[Logger.KEY_INFO]
|
||||
l.debug(fmt, *args, **kwargs, stacklevel=2)
|
||||
|
||||
def info(self, fmt, *args, **kwargs) -> None:
|
||||
l: logging.Logger = self._dataset[Logger.KEY_INFO]
|
||||
l.info(fmt, *args, **kwargs, stacklevel=2)
|
||||
|
||||
def warn(self, fmt, *args, **kwargs) -> None:
|
||||
l: logging.Logger = self._dataset[Logger.KEY_INFO]
|
||||
l.warn(fmt, *args, **kwargs, stacklevel=2)
|
||||
|
||||
def error(self, fmt, *args, **kwargs) -> None:
|
||||
l: logging.Logger = self._dataset[Logger.KEY_ERROR]
|
||||
l.error(fmt, *args, **kwargs, stacklevel=2)
|
||||
|
||||
def job(self, fmt, *args, **kwargs) -> None:
|
||||
l: logging.Logger = self._dataset[Logger.KEY_JOB]
|
||||
l.info(fmt, *args, **kwargs, stacklevel=2)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
import modules
|
||||
import modules.scripts as scripts
|
||||
|
||||
from scripts.assistant.miaoshou import MiaoShouAssistant
|
||||
|
||||
|
||||
class MiaoshouScript(scripts.Script):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def title(self):
|
||||
return "Miaoshou Assistant"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
return ()
|
||||
|
||||
|
||||
assistant = MiaoShouAssistant()
|
||||
modules.script_callbacks.on_ui_tabs(assistant.on_event_ui_tabs_opened)
|
||||
|
|
@ -0,0 +1 @@
|
|||
__all__ = ["msai_singleton", "msai_toolkit"]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
class MiaoshouSingleton(type):
|
||||
_instances = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(MiaoshouSingleton, cls).__call__(*args, **kwargs)
|
||||
cls._instances[cls].__init__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def read_json(file) -> t.Any:
|
||||
try:
|
||||
with open(file, "r", encoding="utf-8-sig") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def write_json(file, content) -> None:
|
||||
try:
|
||||
with open(file, 'w') as f:
|
||||
json.dump(content, f, indent=4)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def get_args(args) -> t.List[str]:
|
||||
parameters = []
|
||||
idx = 0
|
||||
for arg in args:
|
||||
if idx == 0 and '--' not in arg:
|
||||
pass
|
||||
elif '--' in arg:
|
||||
parameters.append(rf'{arg}')
|
||||
idx += 1
|
||||
else:
|
||||
parameters[idx - 1] = parameters[idx - 1] + ' ' + rf'{arg}'
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def get_readable_size(size: int, precision=2) -> str:
|
||||
if size is None:
|
||||
return ""
|
||||
|
||||
suffixes = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||
suffixIndex = 0
|
||||
while size >= 1024 and suffixIndex < len(suffixes):
|
||||
suffixIndex += 1 # increment the index of the suffix
|
||||
size = size / 1024.0 # apply the division
|
||||
return "%.*f%s" % (precision, size, suffixes[suffixIndex])
|
||||
|
||||
|
||||
def get_file_last_modified_time(path_to_file: str) -> datetime:
|
||||
if path_to_file is None:
|
||||
return datetime.now()
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return datetime.fromtimestamp(os.path.getmtime(path_to_file))
|
||||
else:
|
||||
stat = os.stat(path_to_file)
|
||||
return datetime.fromtimestamp(stat.st_mtime)
|
||||
|
||||
|
||||
def get_not_found_image_url() -> str:
|
||||
return "https://msdn.miaoshouai.com/msdn/userimage/not-found.svg"
|
||||
|
||||
|
||||
def get_user_temp_dir() -> str:
|
||||
return os.path.join(Path.home().absolute(), ".miaoshou_assistant_download")
|
||||
|
||||
|
||||
def assert_user_temp_dir() -> None:
|
||||
os.makedirs(get_user_temp_dir(), exist_ok=True)
|
||||
|
||||
|
||||
def move_file(src: str, dst: str) -> None:
|
||||
if not src or not dst:
|
||||
return
|
||||
|
||||
if not os.path.exists(src):
|
||||
return
|
||||
|
||||
if os.path.exists(dst):
|
||||
os.remove(dst)
|
||||
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
shutil.move(src, dst)
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
__all__ = ["msai_prelude", "msai_runtime"]
|
||||
|
||||
from . import msai_prelude as prelude
|
||||
|
||||
prelude.MiaoshouPrelude().load()
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
import platform
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
import launch
|
||||
from scripts.logging.msai_logger import Logger
|
||||
from scripts.msai_utils import msai_toolkit as toolkit
|
||||
from scripts.msai_utils.msai_singleton import MiaoshouSingleton
|
||||
|
||||
|
||||
class MiaoshouPrelude(metaclass=MiaoshouSingleton):
|
||||
_dataset = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Potential race condition, not call in multithread environment
|
||||
if MiaoshouPrelude._dataset is None:
|
||||
self._init_constants()
|
||||
|
||||
MiaoshouPrelude._dataset = {
|
||||
"log_folder": os.path.join(self.ext_folder, "logs")
|
||||
}
|
||||
|
||||
disable_log_console_output: bool = False
|
||||
if self.all_settings.get("boot_settings"):
|
||||
if self.all_settings["boot_settings"].get("disable_log_console_output") is not None:
|
||||
disable_log_console_output = self.all_settings["boot_settings"].get("disable_log_console_output")
|
||||
|
||||
self._logger = Logger(self._dataset["log_folder"], disable_console_output=disable_log_console_output)
|
||||
|
||||
def _init_constants(self) -> None:
|
||||
self._api_url = {
|
||||
"civitai": "https://civitai.com/api/v1/models",
|
||||
"liandange": "https://model-api.paomiantv.cn/model/api/models",
|
||||
}
|
||||
self._ext_folder = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".."))
|
||||
self._setting_file = os.path.join(self.ext_folder, "configs", "settings.json")
|
||||
self._model_hash_file = os.path.join(self.ext_folder, "configs", "model_hash.json")
|
||||
self._model_json = {
|
||||
'civitai': os.path.join(self.ext_folder, 'configs', 'civitai_models.json'),
|
||||
'liandange': os.path.join(self.ext_folder, 'configs', 'liandange_models.json'),
|
||||
}
|
||||
self._checkboxes = {
|
||||
'Enable xFormers': '--xformers',
|
||||
'No Half': '--no-half',
|
||||
'No Half VAE': '--no-half-vae',
|
||||
'Enable API': '--api',
|
||||
'Auto Launch': '--autolaunch',
|
||||
'Allow Local Network Access': '--listen',
|
||||
}
|
||||
|
||||
self._gpu_setting = {
|
||||
'CPU Only': '--precision full --no-half --use-cpu SD GFPGAN BSRGAN ESRGAN SCUNet CodeFormer --all',
|
||||
'GTX 16xx': '--lowvram --xformers --precision full --no-half',
|
||||
'Low: 4-6G VRAM': '--xformers --lowvram',
|
||||
'Med: 6-8G VRAM': '--xformers --medvram',
|
||||
'Normal: 8+G VRAM': '',
|
||||
}
|
||||
|
||||
self._theme_setting = {
|
||||
'Auto': '',
|
||||
'Light Mode': '--theme=light',
|
||||
'Dark Mode': '--theme=dark',
|
||||
}
|
||||
|
||||
@property
|
||||
def ext_folder(self) -> str:
|
||||
return self._ext_folder
|
||||
|
||||
@property
|
||||
def log_folder(self) -> str:
|
||||
return self._dataset.get("log_folder")
|
||||
|
||||
@property
|
||||
def all_settings(self) -> t.Any:
|
||||
return toolkit.read_json(self._setting_file)
|
||||
|
||||
@property
|
||||
def boot_settings(self) -> t.Any:
|
||||
all_setting = self.all_settings
|
||||
if all_setting:
|
||||
return all_setting['boot_settings']
|
||||
else:
|
||||
return None
|
||||
|
||||
def api_url(self, model_source: str) -> t.Optional[str]:
|
||||
return self._api_url.get(model_source)
|
||||
|
||||
@property
|
||||
def setting_file(self) -> str:
|
||||
return self._setting_file
|
||||
|
||||
@property
|
||||
def model_hash_file(self) -> str:
|
||||
return self._model_hash_file
|
||||
|
||||
@property
|
||||
def checkboxes(self) -> t.Dict[str, str]:
|
||||
return self._checkboxes
|
||||
|
||||
@property
|
||||
def gpu_setting(self) -> t.Dict[str, str]:
|
||||
return self._gpu_setting
|
||||
|
||||
@property
|
||||
def theme_setting(self) -> t.Dict[str, str]:
|
||||
return self._theme_setting
|
||||
|
||||
@property
|
||||
def model_json(self) -> t.Dict[str, t.Any]:
|
||||
return self._model_json
|
||||
|
||||
def update_model_json(self, site: str, models: t.Dict[str, t.Any]) -> None:
|
||||
if self._model_json.get(site) is None:
|
||||
self._logger.error(f"cannot save model info for {site}")
|
||||
return
|
||||
|
||||
self._logger.info(f"{self._model_json[site]} updated")
|
||||
toolkit.write_json(self._model_json[site], models)
|
||||
|
||||
def load(self) -> None:
|
||||
self._logger.info("start to do prelude")
|
||||
self._logger.info(f"cmdline args: {' '.join(sys.argv[1:])}")
|
||||
|
||||
@classmethod
|
||||
def get_sys_info(cls) -> str:
|
||||
sys_info = 'System Information\n\n'
|
||||
|
||||
sys_info += r'OS Name: {0} {1}'.format(platform.system(), platform.release()) + '\n'
|
||||
sys_info += r'OS Version: {0}'.format(platform.version()) + '\n'
|
||||
sys_info += r'WebUI Version: {0}'.format(
|
||||
f'https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{launch.commit_hash()}') + '\n'
|
||||
sys_info += r'Torch Version: {0}'.format(getattr(torch, '__long_version__', torch.__version__)) + '\n'
|
||||
sys_info += r'Python Version: {0}'.format(sys.version) + '\n\n'
|
||||
sys_info += r'CPU: {0}'.format(platform.processor()) + '\n'
|
||||
sys_info += r'CPU Cores: {0}/{1}'.format(psutil.cpu_count(logical=False), psutil.cpu_count(logical=True)) + '\n'
|
||||
|
||||
# FIXME: should uncomment line below and remove my own workaround for MacOS
|
||||
# sys_info += r'CPU Frequency: {0} GHz'.format(round(psutil.cpu_freq().max/1000,2)) + '\n'
|
||||
# workaround: let my macbook M1 happy
|
||||
sys_info += r'CPU Frequency: {0} GHz'.format(round(2540.547 / 1000, 2)) + '\n'
|
||||
|
||||
sys_info += r'CPU Usage: {0}%'.format(psutil.cpu_percent()) + '\n\n'
|
||||
sys_info += r'RAM: {0}'.format(toolkit.get_readable_size(psutil.virtual_memory().total)) + '\n'
|
||||
sys_info += r'Memory Usage: {0}%'.format(psutil.virtual_memory().percent) + '\n\n'
|
||||
for i in range(torch.cuda.device_count()):
|
||||
sys_info += r'Graphics Card{0}: {1} ({2})'.format(i, torch.cuda.get_device_properties(i).name,
|
||||
toolkit.get_readable_size(
|
||||
torch.cuda.get_device_properties(
|
||||
i).total_memory)) + '\n'
|
||||
sys_info += r'Available VRAM: {0}'.format(toolkit.get_readable_size(torch.cuda.mem_get_info(i)[0])) + '\n'
|
||||
|
||||
return sys_info
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,630 @@
|
|||
import datetime
|
||||
import fileinput
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
|
||||
import modules
|
||||
from modules.sd_models import CheckpointInfo
|
||||
from scripts.download.msai_downloader_manager import MiaoshouDownloaderManager
|
||||
from scripts.logging.msai_logger import Logger
|
||||
from scripts.msai_utils import msai_toolkit as toolkit
|
||||
from scripts.runtime.msai_prelude import MiaoshouPrelude
|
||||
|
||||
|
||||
class MiaoshouRuntime(object):
|
||||
def __init__(self):
|
||||
self.cmdline_args: t.List[str] = None
|
||||
self.logger = Logger()
|
||||
self.prelude = MiaoshouPrelude()
|
||||
self._old_additional: str = None
|
||||
self._model_set: t.List[t.Dict] = None
|
||||
self._model_set_last_access_time: datetime.datetime = None
|
||||
self._ds_models: gr.Dataset = None
|
||||
self._allow_nsfw: bool = False
|
||||
self._model_source: str = "civitai" # civitai is the default model source
|
||||
|
||||
# TODO: may be owned by downloader class
|
||||
self.model_files = []
|
||||
|
||||
self.downloader_manager = MiaoshouDownloaderManager()
|
||||
|
||||
def get_default_args(self, commandline_args: t.List[str] = None):
|
||||
if commandline_args is None:
|
||||
commandline_args: t.List[str] = toolkit.get_args(sys.argv[1:])
|
||||
self.cmdline_args = commandline_args
|
||||
self.logger.info(f"default commandline args: {commandline_args}")
|
||||
|
||||
checkbox_values = []
|
||||
additional_args = ""
|
||||
saved_setting = self.prelude.boot_settings
|
||||
|
||||
gpu = saved_setting.get('drp_args_vram')
|
||||
theme = saved_setting.get('drp_args_theme')
|
||||
port = saved_setting.get('txt_args_listen_port')
|
||||
|
||||
for arg in commandline_args:
|
||||
if 'theme' in arg:
|
||||
theme = [k for k, v in self.prelude.theme_setting.items() if v == arg][0]
|
||||
if 'port' in arg:
|
||||
port = arg.split(' ')[-1]
|
||||
|
||||
for chk in self.prelude.checkboxes:
|
||||
for arg in commandline_args:
|
||||
if self.prelude.checkboxes[chk] == arg:
|
||||
checkbox_values.append(chk)
|
||||
|
||||
gpu_arg_list = [f'--{i.strip()}' for i in ' '.join(list(self.prelude.gpu_setting.values())).split('--')]
|
||||
for arg in commandline_args:
|
||||
if 'port' not in arg \
|
||||
and arg not in list(self.prelude.theme_setting.values()) \
|
||||
and arg not in list(self.prelude.checkboxes.values()) \
|
||||
and arg not in gpu_arg_list:
|
||||
additional_args += (' ' + rf'{arg}')
|
||||
|
||||
self._old_additional = additional_args
|
||||
webui_ver = saved_setting['drp_choose_version']
|
||||
|
||||
return gpu, theme, port, checkbox_values, additional_args.replace('\\', '\\\\').strip(), webui_ver
|
||||
|
||||
def add_arg(self, args: str = "") -> None:
|
||||
for arg in args.split('--'):
|
||||
self.logger.info(f'add arg: {arg.strip()}')
|
||||
if f"--{arg.strip()}" not in self.cmdline_args and arg.strip() != '':
|
||||
self.cmdline_args.append(f'--{arg.strip()}')
|
||||
|
||||
def remove_arg(self, args: str = "") -> None:
|
||||
arg_keywords = ['port', 'theme']
|
||||
|
||||
for arg in args.split('--'):
|
||||
if arg in arg_keywords:
|
||||
for cmdl in self.cmdline_args:
|
||||
if arg in cmdl:
|
||||
self.cmdline_args.remove(cmdl)
|
||||
break
|
||||
elif f'--{arg.strip()}' in self.cmdline_args and arg.strip() != '':
|
||||
print(f"remove args:{arg.strip()}")
|
||||
self.cmdline_args.remove(f'--{arg.strip()}')
|
||||
|
||||
def get_final_args(self, gpu, theme, port, checkgroup, more_args) -> None:
|
||||
# gpu settings
|
||||
for s1 in self.prelude.gpu_setting:
|
||||
if s1 in gpu:
|
||||
for s2 in self.prelude.gpu_setting:
|
||||
if s2 != s1:
|
||||
self.remove_arg(self.prelude.gpu_setting[s2])
|
||||
self.add_arg(self.prelude.gpu_setting[s1])
|
||||
|
||||
if port != '7860':
|
||||
self.add_arg(f'--port {port}')
|
||||
else:
|
||||
self.remove_arg('--port')
|
||||
|
||||
# theme settings
|
||||
self.remove_arg('--theme')
|
||||
for t in self.prelude.theme_setting:
|
||||
if t == theme:
|
||||
self.add_arg(self.prelude.theme_setting[t])
|
||||
break
|
||||
|
||||
# check box settings
|
||||
for chked in checkgroup:
|
||||
self.logger.info(f'checked:{self.prelude.checkboxes[chked]}')
|
||||
self.add_arg(self.prelude.checkboxes[chked])
|
||||
|
||||
for unchk in list(set(list(self.prelude.checkboxes.keys())) - set(checkgroup)):
|
||||
print(f'unchecked:{unchk}')
|
||||
self.remove_arg(self.prelude.checkboxes[unchk])
|
||||
|
||||
# additional commandline settings
|
||||
self.remove_arg(self._old_additional)
|
||||
self.add_arg(more_args.replace('\\\\', '\\'))
|
||||
self._old_additional = more_args.replace('\\\\', '\\')
|
||||
|
||||
def fetch_all_models(self) -> t.List[t.Dict]:
|
||||
endpoint_url = self.prelude.api_url(self.model_source)
|
||||
if endpoint_url is None:
|
||||
self.logger.error(f"{self.model_source} is not supported")
|
||||
return []
|
||||
|
||||
self.logger.info(f"start to fetch model info from '{self.model_source}':{endpoint_url}")
|
||||
|
||||
limit_threshold = 100
|
||||
|
||||
all_set = []
|
||||
response = requests.get(endpoint_url + f'?page=1&limit={limit_threshold}')
|
||||
num_of_pages = response.json()['metadata']['totalPages']
|
||||
self.logger.info(f"total pages = {num_of_pages}")
|
||||
|
||||
continuous_error_counts = 0
|
||||
|
||||
for p in range(1, num_of_pages + 1):
|
||||
try:
|
||||
response = requests.get(endpoint_url + f'?page={p}&limit={limit_threshold}')
|
||||
payload = response.json()
|
||||
if payload.get("success") is not None and not payload.get("success"):
|
||||
self.logger.error(f"failed to fetch page[{p + 1}]")
|
||||
continuous_error_counts += 1
|
||||
if continuous_error_counts > 10:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
continuous_error_counts = 0 # reset error flag
|
||||
self.logger.debug(f"start to process page[{p}]")
|
||||
|
||||
for model in payload['items']:
|
||||
all_set.append(model)
|
||||
|
||||
self.logger.debug(f"page[{p}] : {len(payload['items'])} items added")
|
||||
except Exception as e:
|
||||
self.logger.error(f"failed to fetch page[{p + 1}] due to {e}")
|
||||
time.sleep(3)
|
||||
|
||||
if len(all_set) > 0:
|
||||
self.prelude.update_model_json(self.model_source, all_set)
|
||||
else:
|
||||
self.logger.error("fetch_all_models: emtpy body received")
|
||||
|
||||
return all_set
|
||||
|
||||
def refresh_all_models(self) -> None:
|
||||
if self.fetch_all_models():
|
||||
if self.ds_models:
|
||||
self.ds_models.samples = self.model_set
|
||||
self.ds_models.update(samples=self.model_set)
|
||||
else:
|
||||
self.logger.error(f"ds models is null")
|
||||
|
||||
def get_images_html(self, search: str = '', model_type: str = 'All') -> t.List[str]:
|
||||
self.logger.info(f"get_image_html: model_type = {model_type}, and search pattern = '{search}'")
|
||||
|
||||
model_cover_thumbnails = []
|
||||
model_format = []
|
||||
|
||||
if self.model_set is None:
|
||||
self.logger.error("model_set is null")
|
||||
return []
|
||||
|
||||
self.logger.info(f"{len(self.model_set)} items inside '{self.model_source}'")
|
||||
|
||||
search = search.lower()
|
||||
for model in self.model_set:
|
||||
try:
|
||||
if model.get('type') is not None \
|
||||
and model.get('type') not in model_format:
|
||||
model_format.append(model['type'])
|
||||
|
||||
if search == '' or \
|
||||
(model.get('name') is not None and search in model.get('name').lower()) \
|
||||
or (model.get('description') is not None and search in model.get('description').lower()):
|
||||
|
||||
if (model_type == 'All' or model_type in model.get('type')) \
|
||||
and (self.allow_nsfw or (not self.allow_nsfw and not model.get('nsfw'))):
|
||||
model_cover_thumbnails.append([
|
||||
[f"""
|
||||
<div style="display: flex; align-items: center;">
|
||||
<div id="{str(model.get('id'))}" style="margin-right: 10px;" class="model-item">
|
||||
<img src="{model['modelVersions'][0]['images'][0]['url'].replace('width=450', 'width=100')}" style="width:100px;">
|
||||
</div>
|
||||
<div style="flex:1; width:100px;">
|
||||
<h3 style="text-align:left; word-wrap:break-word;">{model.get('name')}</h3>
|
||||
<p style="text-align:left;">Type: {model.get('type')}</p>
|
||||
<p style="text-align:left;">Rating: {model.get('stats')['rating']}</p>
|
||||
</div>
|
||||
</div>
|
||||
"""],
|
||||
model['id']])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return model_cover_thumbnails
|
||||
|
||||
# TODO: add typing hint
|
||||
def update_boot_settings(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
|
||||
boot_settings = self.prelude.boot_settings
|
||||
boot_settings['drp_args_vram'] = drp_gpu
|
||||
boot_settings["drp_args_theme"] = drp_theme
|
||||
boot_settings['txt_args_listen_port'] = txt_listen_port
|
||||
for chk in chk_group_args:
|
||||
self.logger.debug(chk)
|
||||
boot_settings[chk] = self.prelude.checkboxes[chk]
|
||||
boot_settings['txt_args_more'] = additional_args
|
||||
boot_settings['drp_choose_version'] = version
|
||||
|
||||
all_settings = self.prelude.all_settings
|
||||
all_settings['boot_settings'] = boot_settings
|
||||
|
||||
toolkit.write_json(self.prelude.setting_file, all_settings)
|
||||
|
||||
def get_all_models(self, site: str) -> t.Any:
|
||||
return toolkit.read_json(self.prelude.model_json[site])
|
||||
|
||||
def update_model_json(self, site: str, models: t.Any) -> None:
|
||||
toolkit.write_json(self.prelude.model_json[site], models)
|
||||
|
||||
def get_hash_from_json(self, chk_point: CheckpointInfo) -> CheckpointInfo:
|
||||
model_hashes = toolkit.read_json(self.prelude.model_hash_file)
|
||||
|
||||
if len(model_hashes) == 0 or chk_point.title not in model_hashes.keys():
|
||||
chk_point.shorthash = chk_point.calculate_shorthash()
|
||||
model_hashes[chk_point.title] = chk_point.shorthash
|
||||
toolkit.write_json(self.prelude.model_hash_file, model_hashes)
|
||||
else:
|
||||
chk_point.shorthash = model_hashes[chk_point.title]
|
||||
|
||||
return chk_point
|
||||
|
||||
def get_local_models(self) -> t.List[t.Any]:
|
||||
models = []
|
||||
|
||||
for file in modules.sd_models.checkpoint_tiles():
|
||||
chkpt_info = modules.sd_models.get_closet_checkpoint_match(file)
|
||||
if chkpt_info.sha256 is None and chkpt_info.shorthash is None:
|
||||
chkpt_info = self.get_hash_from_json(chkpt_info)
|
||||
|
||||
fname = re.sub(r'\[.*?\]', "", chkpt_info.title)
|
||||
model_info = self.search_model_by_hash(chkpt_info.sha256, chkpt_info.shorthash, fname)
|
||||
if model_info is not None:
|
||||
models.append(model_info)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"{chkpt_info.title}, {chkpt_info.hash}, {chkpt_info.shorthash}, {chkpt_info.sha256}")
|
||||
models.append([
|
||||
[f'<img src="file={os.path.join(modules.paths.script_path, "html", "card-no-preview.png")}" '
|
||||
'style="width:100px;height:150px;">'],
|
||||
[os.path.basename(fname)],
|
||||
[fname],
|
||||
[chkpt_info.shorthash],
|
||||
[], [], []])
|
||||
|
||||
return models
|
||||
|
||||
def search_model_by_hash(self, lookup_sha256: str, lookup_shash: str, fname: str) -> t.Optional[t.List[t.Any]]:
|
||||
self.logger.info(f"lookup_sha256: {lookup_sha256}, lookup_shash: {lookup_shash}, fname: {fname}")
|
||||
|
||||
res = None
|
||||
if lookup_sha256 is None and lookup_shash is None:
|
||||
return None
|
||||
|
||||
for model in self.model_set:
|
||||
match = False
|
||||
|
||||
for ver in model['modelVersions']:
|
||||
for file in ver['files']:
|
||||
if lookup_sha256 is not None and 'SHA256' in file['hashes'].keys():
|
||||
match = (lookup_sha256.upper() == file['hashes']['SHA256'].upper())
|
||||
elif lookup_shash is not None:
|
||||
match = (lookup_shash[:10].upper() in [h.upper() for h in file['hashes'].values()])
|
||||
|
||||
if match:
|
||||
cover_link = ver['images'][0]['url'].replace('width=450', 'width=100')
|
||||
mid = model['id']
|
||||
res = [
|
||||
[
|
||||
f'<a href="https://civitai.com/models/{mid}" target="_blank"><img src="{cover_link}"></a>'],
|
||||
[f"{model['name']}/{ver['name']}"],
|
||||
[fname],
|
||||
[lookup_shash],
|
||||
[model['creator']['username']],
|
||||
[model['type']],
|
||||
[model['nsfw']],
|
||||
[ver['trainedWords']],
|
||||
]
|
||||
|
||||
if match:
|
||||
break
|
||||
|
||||
return res
|
||||
|
||||
def update_xformers(self, gpu, checkgroup):
|
||||
if '--xformers' in self.prelude.gpu_setting[gpu]:
|
||||
if 'Enable xFormers' not in checkgroup:
|
||||
checkgroup.append('Enable xFormers')
|
||||
|
||||
return checkgroup
|
||||
|
||||
def set_nsfw(self, search='', nsfw_checker=False, model_type='All') -> t.Dict:
|
||||
self._allow_nsfw = nsfw_checker
|
||||
new_list = self.get_images_html(search, model_type)
|
||||
if self._ds_models is None:
|
||||
self.logger.error(f"_ds_models is not initialized")
|
||||
return {}
|
||||
|
||||
self._ds_models.samples = new_list
|
||||
return self._ds_models.update(samples=new_list)
|
||||
|
||||
def search_model(self, search='', model_type='All') -> t.Dict:
|
||||
if self._ds_models is None:
|
||||
self.logger.error(f"_ds_models is not initialized")
|
||||
return {}
|
||||
|
||||
new_list = self.get_images_html(search, model_type)
|
||||
|
||||
self._ds_models.samples = new_list
|
||||
return self._ds_models.update(samples=new_list)
|
||||
|
||||
def get_model_info(self, models) -> t.Tuple[t.List[t.List[str]], t.Dict, str, t.Dict]:
|
||||
drop_list = []
|
||||
cover_imgs = []
|
||||
htmlDetail = "<div><p>Empty</p></div>"
|
||||
|
||||
mid = models[1]
|
||||
|
||||
# TODO: use map to enhance the performances
|
||||
m = [e for e in self.model_set if e['id'] == mid][0]
|
||||
|
||||
self.model_files.clear()
|
||||
|
||||
download_url_by_default = None
|
||||
if m and m.get('modelVersions') and len(m.get('modelVersions')) > 0:
|
||||
latest_version = m['modelVersions'][0]
|
||||
|
||||
if latest_version.get('images') and isinstance(latest_version.get('images'), list):
|
||||
for img in latest_version['images']:
|
||||
if self.allow_nsfw or (not self.allow_nsfw and not img.get('nsfw')):
|
||||
if img.get('url'):
|
||||
cover_imgs.append([img['url'], ''])
|
||||
|
||||
if latest_version.get('files') and isinstance(latest_version.get('files'), list):
|
||||
for file in latest_version['files']:
|
||||
# error checking for mandatory fields
|
||||
if file.get('id') is not None and file.get('downloadUrl') is not None:
|
||||
item_name = None
|
||||
if file.get('name'):
|
||||
item_name = file.get('name')
|
||||
if not item_name and latest_version.get('name'):
|
||||
item_name = latest_version['name']
|
||||
if not item_name:
|
||||
item_name = "unknown"
|
||||
|
||||
self.model_files.append({
|
||||
"id:": file['id'],
|
||||
"url": file['downloadUrl'],
|
||||
"name": item_name,
|
||||
"type": m['type'] if m.get('type') else "unknown",
|
||||
"size": file['sizeKB'] * 1024 if file.get('sizeKB') else "unknown",
|
||||
"format": file['format'] if file.get('format') else "unknown",
|
||||
"cover": cover_imgs[0][0] if len(cover_imgs) > 0 else toolkit.get_not_found_image_url(),
|
||||
})
|
||||
file_size = toolkit.get_readable_size(file['sizeKB'] * 1024) if file.get('sizeKB') else ""
|
||||
if file_size:
|
||||
drop_list.append(f"{item_name} ({file_size})")
|
||||
else:
|
||||
drop_list.append(f"{item_name}")
|
||||
|
||||
if not download_url_by_default:
|
||||
download_url_by_default = file.get('downloadUrl')
|
||||
|
||||
htmlDetail = '<div>'
|
||||
if m.get('name'):
|
||||
htmlDetail += f"<h1>{m['name']}</h1></br>"
|
||||
if m.get('stats') and m.get('stats').get('downloadCount'):
|
||||
htmlDetail += f"<p>Downloads: {m['stats']['downloadCount']}</p>"
|
||||
if m.get('stats') and m.get('stats').get('rating'):
|
||||
htmlDetail += f"<p>Rating: {m['stats']['rating']}</p>"
|
||||
if m.get('creator') and m.get('creator').get('username'):
|
||||
htmlDetail += f"<p>Author: {m['creator']['username']}</p></div></br></br>"
|
||||
if latest_version.get('name'):
|
||||
htmlDetail += f"<div><table><tbody><tr><td>Version:</td><td>{latest_version['name']}</td></tr>"
|
||||
if latest_version.get('updatedAt'):
|
||||
htmlDetail += f"<tr><td>Updated Time:</td><td>{latest_version['updatedAt']}</td></tr>"
|
||||
if m.get('type'):
|
||||
htmlDetail += f"<tr><td>Type:</td><td>{m['type']}</td></tr>"
|
||||
if latest_version.get('baseModel'):
|
||||
htmlDetail += f"<tr><td>Base Model:</td><td>{latest_version['baseModel']}</td></tr>"
|
||||
htmlDetail += f"<tr><td>NFSW:</td><td>{m.get('nsfw') if m.get('nsfw') is not None else 'false'}</td></tr>"
|
||||
if m.get('tags') and isinstance(m.get('tags'), list):
|
||||
htmlDetail += f"<tr><td>Tags:</td><td>"
|
||||
for t in m['tags']:
|
||||
htmlDetail += f"<span>{t}</span>"
|
||||
htmlDetail += "</td></tr>"
|
||||
if latest_version.get('trainedWords'):
|
||||
htmlDetail += f"<tr><td>Trigger Words:</td><td>"
|
||||
for t in latest_version['trainedWords']:
|
||||
htmlDetail += f"<span>{t}</span>"
|
||||
htmlDetail += "</td></tr>"
|
||||
htmlDetail += "</tbody></table></div>"
|
||||
htmlDetail += f"<div>{m['description'] if m.get('description') else 'N/A'}</div>"
|
||||
|
||||
return (
|
||||
cover_imgs,
|
||||
gr.Dropdown.update(choices=drop_list, value=drop_list[0] if len(drop_list) > 0 else []),
|
||||
htmlDetail,
|
||||
gr.HTML.update(value=f'<p style="text-align: center;">'
|
||||
f'<a style="text-align: center;" href="{download_url_by_default}" '
|
||||
'target="_blank">Download</a></p>')
|
||||
)
|
||||
|
||||
def get_downloading_status(self):
|
||||
(_, _, desc) = self.downloader_manager.tasks_summary()
|
||||
return gr.HTML.update(value=desc)
|
||||
|
||||
def download_model(self, filename: str):
|
||||
model_path = modules.paths.models_path
|
||||
script_path = modules.paths.script_path
|
||||
|
||||
urls = []
|
||||
for _, f in enumerate(self.model_files):
|
||||
if not f.get('name'):
|
||||
continue
|
||||
model_fname = re.sub(r"\s*\(\d+(?:\.\d*)?.B\)\s*$", "", f['name'])
|
||||
|
||||
if model_fname in filename:
|
||||
m_pre, m_ext = os.path.splitext(model_fname)
|
||||
cover_fname = f"{m_pre}.jpg"
|
||||
|
||||
if f['type'] == 'LORA':
|
||||
cover_fname = os.path.join(model_path, 'Lora', cover_fname)
|
||||
model_fname = os.path.join(model_path, 'Lora', model_fname)
|
||||
elif f['format'] == 'VAE':
|
||||
cover_fname = os.path.join(model_path, 'VAE', cover_fname)
|
||||
model_fname = os.path.join(model_path, 'VAE', model_fname)
|
||||
elif f['format'] == 'TextualInversion':
|
||||
cover_fname = os.path.join(script_path, 'embeddings', cover_fname)
|
||||
model_fname = os.path.join(script_path, 'embeddings', model_fname)
|
||||
elif f['format'] == 'Hypernetwork':
|
||||
cover_fname = os.path.join(model_path, 'hypernetworks', cover_fname)
|
||||
model_fname = os.path.join(model_path, 'hypernetworks', model_fname)
|
||||
else:
|
||||
cover_fname = os.path.join(model_path, 'Stable-diffusion', cover_fname)
|
||||
model_fname = os.path.join(model_path, 'Stable-diffusion', model_fname)
|
||||
|
||||
urls.append((f['cover'], f['url'], f['size'], cover_fname, model_fname))
|
||||
break
|
||||
|
||||
for (cover_url, model_url, total_size, local_cover_name, local_model_name) in urls:
|
||||
self.downloader_manager.download(
|
||||
source_url=cover_url,
|
||||
target_file=local_cover_name,
|
||||
estimated_total_size=None,
|
||||
)
|
||||
self.downloader_manager.download(
|
||||
source_url=model_url,
|
||||
target_file=local_model_name,
|
||||
estimated_total_size=total_size,
|
||||
)
|
||||
|
||||
#
|
||||
# currently, web-ui is without queue enabled.
|
||||
#
|
||||
# webui_queue_enabled = False
|
||||
# if webui_queue_enabled:
|
||||
# start = time.time()
|
||||
# downloading_tasks_iter = self.downloader_manager.iterator()
|
||||
# for i in progressbar.tqdm(range(100), unit="byte", desc="Models Downloading"):
|
||||
# while True:
|
||||
# try:
|
||||
# finished_bytes, total_bytes = next(downloading_tasks_iter)
|
||||
# v = finished_bytes / total_bytes
|
||||
# print(f"\n v = {v}")
|
||||
# if isinstance(v, float) and int(v * 100) < i:
|
||||
# print(f"\nv({v}) < {i}")
|
||||
# continue
|
||||
# else:
|
||||
# break
|
||||
# except StopIteration:
|
||||
# break
|
||||
#
|
||||
# time.sleep(0.5)
|
||||
#
|
||||
# self.logger.info(f"[downloading] finished after {time.time() - start} secs")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
self.model_files.clear()
|
||||
return gr.HTML.update(value=f"<h4>{len(urls)} downloading tasks added into task list</h4>")
|
||||
|
||||
def change_boot_setting(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
|
||||
self.get_final_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args)
|
||||
self.logger.info(f'saved_cmd: {self.cmdline_args}')
|
||||
|
||||
target_webui_user_file = "webui-user.bat"
|
||||
script_export_keyword = "export"
|
||||
if platform.system() == "Linux":
|
||||
target_webui_user_file = "webui-user.sh"
|
||||
elif platform.system() == "Darwin":
|
||||
target_webui_user_file = "webui-macos-env.sh"
|
||||
else:
|
||||
script_export_keyword = "set"
|
||||
|
||||
filepath = os.path.join(modules.shared.script_path, target_webui_user_file)
|
||||
self.logger.info(f"to update: {filepath}")
|
||||
|
||||
msg = 'Result: Setting Saved.'
|
||||
if version == 'Official Release':
|
||||
try:
|
||||
if not os.path.exists(filepath):
|
||||
shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath)
|
||||
|
||||
with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file:
|
||||
for line in file:
|
||||
if 'COMMANDLINE_ARGS' in line:
|
||||
rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\')
|
||||
line = f'{script_export_keyword} COMMANDLINE_ARGS="{rep_txt}"\n'
|
||||
sys.stdout.write(line)
|
||||
|
||||
except Exception as e:
|
||||
msg = f'Error: {str(e)}'
|
||||
else:
|
||||
try:
|
||||
if not os.path.exists(filepath):
|
||||
shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath)
|
||||
|
||||
with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file:
|
||||
for line in file:
|
||||
if 'webui.py' in line:
|
||||
rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\')
|
||||
line = f"python\python.exe webui.py {rep_txt}\n"
|
||||
sys.stdout.write(line)
|
||||
except Exception as e:
|
||||
msg = f'Error: {str(e)}'
|
||||
|
||||
self.update_boot_settings(version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args)
|
||||
return gr.update(value=msg, visible=True)
|
||||
|
||||
@property
|
||||
def model_set(self) -> t.List[t.Dict]:
|
||||
try:
|
||||
self.logger.info(f"access to model info for '{self.model_source}'")
|
||||
model_json_mtime = toolkit.get_file_last_modified_time(self.prelude.model_json[self.model_source])
|
||||
|
||||
if self._model_set is None or self._model_set_last_access_time is None \
|
||||
or self._model_set_last_access_time < model_json_mtime:
|
||||
self._model_set = self.get_all_models(self.model_source)
|
||||
self._model_set_last_access_time = model_json_mtime
|
||||
self.logger.info(f"load '{self.model_source}' model data from local file")
|
||||
except Exception as e:
|
||||
self._model_set = self.fetch_all_models()
|
||||
self._model_set_last_access_time = datetime.datetime.now()
|
||||
|
||||
return self._model_set
|
||||
|
||||
@property
|
||||
def allow_nsfw(self) -> bool:
|
||||
return self._allow_nsfw
|
||||
|
||||
@property
|
||||
def old_additional_args(self) -> str:
|
||||
return self._old_additional
|
||||
|
||||
@property
|
||||
def ds_models(self) -> gr.Dataset:
|
||||
return self._ds_models
|
||||
|
||||
@ds_models.setter
|
||||
def ds_models(self, newone: gr.Dataset):
|
||||
self._ds_models = newone
|
||||
|
||||
@property
|
||||
def model_source(self) -> str:
|
||||
return self._model_source
|
||||
|
||||
@model_source.setter
|
||||
def model_source(self, newone: str):
|
||||
self.logger.info(f"model source changes from {self.model_source} to {newone}")
|
||||
self._model_source = newone
|
||||
self._model_set_last_access_time = None # reset timestamp
|
||||
|
||||
def introception(self) -> None:
|
||||
(gpu, theme, port, checkbox_values, extra_args, ver) = self.get_default_args()
|
||||
|
||||
print("################################################################")
|
||||
print("MIAOSHOU ASSISTANT ARGUMENTS:")
|
||||
|
||||
print(f" gpu = {gpu}")
|
||||
print(f" theme = {theme}")
|
||||
print(f" port = {port}")
|
||||
print(f" checkbox_values = {checkbox_values}")
|
||||
print(f" extra_args = {extra_args}")
|
||||
print(f" webui ver = {ver}")
|
||||
|
||||
print("################################################################")
|
||||
|
||||
Loading…
Reference in New Issue