289 lines
8.8 KiB
Python
289 lines
8.8 KiB
Python
from fastapi import FastAPI, HTTPException
|
||
from fastapi.staticfiles import StaticFiles
|
||
import modules.scripts as scripts
|
||
import gradio as gr
|
||
import re
|
||
import subprocess
|
||
import logging
|
||
import uuid
|
||
import asyncio
|
||
import subprocess
|
||
from modules import script_callbacks, shared
|
||
from typing import List,Dict,Union
|
||
from modules.shared import opts
|
||
from scripts.log_parser import parse_log_line
|
||
from scripts.bin import download_bin_file, get_matched_summary, check_bin_exists,cwd, bin_file_path, bin_file_name
|
||
|
||
|
||
# 创建logger对象,设置日志级别为DEBUG
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.DEBUG)
|
||
|
||
# 创建控制台输出的handler,设置日志级别为INFO
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setLevel(logging.INFO)
|
||
|
||
# 创建文件输出的handler,设置日志级别为DEBUG
|
||
file_handler = logging.FileHandler(f"{cwd}/log.log")
|
||
file_handler.setLevel(logging.DEBUG)
|
||
|
||
# 定义handler的日志格式
|
||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||
console_handler.setFormatter(formatter)
|
||
file_handler.setFormatter(formatter)
|
||
|
||
# 将handler添加到logger对象中
|
||
logger.addHandler(console_handler)
|
||
logger.addHandler(file_handler)
|
||
|
||
|
||
def get_global_conf():
|
||
default_conf = get_default_conf()
|
||
return {
|
||
"output_dirs": opts.data.get("baidu_netdisk_output_dirs") or default_conf.get("output_dirs"),
|
||
"upload_dir": opts.data.get("baidu_netdisk_upload_dir") or default_conf.get("upload_dir"),
|
||
}
|
||
|
||
|
||
|
||
|
||
def exec_ops(args: Union[List[str],str]):
|
||
args = [args] if isinstance(args, str) else args
|
||
res = ""
|
||
if check_bin_exists():
|
||
result = subprocess.run([bin_file_path, *args], capture_output=True)
|
||
try:
|
||
res = result.stdout.decode().strip()
|
||
except UnicodeDecodeError:
|
||
res = result.stdout.decode("gbk", errors="ignore").strip()
|
||
logger.info(res)
|
||
return res
|
||
|
||
|
||
def login_by_bduss(bduss: str):
|
||
output = exec_ops(["login", f"-bduss={bduss}"])
|
||
match = re.search("百度帐号登录成功: (.+)$", output)
|
||
if match:
|
||
return (True, match.group(1).strip())
|
||
else:
|
||
print(output)
|
||
|
||
|
||
def get_curr_working_dir():
|
||
return exec_ops("pwd")
|
||
|
||
|
||
def list_file(cwd="/"):
|
||
exec_ops(["cd", cwd])
|
||
output = exec_ops("ls")
|
||
pattern = (
|
||
r"\s+(\d+)\s+(\d+\.\d+\w+)\s+(\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2})\s+(.*)"
|
||
)
|
||
files = []
|
||
for line in output.split("\n"):
|
||
match = re.match(pattern, line)
|
||
if match:
|
||
file_info = {
|
||
"id": int(match.group(1)),
|
||
"size": match.group(2).strip(),
|
||
"date": match.group(3).strip(),
|
||
"name": match.group(4).strip(),
|
||
}
|
||
files.append(file_info)
|
||
|
||
# 打印解析结果
|
||
for file in files:
|
||
print(file)
|
||
return file
|
||
|
||
|
||
def get_curr_user():
|
||
match = re.search(
|
||
r"uid:\s*(\d+), 用户名:\s*(\w+),",
|
||
exec_ops("who"),
|
||
)
|
||
if not match:
|
||
return
|
||
uid = match.group(1)
|
||
if int(uid) == 0:
|
||
return
|
||
username = match.group(2)
|
||
return {"uid": uid, "username": username}
|
||
|
||
|
||
def logout():
|
||
match = re.search("退出用户成功", exec_ops(["logout", "-y"]))
|
||
return bool(match)
|
||
|
||
|
||
def get_curr_user_name():
|
||
res = get_curr_user()
|
||
return res["username"] if res else "未登录"
|
||
|
||
|
||
not_exists_msg = f"找不到{bin_file_name},尝试手动从 {get_matched_summary()[1]} 下载,下载后放到 {cwd} 文件夹下,重启界面"
|
||
|
||
|
||
def upload_file_to_baidu_net_disk(pre_log):
|
||
conf = get_global_conf()
|
||
dirs = str(conf["output_dirs"]).split(",")
|
||
print(["upload", *dirs, conf["upload_dir"]])
|
||
|
||
return exec_ops(["upload", *dirs, conf["upload_dir"]])
|
||
|
||
|
||
def on_ui_tabs():
|
||
exists = check_bin_exists()
|
||
if not exists:
|
||
try:
|
||
print("缺少必要的二进制文件,开始下载")
|
||
download_bin_file()
|
||
print("done")
|
||
except Exception as e:
|
||
print("下载二进制文件时出错:", str(e))
|
||
exists = check_bin_exists()
|
||
if not exists:
|
||
print(f"\033[31m{not_exists_msg}\033[0m")
|
||
user = get_curr_user()
|
||
with gr.Blocks(analytics_enabled=False) as baidu_netdisk:
|
||
gr.Textbox(not_exists_msg, visible=not exists)
|
||
with gr.Row(visible=bool(exists and not user)) as login_form:
|
||
bduss_input = gr.Textbox(interactive=True, label="输入bduss,完成后回车登录")
|
||
with gr.Row(visible=bool(exists and user)) as operation_form:
|
||
with gr.Column(scale=2):
|
||
logout_btn = gr.Button("登出账户")
|
||
with gr.Column(scale=8):
|
||
log_text = gr.HTML("如果你看到这个那说明此项那说明出现了问题", elem_id="baidu_netdisk_container_wrapper"
|
||
)
|
||
|
||
def on_bduss_input_enter(bduss):
|
||
res = login_by_bduss(bduss=bduss)
|
||
return (
|
||
f"登陆成功{res[1]}" if res else "登录失败",
|
||
gr.update(visible=bool(res)),
|
||
gr.update(visible=not res),
|
||
)
|
||
|
||
bduss_input.submit(
|
||
on_bduss_input_enter,
|
||
inputs=[bduss_input],
|
||
outputs=[log_text, operation_form, login_form],
|
||
)
|
||
|
||
def on_logout():
|
||
logout()
|
||
return gr.update(visible=True), gr.update(visible=False)
|
||
|
||
logout_btn.click(fn=on_logout, outputs=[login_form, operation_form])
|
||
return ((baidu_netdisk, "百度云上传", "baiduyun"),)
|
||
|
||
def get_default_conf():
|
||
conf_g = opts.data
|
||
outputs_dirs = ",".join(
|
||
list(
|
||
filter(
|
||
bool,
|
||
[
|
||
conf_g["outdir_samples"],
|
||
conf_g["outdir_txt2img_samples"],
|
||
conf_g["outdir_img2img_samples"],
|
||
conf_g["outdir_extras_samples"],
|
||
conf_g["outdir_grids"],
|
||
conf_g["outdir_txt2img_grids"],
|
||
conf_g["outdir_txt2img_grids"],
|
||
conf_g["outdir_save"],
|
||
],
|
||
)
|
||
)
|
||
)
|
||
upload_dir = "/stable-diffusion-upload"
|
||
return {
|
||
"output_dirs": outputs_dirs,
|
||
"upload_dir": upload_dir,
|
||
}
|
||
|
||
|
||
|
||
def on_ui_settings():
|
||
bd_options = []
|
||
default_conf = get_default_conf()
|
||
bd_options.append(
|
||
("baidu_netdisk_output_dirs", default_conf["output_dirs"], "上传的本地文件夹列表,多个文件夹使用逗号分隔")
|
||
)
|
||
bd_options.append(
|
||
("baidu_netdisk_upload_dir", default_conf["upload_dir"], "百度网盘用于接收上传文件的文件夹地址")
|
||
)
|
||
|
||
section = ("baidu-netdisk", "百度云上传")
|
||
# Move historic setting names to current names
|
||
for i in range(len(bd_options)):
|
||
shared.opts.add_option(
|
||
bd_options[i][0],
|
||
shared.OptionInfo(
|
||
bd_options[i][1],
|
||
bd_options[i][2],
|
||
section=section,
|
||
),
|
||
)
|
||
|
||
|
||
subprocess_cache: Dict[str, asyncio.subprocess.Process] = {}
|
||
|
||
|
||
def baidu_netdisk_api(_: gr.Blocks, app: FastAPI):
|
||
pre = "/baidu_netdisk/"
|
||
app.mount(f"{pre}fe-static", StaticFiles(directory=f"{cwd}/vue/dist"), name="baidu_netdisk-fe-static")
|
||
|
||
@app.get(f"{pre}hello")
|
||
async def greeting():
|
||
return "hello"
|
||
|
||
@app.post(f"{pre}upload")
|
||
async def upload():
|
||
id = str(uuid.uuid4())
|
||
conf = get_global_conf()
|
||
dirs = str(conf["output_dirs"]).split(",")
|
||
|
||
process = await asyncio.create_subprocess_exec(
|
||
bin_file_path,
|
||
"upload",
|
||
*dirs,
|
||
conf["upload_dir"],
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
)
|
||
subprocess_cache[id] = process
|
||
return {"id": id}
|
||
|
||
@app.get(pre + "upload/status/{id}")
|
||
async def upload_poll(id):
|
||
p = subprocess_cache.get(id)
|
||
if not p:
|
||
raise HTTPException(status_code=404, detail="找不到该subprocess")
|
||
running = not isinstance(p.returncode, int)
|
||
tasks = []
|
||
|
||
while True:
|
||
try:
|
||
line = await asyncio.wait_for(p.stdout.readline(), timeout=0.1)
|
||
line = line.decode()
|
||
# logger.info(line)
|
||
if not line:
|
||
#logger.error(line)
|
||
break
|
||
if line.isspace():
|
||
continue
|
||
info = parse_log_line(line)
|
||
#if info is None:
|
||
#logger.error(line)
|
||
tasks.append({"info": info, "log": line})
|
||
except asyncio.TimeoutError:
|
||
break
|
||
return {"running": running, "tasks": tasks, "pCode": p.returncode}
|
||
|
||
|
||
script_callbacks.on_ui_settings(on_ui_settings)
|
||
|
||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||
script_callbacks.on_app_started(baidu_netdisk_api)
|