sd-webui-infinite-image-bro.../scripts/setup.py

294 lines
8.4 KiB
Python

from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
import modules.scripts as scripts
import gradio as gr
import re
import subprocess
import uuid
import asyncio
import subprocess
from modules import script_callbacks, shared
from typing import List, Dict, Literal, Union
from modules.shared import opts
from scripts.baiduyun_task import BaiduyunTask
import datetime
from pydantic import BaseModel
from scripts.log_parser import parse_log_line
from scripts.bin import (
download_bin_file,
get_matched_summary,
check_bin_exists,
cwd,
bin_file_path,
bin_file_name,
)
import functools
from scripts.logger import logger
def get_global_conf():
default_conf = get_default_conf()
return {
"output_dirs": opts.data.get("baidu_netdisk_output_dirs")
or default_conf.get("output_dirs"),
"upload_dir": opts.data.get("baidu_netdisk_upload_dir")
or default_conf.get("upload_dir"),
}
def exec_ops(args: Union[List[str], str]):
args = [args] if isinstance(args, str) else args
res = ""
if check_bin_exists():
result = subprocess.run([bin_file_path, *args], capture_output=True)
try:
res = result.stdout.decode().strip()
except UnicodeDecodeError:
res = result.stdout.decode("gbk", errors="ignore").strip()
logger.info(res)
return res
def login_by_bduss(bduss: str):
output = exec_ops(["login", f"-bduss={bduss}"])
match = re.search("百度帐号登录成功: (.+)$", output)
if match:
return {"status": "ok", "msg": match.group(1).strip()}
else:
return {"status": "error", "msg": output}
def get_curr_working_dir():
return exec_ops("pwd")
def list_file(cwd="/"):
exec_ops(["cd", cwd])
output = exec_ops("ls")
pattern = (
r"\s+(\d+)\s+(\d+\.\d+\w+)\s+(\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2})\s+(.*)"
)
files = []
for line in output.split("\n"):
match = re.match(pattern, line)
if match:
file_info = {
"id": int(match.group(1)),
"size": match.group(2).strip(),
"date": match.group(3).strip(),
"name": match.group(4).strip(),
}
files.append(file_info)
return files
def get_curr_user():
match = re.search(
r"uid:\s*(\d+), 用户名:\s*(\w+),",
exec_ops("who"),
)
if not match:
return
uid = match.group(1)
if int(uid) == 0:
return
username = match.group(2)
return {"uid": uid, "username": username}
def logout():
match = re.search("退出用户成功", exec_ops(["logout", "-y"]))
return bool(match)
not_exists_msg = (
f"找不到{bin_file_name},尝试手动从 {get_matched_summary()[1]} 下载,下载后放到 {cwd} 文件夹下,重启界面"
)
def on_ui_tabs():
exists = check_bin_exists()
if not exists:
try:
print("缺少必要的二进制文件,开始下载")
download_bin_file()
print("done")
except Exception as e:
print("下载二进制文件时出错:", str(e))
exists = check_bin_exists()
if not exists:
print(f"\033[31m{not_exists_msg}\033[0m")
with gr.Blocks(analytics_enabled=False) as baidu_netdisk:
gr.Textbox(not_exists_msg, visible=not exists)
with gr.Row(visible=bool(exists)):
with gr.Column():
gr.HTML(
"如果你看到这个那说明此项那说明出现了问题", elem_id="baidu_netdisk_container_wrapper"
)
return ((baidu_netdisk, "百度云上传", "baiduyun"),)
def get_default_conf():
conf_g = opts.data
outputs_dirs = ",".join(
list(
filter(
bool,
[
conf_g["outdir_samples"],
conf_g["outdir_txt2img_samples"],
conf_g["outdir_img2img_samples"],
conf_g["outdir_extras_samples"],
conf_g["outdir_grids"],
conf_g["outdir_txt2img_grids"],
conf_g["outdir_txt2img_grids"],
conf_g["outdir_save"],
],
)
)
)
upload_dir = "/stable-diffusion-upload"
return {
"output_dirs": outputs_dirs,
"upload_dir": upload_dir,
}
def on_ui_settings():
bd_options = []
default_conf = get_default_conf()
bd_options.append(
(
"baidu_netdisk_output_dirs",
default_conf["output_dirs"],
"上传的本地文件夹列表,多个文件夹使用逗号分隔",
)
)
bd_options.append(
("baidu_netdisk_upload_dir", default_conf["upload_dir"], "百度网盘用于接收上传文件的文件夹地址")
)
section = ("baidu-netdisk", "百度云上传")
# Move historic setting names to current names
for i in range(len(bd_options)):
shared.opts.add_option(
bd_options[i][0],
shared.OptionInfo(
bd_options[i][1],
bd_options[i][2],
section=section,
),
)
def singleton_async(fn):
@functools.wraps(fn)
async def wrapper(*args, **kwargs):
key = args[0] if len(args) > 0 else None
if key in wrapper.busy:
raise Exception("Function is busy, please try again later.")
wrapper.busy.append(key)
try:
return await fn(*args, **kwargs)
finally:
wrapper.busy.remove(key)
wrapper.busy = []
return wrapper
def baidu_netdisk_api(_: gr.Blocks, app: FastAPI):
pre = "/baidu_netdisk"
app.mount(
f"{pre}/fe-static",
StaticFiles(directory=f"{cwd}/vue/dist"),
name="baidu_netdisk-fe-static",
)
@app.get(f"{pre}/user")
async def user():
return get_curr_user()
@app.post(f"{pre}/user/logout")
async def user_logout():
return logout()
class BaiduyunUserLoginReq(BaseModel):
bduss: str
@app.post(f"{pre}/user/login")
async def user_login(req: BaiduyunUserLoginReq):
res = login_by_bduss(req.bduss)
if res["status"] != "ok":
raise HTTPException(status_code=401, detail=res["msg"])
return get_curr_user()
@app.get(f"{pre}/hello")
async def greeting():
return "hello"
@app.get(f"{pre}/global_setting")
async def global_setting():
return {"global_setting": opts.data, "default_conf": get_default_conf()}
class BaiduyunUploadDownloadReq(BaseModel):
type: Literal["upload", "download"]
send_dirs: str
recv_dir: str
@app.post(f"{pre}/task")
async def upload(req: BaiduyunUploadDownloadReq):
task = await BaiduyunTask.create(**req.dict())
return {"id": task.id}
@app.get(f"{pre}/tasks")
async def upload_tasks():
tasks = []
for key in BaiduyunTask.get_cache():
task = BaiduyunTask.get_by_id(key)
task.update_state()
tasks.append(task.get_summary())
return {"tasks": list(reversed(tasks))}
@app.get(pre + "/task/{id}/files_state")
async def task_files_state(id):
p = BaiduyunTask.get_by_id(id)
if not p:
raise HTTPException(status_code=404, detail="找不到该上传任务")
return {"files_state": p.files_state}
@app.post(pre + "/task/{id}/cancel")
async def cancel_task(id):
p = BaiduyunTask.get_by_id(id)
if not p:
raise HTTPException(status_code=404, detail="找不到该上传任务")
last_tick = await p.cancel()
return {"last_tick": last_tick}
upload_poll_promise_dict = {}
@app.get(pre + "/task/{id}/tick")
async def upload_poll(id):
async def get_tick_sync_wait_wrapper():
task = BaiduyunTask.get_by_id(id)
if not task:
raise HTTPException(status_code=404, detail="找不到该上传任务")
return await task.get_tick()
res = upload_poll_promise_dict.get(id)
if res:
res = await res
else:
upload_poll_promise_dict[id] = asyncio.create_task(
get_tick_sync_wait_wrapper()
)
res = await upload_poll_promise_dict[id]
upload_poll_promise_dict.pop(id)
return res
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)
script_callbacks.on_app_started(baidu_netdisk_api)