sd-webui-infinite-image-bro.../scripts/setup.py

289 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
import modules.scripts as scripts
import gradio as gr
import re
import subprocess
import logging
import uuid
import asyncio
import subprocess
from modules import script_callbacks, shared
from typing import List,Dict,Union
from modules.shared import opts
from scripts.log_parser import parse_log_line
from scripts.bin import download_bin_file, get_matched_summary, check_bin_exists,cwd, bin_file_path, bin_file_name
# 创建logger对象设置日志级别为DEBUG
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# 创建控制台输出的handler设置日志级别为INFO
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建文件输出的handler设置日志级别为DEBUG
file_handler = logging.FileHandler(f"{cwd}/log.log")
file_handler.setLevel(logging.DEBUG)
# 定义handler的日志格式
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
# 将handler添加到logger对象中
logger.addHandler(console_handler)
logger.addHandler(file_handler)
def get_global_conf():
default_conf = get_default_conf()
return {
"output_dirs": opts.data.get("baidu_netdisk_output_dirs") or default_conf.get("output_dirs"),
"upload_dir": opts.data.get("baidu_netdisk_upload_dir") or default_conf.get("upload_dir"),
}
def exec_ops(args: Union[List[str],str]):
args = [args] if isinstance(args, str) else args
res = ""
if check_bin_exists():
result = subprocess.run([bin_file_path, *args], capture_output=True)
try:
res = result.stdout.decode().strip()
except UnicodeDecodeError:
res = result.stdout.decode("gbk", errors="ignore").strip()
logger.info(res)
return res
def login_by_bduss(bduss: str):
output = exec_ops(["login", f"-bduss={bduss}"])
match = re.search("百度帐号登录成功: (.+)$", output)
if match:
return (True, match.group(1).strip())
else:
print(output)
def get_curr_working_dir():
return exec_ops("pwd")
def list_file(cwd="/"):
exec_ops(["cd", cwd])
output = exec_ops("ls")
pattern = (
r"\s+(\d+)\s+(\d+\.\d+\w+)\s+(\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2})\s+(.*)"
)
files = []
for line in output.split("\n"):
match = re.match(pattern, line)
if match:
file_info = {
"id": int(match.group(1)),
"size": match.group(2).strip(),
"date": match.group(3).strip(),
"name": match.group(4).strip(),
}
files.append(file_info)
# 打印解析结果
for file in files:
print(file)
return file
def get_curr_user():
match = re.search(
r"uid:\s*(\d+), 用户名:\s*(\w+),",
exec_ops("who"),
)
if not match:
return
uid = match.group(1)
if int(uid) == 0:
return
username = match.group(2)
return {"uid": uid, "username": username}
def logout():
match = re.search("退出用户成功", exec_ops(["logout", "-y"]))
return bool(match)
def get_curr_user_name():
res = get_curr_user()
return res["username"] if res else "未登录"
not_exists_msg = f"找不到{bin_file_name},尝试手动从 {get_matched_summary()[1]} 下载,下载后放到 {cwd} 文件夹下,重启界面"
def upload_file_to_baidu_net_disk(pre_log):
conf = get_global_conf()
dirs = str(conf["output_dirs"]).split(",")
print(["upload", *dirs, conf["upload_dir"]])
return exec_ops(["upload", *dirs, conf["upload_dir"]])
def on_ui_tabs():
exists = check_bin_exists()
if not exists:
try:
print("缺少必要的二进制文件,开始下载")
download_bin_file()
print("done")
except Exception as e:
print("下载二进制文件时出错:", str(e))
exists = check_bin_exists()
if not exists:
print(f"\033[31m{not_exists_msg}\033[0m")
user = get_curr_user()
with gr.Blocks(analytics_enabled=False) as baidu_netdisk:
gr.Textbox(not_exists_msg, visible=not exists)
with gr.Row(visible=bool(exists and not user)) as login_form:
bduss_input = gr.Textbox(interactive=True, label="输入bduss,完成后回车登录")
with gr.Row(visible=bool(exists and user)) as operation_form:
with gr.Column(scale=2):
logout_btn = gr.Button("登出账户")
with gr.Column(scale=8):
log_text = gr.HTML("如果你看到这个那说明此项那说明出现了问题", elem_id="baidu_netdisk_container_wrapper"
)
def on_bduss_input_enter(bduss):
res = login_by_bduss(bduss=bduss)
return (
f"登陆成功{res[1]}" if res else "登录失败",
gr.update(visible=bool(res)),
gr.update(visible=not res),
)
bduss_input.submit(
on_bduss_input_enter,
inputs=[bduss_input],
outputs=[log_text, operation_form, login_form],
)
def on_logout():
logout()
return gr.update(visible=True), gr.update(visible=False)
logout_btn.click(fn=on_logout, outputs=[login_form, operation_form])
return ((baidu_netdisk, "百度云上传", "baiduyun"),)
def get_default_conf():
conf_g = opts.data
outputs_dirs = ",".join(
list(
filter(
bool,
[
conf_g["outdir_samples"],
conf_g["outdir_txt2img_samples"],
conf_g["outdir_img2img_samples"],
conf_g["outdir_extras_samples"],
conf_g["outdir_grids"],
conf_g["outdir_txt2img_grids"],
conf_g["outdir_txt2img_grids"],
conf_g["outdir_save"],
],
)
)
)
upload_dir = "/stable-diffusion-upload"
return {
"output_dirs": outputs_dirs,
"upload_dir": upload_dir,
}
def on_ui_settings():
bd_options = []
default_conf = get_default_conf()
bd_options.append(
("baidu_netdisk_output_dirs", default_conf["output_dirs"], "上传的本地文件夹列表,多个文件夹使用逗号分隔")
)
bd_options.append(
("baidu_netdisk_upload_dir", default_conf["upload_dir"], "百度网盘用于接收上传文件的文件夹地址")
)
section = ("baidu-netdisk", "百度云上传")
# Move historic setting names to current names
for i in range(len(bd_options)):
shared.opts.add_option(
bd_options[i][0],
shared.OptionInfo(
bd_options[i][1],
bd_options[i][2],
section=section,
),
)
subprocess_cache: Dict[str, asyncio.subprocess.Process] = {}
def baidu_netdisk_api(_: gr.Blocks, app: FastAPI):
pre = "/baidu_netdisk/"
app.mount(f"{pre}fe-static", StaticFiles(directory=f"{cwd}/vue/dist"), name="baidu_netdisk-fe-static")
@app.get(f"{pre}hello")
async def greeting():
return "hello"
@app.post(f"{pre}upload")
async def upload():
id = str(uuid.uuid4())
conf = get_global_conf()
dirs = str(conf["output_dirs"]).split(",")
process = await asyncio.create_subprocess_exec(
bin_file_path,
"upload",
*dirs,
conf["upload_dir"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
subprocess_cache[id] = process
return {"id": id}
@app.get(pre + "upload/status/{id}")
async def upload_poll(id):
p = subprocess_cache.get(id)
if not p:
raise HTTPException(status_code=404, detail="找不到该subprocess")
running = not isinstance(p.returncode, int)
tasks = []
while True:
try:
line = await asyncio.wait_for(p.stdout.readline(), timeout=0.1)
line = line.decode()
# logger.info(line)
if not line:
#logger.error(line)
break
if line.isspace():
continue
info = parse_log_line(line)
#if info is None:
#logger.error(line)
tasks.append({"info": info, "log": line})
except asyncio.TimeoutError:
break
return {"running": running, "tasks": tasks, "pCode": p.returncode}
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)
script_callbacks.on_app_started(baidu_netdisk_api)