import os import time from scripts.tool import human_readable_size from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles import modules.scripts as scripts import gradio as gr import re import subprocess import uuid import asyncio import subprocess from modules import script_callbacks, shared from typing import List, Dict, Literal, Union from modules.shared import opts from scripts.baiduyun_task import BaiduyunTask import datetime from pydantic import BaseModel from scripts.log_parser import parse_log_line from scripts.bin import ( download_bin_file, get_matched_summary, check_bin_exists, cwd, bin_file_path, bin_file_name, is_win, ) from scripts.tool import get_windows_drives, convert_to_bytes import functools from scripts.logger import logger def get_global_conf(): default_conf = get_default_conf() return { "output_dirs": opts.data.get("baidu_netdisk_output_dirs") or default_conf.get("output_dirs"), "upload_dir": opts.data.get("baidu_netdisk_upload_dir") or default_conf.get("upload_dir"), } def exec_ops(args: Union[List[str], str]): args = [args] if isinstance(args, str) else args res = "" if check_bin_exists(): result = subprocess.run([bin_file_path, *args], capture_output=True) try: res = result.stdout.decode().strip() except UnicodeDecodeError: res = result.stdout.decode("gbk", errors="ignore").strip() logger.info(res) return res def login_by_bduss(bduss: str): output = exec_ops(["login", f"-bduss={bduss}"]) match = re.search("百度帐号登录成功: (.+)$", output) if match: return {"status": "ok", "msg": match.group(1).strip()} else: return {"status": "error", "msg": output} def get_curr_working_dir(): return exec_ops("pwd") def list_file(cwd="/"): output = exec_ops(["ls", cwd]) pattern = re.compile( r"\s+(\d+)\s+([\w\-.]+)\s+(\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2})\s+(.*)" ) if output.find("获取目录下的文件列表: 网络错误") != -1: raise Exception("获取目录下的文件列表: 网络错误") files = [] for line in output.split("\n"): match = re.match(pattern, line) if match: name = match.group(4).strip() f_type = "dir" if name.endswith("/") else "file" size = match.group(2) file_info = { "size": size, "date": match.group(3), "name": name.strip("/"), "type": f_type, "bytes": convert_to_bytes(size) if size != "-" else size } files.append(file_info) return files def get_curr_user(): match = re.search( r"uid:\s*(\d+), 用户名:\s*(\w+),", exec_ops("who"), ) if not match: return uid = match.group(1) if int(uid) == 0: return username = match.group(2) return {"uid": uid, "username": username} def logout(): match = re.search("退出用户成功", exec_ops(["logout", "-y"])) return bool(match) not_exists_msg = ( f"找不到{bin_file_name},尝试手动从 {get_matched_summary()[1]} 下载,下载后放到 {cwd} 文件夹下,重启界面" ) def on_ui_tabs(): exists = check_bin_exists() if not exists: try: print("缺少必要的二进制文件,开始下载") download_bin_file() print("done") except Exception as e: print("下载二进制文件时出错:", str(e)) exists = check_bin_exists() if not exists: print(f"\033[31m{not_exists_msg}\033[0m") with gr.Blocks(analytics_enabled=False) as baidu_netdisk: gr.Textbox(not_exists_msg, visible=not exists) with gr.Row(visible=bool(exists)): with gr.Column(): gr.HTML( "如果你看到这个那说明此项那说明出现了问题", elem_id="baidu_netdisk_container_wrapper" ) return ((baidu_netdisk, "百度云上传", "baiduyun"),) def get_default_conf(): conf_g = opts.data outputs_dirs = ",".join( list( filter( bool, [ conf_g["outdir_samples"], conf_g["outdir_txt2img_samples"], conf_g["outdir_img2img_samples"], conf_g["outdir_extras_samples"], conf_g["outdir_grids"], conf_g["outdir_txt2img_grids"], conf_g["outdir_txt2img_grids"], conf_g["outdir_save"], ], ) ) ) upload_dir = "/stable-diffusion-upload" return { "output_dirs": outputs_dirs, "upload_dir": upload_dir, } def singleton_async(fn): @functools.wraps(fn) async def wrapper(*args, **kwargs): key = args[0] if len(args) > 0 else None if key in wrapper.busy: raise Exception("Function is busy, please try again later.") wrapper.busy.append(key) try: return await fn(*args, **kwargs) finally: wrapper.busy.remove(key) wrapper.busy = [] return wrapper def baidu_netdisk_api(_: gr.Blocks, app: FastAPI): pre = "/baidu_netdisk" app.mount( f"{pre}/fe-static", StaticFiles(directory=f"{cwd}/vue/dist"), name="baidu_netdisk-fe-static", ) @app.get(f"{pre}/user") async def user(): return get_curr_user() @app.post(f"{pre}/user/logout") async def user_logout(): return logout() class BaiduyunUserLoginReq(BaseModel): bduss: str @app.post(f"{pre}/user/login") async def user_login(req: BaiduyunUserLoginReq): res = login_by_bduss(req.bduss) if res["status"] != "ok": raise HTTPException(status_code=401, detail=res["msg"]) return get_curr_user() @app.get(f"{pre}/hello") async def greeting(): return "hello" @app.get(f"{pre}/global_setting") async def global_setting(): return { "global_setting": opts.data, "default_conf": get_default_conf(), "cwd": cwd, "is_win": is_win, "sd_cwd": os.getcwd(), } class BaiduyunUploadDownloadReq(BaseModel): type: Literal["upload", "download"] send_dirs: str recv_dir: str @app.post(f"{pre}/task") async def upload(req: BaiduyunUploadDownloadReq): task = await BaiduyunTask.create(**req.dict()) return {"id": task.id} @app.get(f"{pre}/tasks") async def upload_tasks(): tasks = [] for key in BaiduyunTask.get_cache(): task = BaiduyunTask.get_by_id(key) task.update_state() tasks.append(task.get_summary()) return {"tasks": list(reversed(tasks))} @app.delete(pre + "/task/{id}") async def remove_task_cache(id: str): c = BaiduyunTask.get_cache() if id in c: c.pop(id) @app.get(pre + "/task/{id}/files_state") async def task_files_state(id): p = BaiduyunTask.get_by_id(id) if not p: raise HTTPException(status_code=404, detail="找不到该上传任务") return {"files_state": p.files_state} @app.post(pre + "/task/{id}/cancel") async def cancel_task(id): p = BaiduyunTask.get_by_id(id) if not p: raise HTTPException(status_code=404, detail="找不到该上传任务") last_tick = await p.cancel() return {"last_tick": last_tick} upload_poll_promise_dict = {} @app.get(pre + "/task/{id}/tick") async def upload_poll(id): async def get_tick_sync_wait_wrapper(): task = BaiduyunTask.get_by_id(id) if not task: raise HTTPException(status_code=404, detail="找不到该上传任务") return await task.get_tick() res = upload_poll_promise_dict.get(id) if res: res = await res else: upload_poll_promise_dict[id] = asyncio.create_task( get_tick_sync_wait_wrapper() ) res = await upload_poll_promise_dict[id] upload_poll_promise_dict.pop(id) return res @app.get(pre + "/files/{target}") async def get_target_floder_files( target: Literal["local", "netdisk"], folder_path: str ): files = [] try: if target == "local": if is_win and folder_path == "/": for item in get_windows_drives(): files.append({"type": "dir", "size": "-", "name": item}) else: for item in os.listdir(folder_path): path = os.path.join(folder_path, item) mod_time = os.path.getmtime(path) date = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(mod_time) ) if os.path.isfile(path): bytes = os.path.getsize(path) size = human_readable_size(bytes) files.append( { "type": "file", "date": date, "size": size, "name": item, "bytes": bytes, } ) elif os.path.isdir(path): files.append( {"type": "dir", "date": date, "size": "-", "name": item} ) else: files = list_file(folder_path) except Exception as e: logger.error(e) raise HTTPException(status_code=400, detail=str(e)) return {"files": files} script_callbacks.on_ui_tabs(on_ui_tabs) script_callbacks.on_app_started(baidu_netdisk_api)