implement partial features

Signed-off-by: yanshoutong <yanshoutong@sina.cn>
pull/38/head
yanshoutong 2023-04-08 09:14:58 +08:00
parent c3017fef4e
commit e67f68e7c3
17 changed files with 1795 additions and 0 deletions

17
.gitignore vendored Normal file
View File

@ -0,0 +1,17 @@
# for vim
.~
.*.swp
*~
# for MacOS
.DS_Store
__pycache__/
.idea/
logs/
flagged/
configs/civitai_models.json
configs/liandange_models.json

30
install.py Normal file
View File

@ -0,0 +1,30 @@
import launch
import os
import gzip
import io
def install_preset_models_if_needed():
assets_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "assets")
configs_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
for model_filename in ["civitai_models.json", "liandange_models.json"]:
gzip_file = os.path.join(assets_folder, f"{model_filename}.gz")
target_file = os.path.join(configs_folder, f"{model_filename}")
if not os.path.exists(target_file):
with gzip.open(gzip_file, "rb") as compressed_file:
with io.TextIOWrapper(compressed_file, encoding="utf-8") as decoder:
content = decoder.read()
with open(target_file, "w") as model_file:
model_file.write(content)
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
with open(req_file) as file:
for lib in file:
lib = lib.strip()
if not launch.is_installed(lib):
launch.run_pip(f"install {lib}", f"Miaoshou assistant requirement: {lib}")
install_preset_models_if_needed()

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
psutil
rehash
tqdm

View File

@ -0,0 +1 @@
__all__ = ["miaoshou"]

View File

@ -0,0 +1,256 @@
import os
import sys
import typing as t
import gradio as gr
import launch
import modules
from scripts.logging.msai_logger import Logger
from scripts.runtime.msai_prelude import MiaoshouPrelude
from scripts.runtime.msai_runtime import MiaoshouRuntime
class MiaoShouAssistant(object):
# default css definition
default_css = '#my_model_cover{width: 100px;} #my_model_trigger_words{width: 200px;}'
def __init__(self) -> None:
self.logger = Logger()
self.prelude = MiaoshouPrelude()
self.runtime = MiaoshouRuntime()
self.refresh_symbol = '\U0001f504'
def on_event_ui_tabs_opened(self) -> t.List[t.Optional[t.Tuple[t.Any, str, str]]]:
with gr.Blocks(analytics_enabled=False, css=MiaoShouAssistant.default_css) as miaoshou_assistant:
self.create_subtab_boot_assistant()
self.create_subtab_model_management()
self.create_subtab_model_download()
return [(miaoshou_assistant.queue(), "Miaoshou Assistant", "miaoshou_assistant")]
def create_subtab_boot_assistant(self) -> None:
with gr.TabItem('Boot Assistant', elem_id="boot_assistant_tab") as boot_assistant:
with gr.Row():
with gr.Column(elem_id="col_model_list"):
gpu, theme, port, chk_args, txt_args, webui_ver = self.runtime.get_default_args()
gr.Markdown(value="Argument settings")
with gr.Row():
drp_gpu = gr.Dropdown(label="", elem_id="drp_args_vram",
choices=list(self.prelude.gpu_setting.keys()),
value=gpu, interactive=True)
drp_theme = gr.Dropdown(label="UI Theme", choices=list(self.prelude.theme_setting.keys()),
value=theme,
elem_id="drp_args_theme", interactive=True)
txt_listen_port = gr.Text(label='Listen Port', value=port, elem_id="txt_args_listen_port",
interactive=True)
with gr.Row():
chk_group_args = gr.CheckboxGroup(choices=list(self.prelude.checkboxes.keys()), value=chk_args,
show_label=False)
additional_args = gr.Text(label='COMMANDLINE_ARGS (Divide by space)', value=txt_args,
elem_id="txt_args_more", interactive=True)
with gr.Row():
with gr.Column():
txt_save_status = gr.Markdown(visible=False, interactive=False, show_label=False)
drp_choose_version = gr.Dropdown(label="WebUI Version",
choices=['Official Release', 'Python Integrated'],
value=webui_ver, elem_id="drp_args_version",
interactive=True)
gr.HTML(
'<div><p>*Save your settings to webui-user.bat file. Use Python Integrated only if your'
' WebUI is extracted from a zip file and does not need python installation</p></div>')
save_settings = gr.Button(value="Save settings", elem_id="btn_arg_save_setting")
with gr.Row():
# with gr.Column():
# settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="ms_settings_submit")
# with gr.Column():
restart_gradio = gr.Button(value='Apply & Restart WebUI', variant='primary',
elem_id="ms_settings_restart_gradio")
'''def mod_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
global commandline_args
get_final_args(drp_gpu, drp_theme, txt_listen_port, hk_group_args, additional_args)
print(commandline_args)
print(sys.argv)
#if '--xformers' not in sys.argv:
#sys.argv.append('--xformers')
settings_submit.click(mod_args, inputs=[drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args], outputs=[])'''
save_settings.click(self.runtime.change_boot_setting,
inputs=[drp_choose_version, drp_gpu, drp_theme, txt_listen_port, chk_group_args,
additional_args], outputs=[txt_save_status])
restart_gradio.click(
fn=self.request_restart,
_js='restart_reload',
inputs=[],
outputs=[],
)
with gr.Column():
with gr.Row():
machine_settings = self.prelude.get_sys_info()
txt_sys_info = gr.TextArea(value=machine_settings, lines=20, max_lines=20,
label="System Info",
show_label=False, interactive=False)
with gr.Row():
sys_info_refbtn = gr.Button(value="Refresh")
drp_gpu.change(self.runtime.update_xformers, inputs=[drp_gpu, chk_group_args], outputs=[chk_group_args])
sys_info_refbtn.click(self.prelude.get_sys_info, None, txt_sys_info)
def create_subtab_model_management(self) -> None:
with gr.TabItem('Model Management', elem_id="model_management_tab") as tab_batch:
with gr.Row():
with gr.Column():
my_models = self.runtime.get_local_models()
ds_my_models = gr.Dataset(
components=[gr.HTML(visible=False, label='Cover', elem_id='my_model_cover'),
gr.Textbox(visible=False, label='Name/Version'),
gr.Textbox(visible=False, label='File Name'),
gr.Textbox(visible=False, label='Hash'), gr.Textbox(visible=False, label='Creator'),
gr.Textbox(visible=False, label='Type'), gr.Textbox(visible=False, label='NSFW'),
gr.Textbox(visible=False, label='Trigger Words', elem_id='my_model_trigger_words')],
elem_id='my_model_lib',
label="My Models",
headers=None,
samples=my_models,
samples_per_page=50)
with gr.Column():
html_model_prompt = gr.HTML(visible=True,
value='<div style="height:400px;"><p>No Model Selected</p></div>')
with gr.Row():
add = gr.Button(value="Add", variant="primary")
# delete = gr.Button(value="Delete")
with gr.Row():
reset_btn = gr.Button(value="Reset")
json_input = gr.Button(value="Load from JSON")
png_input = gr.Button(value="Detect from image")
png_input_area = gr.Image(label="Detect from image", elem_id="openpose_editor_input")
bg_input = gr.Button(value="Add Background image")
def create_subtab_model_download(self) -> None:
with gr.TabItem('Model Download', elem_id="model_download_tab") as tab_downloads:
with gr.Row():
with gr.Column(elem_id="col_model_list"):
with gr.Row().style(equal_height=True):
model_source_dropdown = gr.Dropdown(choices=["civitai", "liandange"],
value=self.runtime.model_source,
label="Select Model Source",
type="value",
show_label=True,
elem_id="model_source").style(full_width=True)
with gr.Row().style(equal_height=True):
search_text = gr.Textbox(
label="Model name",
show_label=False,
max_lines=1,
placeholder="Enter model name",
)
btn_search = gr.Button("Search")
with gr.Row().style(equal_height=True):
nsfw_checker = gr.Checkbox(label='NSFW', value=False, elem_id="chk_nsfw", interactive=True)
model_type = gr.Radio(["All", "Checkpoint", "LORA", "TextualInversion", "Hypernetwork"],
show_label=False, value='All', elem_id="rad_model_type",
interactive=True).style(full_width=True)
images = self.runtime.get_images_html()
self.runtime.ds_models = gr.Dataset(
components=[gr.HTML(visible=False)],
headers=None,
type="values",
label="Models",
samples=images,
samples_per_page=60,
elem_id="model_dataset").style(type="gallery", container=True)
with gr.Column(elem_id="col_model_info"):
with gr.Row():
cover_gallery = gr.Gallery(label="Cover", show_label=False, visible=True).style(grid=[4],
height="2")
with gr.Row():
with gr.Column():
download_summary = gr.HTML('<div><span>No downloading tasks ongoing</span></div>')
downloading_status = gr.Button(value=f"{self.refresh_symbol} Refresh Downloading Status",
elem_id="ms_dwn_status")
with gr.Row():
model_dropdown = gr.Dropdown(choices=['Select Model'], label="Models", show_label=False,
value='Select Model', elem_id='ms_dwn_button',
interactive=True)
is_civitai_model_source_active = self.runtime.model_source == "civitai"
with gr.Row(variant="panel"):
dwn_button = gr.Button(value='Download',
visible=is_civitai_model_source_active, elem_id='ms_dwn_button')
open_url_in_browser_newtab_button = gr.HTML(
value='<p style="text-align: center;">'
'<a style="text-align: center;" href="https://models.paomiantv.cn/models" '
'target="_blank">Download</a></p>',
visible=not is_civitai_model_source_active)
with gr.Row():
model_info = gr.HTML(visible=True)
nsfw_checker.change(self.runtime.set_nsfw, inputs=[search_text, nsfw_checker, model_type],
outputs=self.runtime.ds_models)
model_type.change(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models)
btn_search.click(self.runtime.search_model, inputs=[search_text, model_type], outputs=self.runtime.ds_models)
self.runtime.ds_models.click(self.runtime.get_model_info,
inputs=[self.runtime.ds_models],
outputs=[
cover_gallery,
model_dropdown,
model_info,
open_url_in_browser_newtab_button
])
dwn_button.click(self.runtime.download_model, inputs=[model_dropdown], outputs=[download_summary])
downloading_status.click(self.runtime.get_downloading_status, inputs=[], outputs=[download_summary])
model_source_dropdown.change(self.switch_model_source,
inputs=[model_source_dropdown],
outputs=[self.runtime.ds_models, dwn_button, open_url_in_browser_newtab_button])
def request_restart(self, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
print('request_restart: cmd_arg = ', self.runtime.cmdline_args)
print('request_restart: sys.argv = ', sys.argv)
modules.shared.state.interrupt()
modules.shared.state.need_restart = True
# reset args
sys.argv = [sys.argv[0]]
os.environ['COMMANDLINE_ARGS'] = ""
print('remove', sys.argv)
for arg in self.runtime.cmdline_args:
sys.argv.append(arg)
print('after', sys.argv)
launch.prepare_environment()
launch.start()
def switch_model_source(self, new_model_source: str):
self.runtime.model_source = new_model_source
show_download_button = self.runtime.model_source == "civitai"
images = self.runtime.get_images_html()
self.runtime.ds_models.samples = images
return (
gr.Dataset.update(samples=images),
gr.Button.update(visible=show_download_button),
gr.HTML.update(visible=not show_download_button)
)
def introception(self) -> None:
self.runtime.introception()

View File

@ -0,0 +1 @@
__all__ = ["msai_downloader_manager"]

View File

@ -0,0 +1,248 @@
import asyncio
import os.path
import queue
import time
import requests
import typing as t
from threading import Thread, Lock
from scripts.download.msai_file_downloader import MiaoshouFileDownloader
from scripts.logging.msai_logger import Logger
from scripts.msai_utils.msai_singleton import MiaoshouSingleton
import scripts.msai_utils.msai_toolkit as toolkit
class DownloadingEntry(object):
def __init__(self, target_url: str = None, local_file: str = None,
local_directory: str = None, estimated_total_size: float = 0., expected_checksum: str = None):
self._target_url = target_url
self._local_file = local_file
self._local_directory = local_directory
self._expected_checksum = expected_checksum
self._estimated_total_size = estimated_total_size
self._total_size = 0
self._downloaded_size = 0
self._downloading = False
self._failure = False
@property
def target_url(self) -> str:
return self._target_url
@property
def local_file(self) -> str:
return self._local_file
@property
def local_directory(self) -> str:
return self._local_directory
@property
def expected_checksum(self) -> str:
return self._expected_checksum
@property
def total_size(self) -> int:
return self._total_size
@total_size.setter
def total_size(self, sz: int) -> None:
self._total_size = sz
@property
def downloaded_size(self) -> int:
return self._downloaded_size
@downloaded_size.setter
def downloaded_size(self, sz: int) -> None:
self._downloaded_size = sz
@property
def estimated_size(self) -> float:
return self._estimated_total_size
def is_downloading(self) -> bool:
return self._downloading
def start_download(self) -> None:
self._downloading = True
def update_final_status(self, result: bool) -> None:
self._failure = result
self._downloading = False
def is_failure(self) -> bool:
return self._failure
class AsyncLoopThread(Thread):
def __init__(self):
super(AsyncLoopThread, self).__init__(daemon=True)
self.loop = asyncio.new_event_loop()
self.logger = Logger()
self.logger.info("looper thread is created")
def run(self):
asyncio.set_event_loop(self.loop)
self.logger.info("looper thread is running")
self.loop.run_forever()
class MiaoshouDownloaderManager(metaclass=MiaoshouSingleton):
_downloading_entries: t.Dict[str, DownloadingEntry] = None
def __init__(self):
if self._downloading_entries is None:
self._downloading_entries = {}
self.message_queue = queue.Queue()
self.logger = Logger()
self.looper = AsyncLoopThread()
self.looper.start()
self.logger.info("download manager is ready")
self._mutex = Lock()
def consume_all_ready_messages(self) -> None:
"""
capture all enqueued messages, this method should not be used if you are iterating over the message queue
:return:
None
:side-effect:
update downloading entries' status
"""
while True:
# self.logger.info("fetching the enqueued message")
try:
(aurl, finished_size, total_size) = self.message_queue.get(block=False, timeout=0.2)
# self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}")
try:
self._mutex.acquire(blocking=True)
self._downloading_entries[aurl].total_size = total_size
self._downloading_entries[aurl].downloaded_size = finished_size
finally:
self._mutex.release()
except queue.Empty:
break
def iterator(self) -> t.Tuple[float, float]:
while True:
self.logger.info("waiting for incoming message")
try:
(aurl, finished_size, total_size) = self.message_queue.get(block=True)
self.logger.info(f"[+] message ([{finished_size}/{total_size}] {aurl}")
try:
self._mutex.acquire(blocking=True)
self._downloading_entries[aurl].total_size = total_size
self._downloading_entries[aurl].downloaded_size = finished_size
tasks_total_size = 0.
tasks_finished_size = 0.
for e in self._downloading_entries.values():
tasks_total_size += e.total_size
tasks_finished_size += e.downloaded_size
yield tasks_finished_size, tasks_total_size
finally:
self._mutex.release()
except queue.Empty:
if len(asyncio.all_tasks(self.looper.loop)) == 0:
self.logger.info("all downloading tasks finished")
break
async def _submit_task(self, download_entry: DownloadingEntry) -> None:
try:
self._mutex.acquire(blocking=True)
if download_entry.target_url in self._downloading_entries:
self.logger.warn(f"{download_entry.target_url} is already downloading")
return
else:
download_entry.start_download()
self._downloading_entries[download_entry.target_url] = download_entry
finally:
self._mutex.release()
file_downloader = MiaoshouFileDownloader(
target_url=download_entry.target_url,
local_file=download_entry.local_file,
local_directory=download_entry.local_directory,
channel=self.message_queue if download_entry.estimated_size else None,
estimated_total_length=download_entry.estimated_size,
expected_checksum=download_entry.expected_checksum,
)
result: bool = await self.looper.loop.run_in_executor(None, file_downloader.download_file)
try:
self._mutex.acquire(blocking=True)
self._downloading_entries[download_entry.target_url].update_final_status(result)
finally:
self._mutex.release()
def download(self, source_url: str, target_file: str, estimated_total_size: float,
expected_checksum: str = None) -> None:
target_dir = os.path.dirname(target_file)
target_filename = os.path.basename(target_file)
download_entry = DownloadingEntry(
target_url=source_url,
local_file=target_filename,
local_directory=target_dir,
estimated_total_size=estimated_total_size,
expected_checksum=expected_checksum
)
asyncio.run_coroutine_threadsafe(self._submit_task(download_entry), self.looper.loop)
def tasks_summary(self) -> t.Tuple[int, int, str]:
self.consume_all_ready_messages()
total_tasks_num = 0
ongoing_tasks_num = 0
failed_tasks_num = 0
try:
description = "<div>"
self._mutex.acquire(blocking=True)
for name, entry in self._downloading_entries.items():
if entry.estimated_size is None:
continue
total_tasks_num += 1
if entry.total_size > 0.:
description += f"<p>{entry.local_file} ({toolkit.get_readable_size(entry.total_size)}) : "
else:
description += f"<p>{entry.local_file} ({toolkit.get_readable_size(entry.estimated_size)}) : "
if entry.is_downloading():
ongoing_tasks_num += 1
finished_percent = entry.downloaded_size/entry.estimated_size * 100
description += f'<span style="color:blue;font-weight:bold">{round(finished_percent, 2)} %</span>'
elif entry.is_failure():
failed_tasks_num += 1
description += '<span style="color:red;font-weight:bold">failed!</span>'
else:
description += '<span style="color:green;font-weight:bold">finished</span>'
description += "</p><br>"
finally:
self._mutex.release()
pass
description += "</div>"
overall = f"""
<h4>
<span style="color:blue;font-weight:bold">{ongoing_tasks_num}</span> ongoing,
<span style="color:green;font-weight:bold">{total_tasks_num - ongoing_tasks_num - failed_tasks_num}</span> finished,
<span style="color:red;font-weight:bold">{failed_tasks_num}</span> failed.
</h4>
<br>
<br>
"""
return ongoing_tasks_num, total_tasks_num, overall + description

View File

@ -0,0 +1,226 @@
import os
import pickle
import queue
import time
import typing as t
from pathlib import Path
from urllib.parse import urlparse
import rehash
import requests
from requests.adapters import HTTPAdapter
from tqdm import tqdm
from urllib3.util import Retry
import scripts.msai_utils.msai_toolkit as toolkit
from scripts.logging.msai_logger import Logger
class MiaoshouFileDownloader(object):
CHUNK_SIZE = 1024 * 1024
def __init__(self, target_url: str = None,
local_file: str = None, local_directory: str = None, estimated_total_length: float = 0.,
expected_checksum: str = None,
channel: queue.Queue = None,
max_retries=3) -> None:
self.logger = Logger()
self.target_url: str = target_url
self.local_file: str = local_file
self.local_directory = local_directory
self.expected_checksum = expected_checksum
self.max_retries = max_retries
self.accept_ranges: bool = False
self.estimated_content_length = estimated_total_length
self.content_length: int = -1
self.finished_chunk_size: int = 0
self.channel = channel # for communication
# Support 3 retries and backoff
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
method_whitelist=["HEAD", "GET", "OPTIONS"]
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session = requests.Session()
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)
# inform message receiver at once
if self.channel:
self.channel.put_nowait((
self.target_url,
self.finished_chunk_size,
self.estimated_content_length,
))
# Head request to get file-length and check whether it supports ranges.
def get_file_info_from_server(self, target_url: str) -> t.Tuple[bool, float]:
try:
headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip
response = requests.head(target_url, headers=headers, allow_redirects=True)
response.raise_for_status()
content_length = None
if "Content-Length" in response.headers:
content_length = int(response.headers['Content-Length'])
accept_ranges = (response.headers.get("Accept-Ranges") == "bytes")
return accept_ranges, float(content_length)
except Exception as ex:
self.logger.info(f"HEAD Request Error: {ex}")
return False, self.estimated_content_length
def download_file_full(self, target_url: str, local_filepath: str) -> t.Optional[str]:
try:
checksum = rehash.sha256()
headers = {"Accept-Encoding": "identity"} # Avoid dealing with gzip
with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN",
desc=os.path.basename(self.local_file)) as progressbar, \
self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \
open(local_filepath, 'wb') as file_out:
response.raise_for_status()
for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE):
file_out.write(chunk)
checksum.update(chunk)
progressbar.update(len(chunk))
self.update_progress(len(chunk))
except Exception as ex:
self.logger.info(f"Download error: {ex}")
return None
return checksum.hexdigest()
def download_file_resumable(self, target_url: str, local_filepath: str) -> t.Optional[str]:
# Always go off the checkpoint as the file was flushed before writing.
download_checkpoint = local_filepath + ".downloading"
try:
resume_point, checksum = pickle.load(open(download_checkpoint, "rb"))
assert os.path.exists(local_filepath) # catch checkpoint without file
self.logger.info("File already exists, resuming download.")
except Exception as e:
self.logger.error(f"failed to load downloading checkpoint - {download_checkpoint} due to {e}")
resume_point = 0
checksum = rehash.sha256()
if os.path.exists(local_filepath):
os.remove(local_filepath)
Path(local_filepath).touch()
assert (resume_point < self.content_length)
self.finished_chunk_size = resume_point
# Support resuming
headers = {"Range": f"bytes={resume_point}-", "Accept-Encoding": "identity"}
try:
with tqdm(total=self.content_length, unit="byte", unit_scale=1, colour="GREEN",
desc=os.path.basename(self.local_file)) as progressbar, \
self.session.get(target_url, headers=headers, stream=True, timeout=5) as response, \
open(local_filepath, 'r+b') as file_out:
response.raise_for_status()
self.update_progress(resume_point)
file_out.seek(resume_point)
for chunk in response.iter_content(MiaoshouFileDownloader.CHUNK_SIZE):
file_out.write(chunk)
file_out.flush()
resume_point += len(chunk)
checksum.update(chunk)
pickle.dump((resume_point, checksum), open(download_checkpoint, "wb"))
progressbar.update(len(chunk))
self.update_progress(len(chunk))
# Only remove checkpoint at full size in case connection cut
if os.path.getsize(local_filepath) == self.content_length:
os.remove(download_checkpoint)
else:
return None
except Exception as ex:
self.logger.error(f"Download error: {ex}")
return None
return checksum.hexdigest()
def update_progress(self, finished_chunk_size: int) -> None:
self.finished_chunk_size += finished_chunk_size
if self.channel:
self.channel.put_nowait((
self.target_url,
self.finished_chunk_size,
self.content_length,
))
# In order to avoid leaving extra garbage meta files behind this
# will overwrite any existing files found at local_file. If you don't want this
# behaviour you can handle this externally.
# local_file and local_directory could write to unexpected places if the source
# is untrusted, be careful!
def download_file(self) -> bool:
success = False
try:
# Need to rebuild local_file_final each time in case of different urls
if not self.local_file:
specific_local_file = os.path.basename(urlparse(self.target_url).path)
else:
specific_local_file = self.local_file
download_temp_dir = toolkit.get_user_temp_dir()
toolkit.assert_user_temp_dir()
if self.local_directory:
os.makedirs(self.local_directory, exist_ok=True)
specific_local_file = os.path.join(download_temp_dir, specific_local_file)
self.accept_ranges, self.content_length = self.get_file_info_from_server(self.target_url)
self.logger.info(f"Accept-Ranges: {self.accept_ranges}. content length: {self.content_length}")
if self.accept_ranges and self.content_length:
download_method = self.download_file_resumable
self.logger.info("Server supports resume")
else:
download_method = self.download_file_full
self.logger.info(f"Server doesn't support resume.")
for i in range(self.max_retries):
self.logger.info(f"Download Attempt {i + 1}")
checksum = download_method(self.target_url, specific_local_file)
if checksum:
match = ""
if self.expected_checksum:
match = ", Checksum Match"
if self.expected_checksum and self.expected_checksum != checksum:
self.logger.info(f"Checksum doesn't match. Calculated {checksum} "
f"Expecting: {self.expected_checksum}")
else:
self.logger.info(f"Download successful{match}. Checksum {checksum}")
success = True
break
time.sleep(1)
if success:
print("*" * 80)
self.logger.info(f"{self.target_url} [DOWNLOADED COMPLETELY]")
print("*" * 80)
if self.local_directory:
target_local_file = os.path.join(self.local_directory, self.local_file)
else:
target_local_file = self.local_file
toolkit.move_file(specific_local_file, target_local_file)
else:
print("*" * 80)
self.logger.info(f"{self.target_url} [ FAILED ]")
print("*" * 80)
except Exception as ex:
self.logger.info(f"Unexpected Error: {ex}") # Only from block above
return success

View File

@ -0,0 +1 @@
__all__ = ["msai_logger"]

View File

@ -0,0 +1,100 @@
import datetime
import logging
import logging.handlers
import os
import typing as t
from scripts.msai_utils.msai_singleton import MiaoshouSingleton
class Logger(metaclass=MiaoshouSingleton):
_dataset = None
KEY_TRACE_PATH = "trace_path"
KEY_INFO = "info"
KEY_ERROR = "error"
KEY_JOB = "job"
def _do_init(self, log_folder: str, disable_console_output: bool = False) -> None:
# Setup trace_path with empty string by default, it will be assigned with valid content if needed
self._dataset = {Logger.KEY_TRACE_PATH: ""}
print(f"logs_location: {log_folder}")
os.makedirs(log_folder, exist_ok=True)
# Setup basic logging configuration
logging.basicConfig(level=logging.INFO,
filemode='w',
format='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s')
# Setup info logging
self._dataset[Logger.KEY_INFO] = logging.getLogger(Logger.KEY_INFO)
msg_handler = logging.FileHandler(os.path.join(log_folder, "info.log"),
"a",
encoding="UTF-8")
msg_handler.setLevel(logging.INFO)
msg_handler.setFormatter(
logging.Formatter(fmt='%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s'))
self._dataset[Logger.KEY_INFO].addHandler(msg_handler)
# Setup error logging
self._dataset[Logger.KEY_ERROR] = logging.getLogger(Logger.KEY_ERROR)
error_handler = logging.FileHandler(
os.path.join(log_folder, f'error_{datetime.date.today().strftime("%Y%m%d")}.log'),
mode="a",
encoding='UTF-8')
error_handler.setLevel(logging.ERROR)
error_handler.setFormatter(
logging.Formatter(
fmt=f"{self._dataset.get('trace_path')}:\n "
f"%(asctime)s - %(filename)s [line:%(lineno)d] - %(levelname)s: %(message)s"))
self._dataset[Logger.KEY_ERROR].addHandler(error_handler)
# Setup job logging
self._dataset[Logger.KEY_JOB] = logging.getLogger(Logger.KEY_JOB)
job_handler = logging.FileHandler(os.path.join(log_folder, "jobs.log"),
mode="a",
encoding="UTF-8")
self._dataset[Logger.KEY_JOB].addHandler(job_handler)
if disable_console_output:
for k in [Logger.KEY_INFO, Logger.KEY_JOB, Logger.KEY_ERROR]:
l: logging.Logger = self._dataset[k]
l.propagate = not disable_console_output
def __init__(self, log_folder: str = None, disable_console_output: bool = False) -> None:
if self._dataset is None:
try:
self._do_init(log_folder=log_folder, disable_console_output=disable_console_output)
except Exception as e:
print(e)
def update_path_info(self, current_path: str) -> None:
self._dataset[Logger.KEY_TRACE_PATH] = current_path
def callback_func(self, exc_type: t.Any, exc_value: t.Any, exc_tracback: t.Any) -> None:
self._dataset[Logger.KEY_JOB].error(f"job failed for {self._dataset[Logger.KEY_TRACE_PATH]}")
self._dataset[Logger.KEY_INFO].error(f"{self._dataset[Logger.KEY_TRACE_PATH]}\n, callback_func: ",
exc_info=(exc_type, exc_value, exc_tracback))
def debug(self, fmt, *args, **kwargs) -> None:
l: logging.Logger = self._dataset[Logger.KEY_INFO]
l.debug(fmt, *args, **kwargs, stacklevel=2)
def info(self, fmt, *args, **kwargs) -> None:
l: logging.Logger = self._dataset[Logger.KEY_INFO]
l.info(fmt, *args, **kwargs, stacklevel=2)
def warn(self, fmt, *args, **kwargs) -> None:
l: logging.Logger = self._dataset[Logger.KEY_INFO]
l.warn(fmt, *args, **kwargs, stacklevel=2)
def error(self, fmt, *args, **kwargs) -> None:
l: logging.Logger = self._dataset[Logger.KEY_ERROR]
l.error(fmt, *args, **kwargs, stacklevel=2)
def job(self, fmt, *args, **kwargs) -> None:
l: logging.Logger = self._dataset[Logger.KEY_JOB]
l.info(fmt, *args, **kwargs, stacklevel=2)

22
scripts/main.py Normal file
View File

@ -0,0 +1,22 @@
import modules
import modules.scripts as scripts
from scripts.assistant.miaoshou import MiaoShouAssistant
class MiaoshouScript(scripts.Script):
def __init__(self) -> None:
super().__init__()
def title(self):
return "Miaoshou Assistant"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
return ()
assistant = MiaoShouAssistant()
modules.script_callbacks.on_ui_tabs(assistant.on_event_ui_tabs_opened)

View File

@ -0,0 +1 @@
__all__ = ["msai_singleton", "msai_toolkit"]

View File

@ -0,0 +1,8 @@
class MiaoshouSingleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(MiaoshouSingleton, cls).__call__(*args, **kwargs)
cls._instances[cls].__init__(*args, **kwargs)
return cls._instances[cls]

View File

@ -0,0 +1,88 @@
import json
import os
import platform
import shutil
import typing as t
from datetime import datetime
from pathlib import Path
def read_json(file) -> t.Any:
try:
with open(file, "r", encoding="utf-8-sig") as f:
return json.load(f)
except Exception as e:
print(e)
return None
def write_json(file, content) -> None:
try:
with open(file, 'w') as f:
json.dump(content, f, indent=4)
except Exception as e:
print(e)
def get_args(args) -> t.List[str]:
parameters = []
idx = 0
for arg in args:
if idx == 0 and '--' not in arg:
pass
elif '--' in arg:
parameters.append(rf'{arg}')
idx += 1
else:
parameters[idx - 1] = parameters[idx - 1] + ' ' + rf'{arg}'
return parameters
def get_readable_size(size: int, precision=2) -> str:
if size is None:
return ""
suffixes = ['B', 'KB', 'MB', 'GB', 'TB']
suffixIndex = 0
while size >= 1024 and suffixIndex < len(suffixes):
suffixIndex += 1 # increment the index of the suffix
size = size / 1024.0 # apply the division
return "%.*f%s" % (precision, size, suffixes[suffixIndex])
def get_file_last_modified_time(path_to_file: str) -> datetime:
if path_to_file is None:
return datetime.now()
if platform.system() == "Windows":
return datetime.fromtimestamp(os.path.getmtime(path_to_file))
else:
stat = os.stat(path_to_file)
return datetime.fromtimestamp(stat.st_mtime)
def get_not_found_image_url() -> str:
return "https://msdn.miaoshouai.com/msdn/userimage/not-found.svg"
def get_user_temp_dir() -> str:
return os.path.join(Path.home().absolute(), ".miaoshou_assistant_download")
def assert_user_temp_dir() -> None:
os.makedirs(get_user_temp_dir(), exist_ok=True)
def move_file(src: str, dst: str) -> None:
if not src or not dst:
return
if not os.path.exists(src):
return
if os.path.exists(dst):
os.remove(dst)
os.makedirs(os.path.dirname(dst), exist_ok=True)
shutil.move(src, dst)

View File

@ -0,0 +1,5 @@
__all__ = ["msai_prelude", "msai_runtime"]
from . import msai_prelude as prelude
prelude.MiaoshouPrelude().load()

View File

@ -0,0 +1,158 @@
import os
import platform
import sys
import typing as t
import psutil
import torch
import launch
from scripts.logging.msai_logger import Logger
from scripts.msai_utils import msai_toolkit as toolkit
from scripts.msai_utils.msai_singleton import MiaoshouSingleton
class MiaoshouPrelude(metaclass=MiaoshouSingleton):
_dataset = None
def __init__(self) -> None:
# Potential race condition, not call in multithread environment
if MiaoshouPrelude._dataset is None:
self._init_constants()
MiaoshouPrelude._dataset = {
"log_folder": os.path.join(self.ext_folder, "logs")
}
disable_log_console_output: bool = False
if self.all_settings.get("boot_settings"):
if self.all_settings["boot_settings"].get("disable_log_console_output") is not None:
disable_log_console_output = self.all_settings["boot_settings"].get("disable_log_console_output")
self._logger = Logger(self._dataset["log_folder"], disable_console_output=disable_log_console_output)
def _init_constants(self) -> None:
self._api_url = {
"civitai": "https://civitai.com/api/v1/models",
"liandange": "https://model-api.paomiantv.cn/model/api/models",
}
self._ext_folder = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".."))
self._setting_file = os.path.join(self.ext_folder, "configs", "settings.json")
self._model_hash_file = os.path.join(self.ext_folder, "configs", "model_hash.json")
self._model_json = {
'civitai': os.path.join(self.ext_folder, 'configs', 'civitai_models.json'),
'liandange': os.path.join(self.ext_folder, 'configs', 'liandange_models.json'),
}
self._checkboxes = {
'Enable xFormers': '--xformers',
'No Half': '--no-half',
'No Half VAE': '--no-half-vae',
'Enable API': '--api',
'Auto Launch': '--autolaunch',
'Allow Local Network Access': '--listen',
}
self._gpu_setting = {
'CPU Only': '--precision full --no-half --use-cpu SD GFPGAN BSRGAN ESRGAN SCUNet CodeFormer --all',
'GTX 16xx': '--lowvram --xformers --precision full --no-half',
'Low: 4-6G VRAM': '--xformers --lowvram',
'Med: 6-8G VRAM': '--xformers --medvram',
'Normal: 8+G VRAM': '',
}
self._theme_setting = {
'Auto': '',
'Light Mode': '--theme=light',
'Dark Mode': '--theme=dark',
}
@property
def ext_folder(self) -> str:
return self._ext_folder
@property
def log_folder(self) -> str:
return self._dataset.get("log_folder")
@property
def all_settings(self) -> t.Any:
return toolkit.read_json(self._setting_file)
@property
def boot_settings(self) -> t.Any:
all_setting = self.all_settings
if all_setting:
return all_setting['boot_settings']
else:
return None
def api_url(self, model_source: str) -> t.Optional[str]:
return self._api_url.get(model_source)
@property
def setting_file(self) -> str:
return self._setting_file
@property
def model_hash_file(self) -> str:
return self._model_hash_file
@property
def checkboxes(self) -> t.Dict[str, str]:
return self._checkboxes
@property
def gpu_setting(self) -> t.Dict[str, str]:
return self._gpu_setting
@property
def theme_setting(self) -> t.Dict[str, str]:
return self._theme_setting
@property
def model_json(self) -> t.Dict[str, t.Any]:
return self._model_json
def update_model_json(self, site: str, models: t.Dict[str, t.Any]) -> None:
if self._model_json.get(site) is None:
self._logger.error(f"cannot save model info for {site}")
return
self._logger.info(f"{self._model_json[site]} updated")
toolkit.write_json(self._model_json[site], models)
def load(self) -> None:
self._logger.info("start to do prelude")
self._logger.info(f"cmdline args: {' '.join(sys.argv[1:])}")
@classmethod
def get_sys_info(cls) -> str:
sys_info = 'System Information\n\n'
sys_info += r'OS Name: {0} {1}'.format(platform.system(), platform.release()) + '\n'
sys_info += r'OS Version: {0}'.format(platform.version()) + '\n'
sys_info += r'WebUI Version: {0}'.format(
f'https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{launch.commit_hash()}') + '\n'
sys_info += r'Torch Version: {0}'.format(getattr(torch, '__long_version__', torch.__version__)) + '\n'
sys_info += r'Python Version: {0}'.format(sys.version) + '\n\n'
sys_info += r'CPU: {0}'.format(platform.processor()) + '\n'
sys_info += r'CPU Cores: {0}/{1}'.format(psutil.cpu_count(logical=False), psutil.cpu_count(logical=True)) + '\n'
# FIXME: should uncomment line below and remove my own workaround for MacOS
# sys_info += r'CPU Frequency: {0} GHz'.format(round(psutil.cpu_freq().max/1000,2)) + '\n'
# workaround: let my macbook M1 happy
sys_info += r'CPU Frequency: {0} GHz'.format(round(2540.547 / 1000, 2)) + '\n'
sys_info += r'CPU Usage: {0}%'.format(psutil.cpu_percent()) + '\n\n'
sys_info += r'RAM: {0}'.format(toolkit.get_readable_size(psutil.virtual_memory().total)) + '\n'
sys_info += r'Memory Usage: {0}%'.format(psutil.virtual_memory().percent) + '\n\n'
for i in range(torch.cuda.device_count()):
sys_info += r'Graphics Card{0}: {1} ({2})'.format(i, torch.cuda.get_device_properties(i).name,
toolkit.get_readable_size(
torch.cuda.get_device_properties(
i).total_memory)) + '\n'
sys_info += r'Available VRAM: {0}'.format(toolkit.get_readable_size(torch.cuda.mem_get_info(i)[0])) + '\n'
return sys_info

View File

@ -0,0 +1,630 @@
import datetime
import fileinput
import os
import platform
import re
import shutil
import sys
import time
import typing as t
import gradio as gr
import requests
import modules
from modules.sd_models import CheckpointInfo
from scripts.download.msai_downloader_manager import MiaoshouDownloaderManager
from scripts.logging.msai_logger import Logger
from scripts.msai_utils import msai_toolkit as toolkit
from scripts.runtime.msai_prelude import MiaoshouPrelude
class MiaoshouRuntime(object):
def __init__(self):
self.cmdline_args: t.List[str] = None
self.logger = Logger()
self.prelude = MiaoshouPrelude()
self._old_additional: str = None
self._model_set: t.List[t.Dict] = None
self._model_set_last_access_time: datetime.datetime = None
self._ds_models: gr.Dataset = None
self._allow_nsfw: bool = False
self._model_source: str = "civitai" # civitai is the default model source
# TODO: may be owned by downloader class
self.model_files = []
self.downloader_manager = MiaoshouDownloaderManager()
def get_default_args(self, commandline_args: t.List[str] = None):
if commandline_args is None:
commandline_args: t.List[str] = toolkit.get_args(sys.argv[1:])
self.cmdline_args = commandline_args
self.logger.info(f"default commandline args: {commandline_args}")
checkbox_values = []
additional_args = ""
saved_setting = self.prelude.boot_settings
gpu = saved_setting.get('drp_args_vram')
theme = saved_setting.get('drp_args_theme')
port = saved_setting.get('txt_args_listen_port')
for arg in commandline_args:
if 'theme' in arg:
theme = [k for k, v in self.prelude.theme_setting.items() if v == arg][0]
if 'port' in arg:
port = arg.split(' ')[-1]
for chk in self.prelude.checkboxes:
for arg in commandline_args:
if self.prelude.checkboxes[chk] == arg:
checkbox_values.append(chk)
gpu_arg_list = [f'--{i.strip()}' for i in ' '.join(list(self.prelude.gpu_setting.values())).split('--')]
for arg in commandline_args:
if 'port' not in arg \
and arg not in list(self.prelude.theme_setting.values()) \
and arg not in list(self.prelude.checkboxes.values()) \
and arg not in gpu_arg_list:
additional_args += (' ' + rf'{arg}')
self._old_additional = additional_args
webui_ver = saved_setting['drp_choose_version']
return gpu, theme, port, checkbox_values, additional_args.replace('\\', '\\\\').strip(), webui_ver
def add_arg(self, args: str = "") -> None:
for arg in args.split('--'):
self.logger.info(f'add arg: {arg.strip()}')
if f"--{arg.strip()}" not in self.cmdline_args and arg.strip() != '':
self.cmdline_args.append(f'--{arg.strip()}')
def remove_arg(self, args: str = "") -> None:
arg_keywords = ['port', 'theme']
for arg in args.split('--'):
if arg in arg_keywords:
for cmdl in self.cmdline_args:
if arg in cmdl:
self.cmdline_args.remove(cmdl)
break
elif f'--{arg.strip()}' in self.cmdline_args and arg.strip() != '':
print(f"remove args:{arg.strip()}")
self.cmdline_args.remove(f'--{arg.strip()}')
def get_final_args(self, gpu, theme, port, checkgroup, more_args) -> None:
# gpu settings
for s1 in self.prelude.gpu_setting:
if s1 in gpu:
for s2 in self.prelude.gpu_setting:
if s2 != s1:
self.remove_arg(self.prelude.gpu_setting[s2])
self.add_arg(self.prelude.gpu_setting[s1])
if port != '7860':
self.add_arg(f'--port {port}')
else:
self.remove_arg('--port')
# theme settings
self.remove_arg('--theme')
for t in self.prelude.theme_setting:
if t == theme:
self.add_arg(self.prelude.theme_setting[t])
break
# check box settings
for chked in checkgroup:
self.logger.info(f'checked:{self.prelude.checkboxes[chked]}')
self.add_arg(self.prelude.checkboxes[chked])
for unchk in list(set(list(self.prelude.checkboxes.keys())) - set(checkgroup)):
print(f'unchecked:{unchk}')
self.remove_arg(self.prelude.checkboxes[unchk])
# additional commandline settings
self.remove_arg(self._old_additional)
self.add_arg(more_args.replace('\\\\', '\\'))
self._old_additional = more_args.replace('\\\\', '\\')
def fetch_all_models(self) -> t.List[t.Dict]:
endpoint_url = self.prelude.api_url(self.model_source)
if endpoint_url is None:
self.logger.error(f"{self.model_source} is not supported")
return []
self.logger.info(f"start to fetch model info from '{self.model_source}':{endpoint_url}")
limit_threshold = 100
all_set = []
response = requests.get(endpoint_url + f'?page=1&limit={limit_threshold}')
num_of_pages = response.json()['metadata']['totalPages']
self.logger.info(f"total pages = {num_of_pages}")
continuous_error_counts = 0
for p in range(1, num_of_pages + 1):
try:
response = requests.get(endpoint_url + f'?page={p}&limit={limit_threshold}')
payload = response.json()
if payload.get("success") is not None and not payload.get("success"):
self.logger.error(f"failed to fetch page[{p + 1}]")
continuous_error_counts += 1
if continuous_error_counts > 10:
break
else:
continue
continuous_error_counts = 0 # reset error flag
self.logger.debug(f"start to process page[{p}]")
for model in payload['items']:
all_set.append(model)
self.logger.debug(f"page[{p}] : {len(payload['items'])} items added")
except Exception as e:
self.logger.error(f"failed to fetch page[{p + 1}] due to {e}")
time.sleep(3)
if len(all_set) > 0:
self.prelude.update_model_json(self.model_source, all_set)
else:
self.logger.error("fetch_all_models: emtpy body received")
return all_set
def refresh_all_models(self) -> None:
if self.fetch_all_models():
if self.ds_models:
self.ds_models.samples = self.model_set
self.ds_models.update(samples=self.model_set)
else:
self.logger.error(f"ds models is null")
def get_images_html(self, search: str = '', model_type: str = 'All') -> t.List[str]:
self.logger.info(f"get_image_html: model_type = {model_type}, and search pattern = '{search}'")
model_cover_thumbnails = []
model_format = []
if self.model_set is None:
self.logger.error("model_set is null")
return []
self.logger.info(f"{len(self.model_set)} items inside '{self.model_source}'")
search = search.lower()
for model in self.model_set:
try:
if model.get('type') is not None \
and model.get('type') not in model_format:
model_format.append(model['type'])
if search == '' or \
(model.get('name') is not None and search in model.get('name').lower()) \
or (model.get('description') is not None and search in model.get('description').lower()):
if (model_type == 'All' or model_type in model.get('type')) \
and (self.allow_nsfw or (not self.allow_nsfw and not model.get('nsfw'))):
model_cover_thumbnails.append([
[f"""
<div style="display: flex; align-items: center;">
<div id="{str(model.get('id'))}" style="margin-right: 10px;" class="model-item">
<img src="{model['modelVersions'][0]['images'][0]['url'].replace('width=450', 'width=100')}" style="width:100px;">
</div>
<div style="flex:1; width:100px;">
<h3 style="text-align:left; word-wrap:break-word;">{model.get('name')}</h3>
<p style="text-align:left;">Type: {model.get('type')}</p>
<p style="text-align:left;">Rating: {model.get('stats')['rating']}</p>
</div>
</div>
"""],
model['id']])
except Exception:
continue
return model_cover_thumbnails
# TODO: add typing hint
def update_boot_settings(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
boot_settings = self.prelude.boot_settings
boot_settings['drp_args_vram'] = drp_gpu
boot_settings["drp_args_theme"] = drp_theme
boot_settings['txt_args_listen_port'] = txt_listen_port
for chk in chk_group_args:
self.logger.debug(chk)
boot_settings[chk] = self.prelude.checkboxes[chk]
boot_settings['txt_args_more'] = additional_args
boot_settings['drp_choose_version'] = version
all_settings = self.prelude.all_settings
all_settings['boot_settings'] = boot_settings
toolkit.write_json(self.prelude.setting_file, all_settings)
def get_all_models(self, site: str) -> t.Any:
return toolkit.read_json(self.prelude.model_json[site])
def update_model_json(self, site: str, models: t.Any) -> None:
toolkit.write_json(self.prelude.model_json[site], models)
def get_hash_from_json(self, chk_point: CheckpointInfo) -> CheckpointInfo:
model_hashes = toolkit.read_json(self.prelude.model_hash_file)
if len(model_hashes) == 0 or chk_point.title not in model_hashes.keys():
chk_point.shorthash = chk_point.calculate_shorthash()
model_hashes[chk_point.title] = chk_point.shorthash
toolkit.write_json(self.prelude.model_hash_file, model_hashes)
else:
chk_point.shorthash = model_hashes[chk_point.title]
return chk_point
def get_local_models(self) -> t.List[t.Any]:
models = []
for file in modules.sd_models.checkpoint_tiles():
chkpt_info = modules.sd_models.get_closet_checkpoint_match(file)
if chkpt_info.sha256 is None and chkpt_info.shorthash is None:
chkpt_info = self.get_hash_from_json(chkpt_info)
fname = re.sub(r'\[.*?\]', "", chkpt_info.title)
model_info = self.search_model_by_hash(chkpt_info.sha256, chkpt_info.shorthash, fname)
if model_info is not None:
models.append(model_info)
else:
self.logger.info(
f"{chkpt_info.title}, {chkpt_info.hash}, {chkpt_info.shorthash}, {chkpt_info.sha256}")
models.append([
[f'<img src="file={os.path.join(modules.paths.script_path, "html", "card-no-preview.png")}" '
'style="width:100px;height:150px;">'],
[os.path.basename(fname)],
[fname],
[chkpt_info.shorthash],
[], [], []])
return models
def search_model_by_hash(self, lookup_sha256: str, lookup_shash: str, fname: str) -> t.Optional[t.List[t.Any]]:
self.logger.info(f"lookup_sha256: {lookup_sha256}, lookup_shash: {lookup_shash}, fname: {fname}")
res = None
if lookup_sha256 is None and lookup_shash is None:
return None
for model in self.model_set:
match = False
for ver in model['modelVersions']:
for file in ver['files']:
if lookup_sha256 is not None and 'SHA256' in file['hashes'].keys():
match = (lookup_sha256.upper() == file['hashes']['SHA256'].upper())
elif lookup_shash is not None:
match = (lookup_shash[:10].upper() in [h.upper() for h in file['hashes'].values()])
if match:
cover_link = ver['images'][0]['url'].replace('width=450', 'width=100')
mid = model['id']
res = [
[
f'<a href="https://civitai.com/models/{mid}" target="_blank"><img src="{cover_link}"></a>'],
[f"{model['name']}/{ver['name']}"],
[fname],
[lookup_shash],
[model['creator']['username']],
[model['type']],
[model['nsfw']],
[ver['trainedWords']],
]
if match:
break
return res
def update_xformers(self, gpu, checkgroup):
if '--xformers' in self.prelude.gpu_setting[gpu]:
if 'Enable xFormers' not in checkgroup:
checkgroup.append('Enable xFormers')
return checkgroup
def set_nsfw(self, search='', nsfw_checker=False, model_type='All') -> t.Dict:
self._allow_nsfw = nsfw_checker
new_list = self.get_images_html(search, model_type)
if self._ds_models is None:
self.logger.error(f"_ds_models is not initialized")
return {}
self._ds_models.samples = new_list
return self._ds_models.update(samples=new_list)
def search_model(self, search='', model_type='All') -> t.Dict:
if self._ds_models is None:
self.logger.error(f"_ds_models is not initialized")
return {}
new_list = self.get_images_html(search, model_type)
self._ds_models.samples = new_list
return self._ds_models.update(samples=new_list)
def get_model_info(self, models) -> t.Tuple[t.List[t.List[str]], t.Dict, str, t.Dict]:
drop_list = []
cover_imgs = []
htmlDetail = "<div><p>Empty</p></div>"
mid = models[1]
# TODO: use map to enhance the performances
m = [e for e in self.model_set if e['id'] == mid][0]
self.model_files.clear()
download_url_by_default = None
if m and m.get('modelVersions') and len(m.get('modelVersions')) > 0:
latest_version = m['modelVersions'][0]
if latest_version.get('images') and isinstance(latest_version.get('images'), list):
for img in latest_version['images']:
if self.allow_nsfw or (not self.allow_nsfw and not img.get('nsfw')):
if img.get('url'):
cover_imgs.append([img['url'], ''])
if latest_version.get('files') and isinstance(latest_version.get('files'), list):
for file in latest_version['files']:
# error checking for mandatory fields
if file.get('id') is not None and file.get('downloadUrl') is not None:
item_name = None
if file.get('name'):
item_name = file.get('name')
if not item_name and latest_version.get('name'):
item_name = latest_version['name']
if not item_name:
item_name = "unknown"
self.model_files.append({
"id:": file['id'],
"url": file['downloadUrl'],
"name": item_name,
"type": m['type'] if m.get('type') else "unknown",
"size": file['sizeKB'] * 1024 if file.get('sizeKB') else "unknown",
"format": file['format'] if file.get('format') else "unknown",
"cover": cover_imgs[0][0] if len(cover_imgs) > 0 else toolkit.get_not_found_image_url(),
})
file_size = toolkit.get_readable_size(file['sizeKB'] * 1024) if file.get('sizeKB') else ""
if file_size:
drop_list.append(f"{item_name} ({file_size})")
else:
drop_list.append(f"{item_name}")
if not download_url_by_default:
download_url_by_default = file.get('downloadUrl')
htmlDetail = '<div>'
if m.get('name'):
htmlDetail += f"<h1>{m['name']}</h1></br>"
if m.get('stats') and m.get('stats').get('downloadCount'):
htmlDetail += f"<p>Downloads: {m['stats']['downloadCount']}</p>"
if m.get('stats') and m.get('stats').get('rating'):
htmlDetail += f"<p>Rating: {m['stats']['rating']}</p>"
if m.get('creator') and m.get('creator').get('username'):
htmlDetail += f"<p>Author: {m['creator']['username']}</p></div></br></br>"
if latest_version.get('name'):
htmlDetail += f"<div><table><tbody><tr><td>Version:</td><td>{latest_version['name']}</td></tr>"
if latest_version.get('updatedAt'):
htmlDetail += f"<tr><td>Updated Time:</td><td>{latest_version['updatedAt']}</td></tr>"
if m.get('type'):
htmlDetail += f"<tr><td>Type:</td><td>{m['type']}</td></tr>"
if latest_version.get('baseModel'):
htmlDetail += f"<tr><td>Base Model:</td><td>{latest_version['baseModel']}</td></tr>"
htmlDetail += f"<tr><td>NFSW:</td><td>{m.get('nsfw') if m.get('nsfw') is not None else 'false'}</td></tr>"
if m.get('tags') and isinstance(m.get('tags'), list):
htmlDetail += f"<tr><td>Tags:</td><td>"
for t in m['tags']:
htmlDetail += f"<span>{t}</span>"
htmlDetail += "</td></tr>"
if latest_version.get('trainedWords'):
htmlDetail += f"<tr><td>Trigger Words:</td><td>"
for t in latest_version['trainedWords']:
htmlDetail += f"<span>{t}</span>"
htmlDetail += "</td></tr>"
htmlDetail += "</tbody></table></div>"
htmlDetail += f"<div>{m['description'] if m.get('description') else 'N/A'}</div>"
return (
cover_imgs,
gr.Dropdown.update(choices=drop_list, value=drop_list[0] if len(drop_list) > 0 else []),
htmlDetail,
gr.HTML.update(value=f'<p style="text-align: center;">'
f'<a style="text-align: center;" href="{download_url_by_default}" '
'target="_blank">Download</a></p>')
)
def get_downloading_status(self):
(_, _, desc) = self.downloader_manager.tasks_summary()
return gr.HTML.update(value=desc)
def download_model(self, filename: str):
model_path = modules.paths.models_path
script_path = modules.paths.script_path
urls = []
for _, f in enumerate(self.model_files):
if not f.get('name'):
continue
model_fname = re.sub(r"\s*\(\d+(?:\.\d*)?.B\)\s*$", "", f['name'])
if model_fname in filename:
m_pre, m_ext = os.path.splitext(model_fname)
cover_fname = f"{m_pre}.jpg"
if f['type'] == 'LORA':
cover_fname = os.path.join(model_path, 'Lora', cover_fname)
model_fname = os.path.join(model_path, 'Lora', model_fname)
elif f['format'] == 'VAE':
cover_fname = os.path.join(model_path, 'VAE', cover_fname)
model_fname = os.path.join(model_path, 'VAE', model_fname)
elif f['format'] == 'TextualInversion':
cover_fname = os.path.join(script_path, 'embeddings', cover_fname)
model_fname = os.path.join(script_path, 'embeddings', model_fname)
elif f['format'] == 'Hypernetwork':
cover_fname = os.path.join(model_path, 'hypernetworks', cover_fname)
model_fname = os.path.join(model_path, 'hypernetworks', model_fname)
else:
cover_fname = os.path.join(model_path, 'Stable-diffusion', cover_fname)
model_fname = os.path.join(model_path, 'Stable-diffusion', model_fname)
urls.append((f['cover'], f['url'], f['size'], cover_fname, model_fname))
break
for (cover_url, model_url, total_size, local_cover_name, local_model_name) in urls:
self.downloader_manager.download(
source_url=cover_url,
target_file=local_cover_name,
estimated_total_size=None,
)
self.downloader_manager.download(
source_url=model_url,
target_file=local_model_name,
estimated_total_size=total_size,
)
#
# currently, web-ui is without queue enabled.
#
# webui_queue_enabled = False
# if webui_queue_enabled:
# start = time.time()
# downloading_tasks_iter = self.downloader_manager.iterator()
# for i in progressbar.tqdm(range(100), unit="byte", desc="Models Downloading"):
# while True:
# try:
# finished_bytes, total_bytes = next(downloading_tasks_iter)
# v = finished_bytes / total_bytes
# print(f"\n v = {v}")
# if isinstance(v, float) and int(v * 100) < i:
# print(f"\nv({v}) < {i}")
# continue
# else:
# break
# except StopIteration:
# break
#
# time.sleep(0.5)
#
# self.logger.info(f"[downloading] finished after {time.time() - start} secs")
time.sleep(2)
self.model_files.clear()
return gr.HTML.update(value=f"<h4>{len(urls)} downloading tasks added into task list</h4>")
def change_boot_setting(self, version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args):
self.get_final_args(drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args)
self.logger.info(f'saved_cmd: {self.cmdline_args}')
target_webui_user_file = "webui-user.bat"
script_export_keyword = "export"
if platform.system() == "Linux":
target_webui_user_file = "webui-user.sh"
elif platform.system() == "Darwin":
target_webui_user_file = "webui-macos-env.sh"
else:
script_export_keyword = "set"
filepath = os.path.join(modules.shared.script_path, target_webui_user_file)
self.logger.info(f"to update: {filepath}")
msg = 'Result: Setting Saved.'
if version == 'Official Release':
try:
if not os.path.exists(filepath):
shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath)
with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file:
for line in file:
if 'COMMANDLINE_ARGS' in line:
rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\')
line = f'{script_export_keyword} COMMANDLINE_ARGS="{rep_txt}"\n'
sys.stdout.write(line)
except Exception as e:
msg = f'Error: {str(e)}'
else:
try:
if not os.path.exists(filepath):
shutil.copyfile(os.path.join(self.prelude.ext_folder, 'configs', target_webui_user_file), filepath)
with fileinput.FileInput(filepath, inplace=True, backup='.bak') as file:
for line in file:
if 'webui.py' in line:
rep_txt = ' '.join(self.cmdline_args).replace('\\', '\\\\')
line = f"python\python.exe webui.py {rep_txt}\n"
sys.stdout.write(line)
except Exception as e:
msg = f'Error: {str(e)}'
self.update_boot_settings(version, drp_gpu, drp_theme, txt_listen_port, chk_group_args, additional_args)
return gr.update(value=msg, visible=True)
@property
def model_set(self) -> t.List[t.Dict]:
try:
self.logger.info(f"access to model info for '{self.model_source}'")
model_json_mtime = toolkit.get_file_last_modified_time(self.prelude.model_json[self.model_source])
if self._model_set is None or self._model_set_last_access_time is None \
or self._model_set_last_access_time < model_json_mtime:
self._model_set = self.get_all_models(self.model_source)
self._model_set_last_access_time = model_json_mtime
self.logger.info(f"load '{self.model_source}' model data from local file")
except Exception as e:
self._model_set = self.fetch_all_models()
self._model_set_last_access_time = datetime.datetime.now()
return self._model_set
@property
def allow_nsfw(self) -> bool:
return self._allow_nsfw
@property
def old_additional_args(self) -> str:
return self._old_additional
@property
def ds_models(self) -> gr.Dataset:
return self._ds_models
@ds_models.setter
def ds_models(self, newone: gr.Dataset):
self._ds_models = newone
@property
def model_source(self) -> str:
return self._model_source
@model_source.setter
def model_source(self, newone: str):
self.logger.info(f"model source changes from {self.model_source} to {newone}")
self._model_source = newone
self._model_set_last_access_time = None # reset timestamp
def introception(self) -> None:
(gpu, theme, port, checkbox_values, extra_args, ver) = self.get_default_args()
print("################################################################")
print("MIAOSHOU ASSISTANT ARGUMENTS:")
print(f" gpu = {gpu}")
print(f" theme = {theme}")
print(f" port = {port}")
print(f" checkbox_values = {checkbox_values}")
print(f" extra_args = {extra_args}")
print(f" webui ver = {ver}")
print("################################################################")