api refactor: force access control and handle subpaths

Signed-off-by: Vladimir Mandic <mandic00@live.com>
pull/3754/merge
Vladimir Mandic 2025-02-14 11:25:29 -05:00
parent 877c488fa8
commit dbf20d1388
19 changed files with 92 additions and 45 deletions

View File

@ -60,6 +60,8 @@ We're back with another update with over 50 commits!
- improve upscaler compatibility
- enable upscaler compile by default
- fix shape mismatch errors on too many resolution changes
- **ZLUDA**
- update to `zluda==3.8.8`
- **Other**
- **Asymmetric tiling**
allows for configurable image tiling for x/y axis separately
@ -90,6 +92,8 @@ We're back with another update with over 50 commits!
- force pydantic version reinstall/reload
- multi-unit when using controlnet-union
- pulid with hidiffusion
- api: stricter access control
- api: universal handle mount subpaths
## Update for 2025-02-05

View File

@ -294,7 +294,7 @@ function selectHistory(id) {
const headers = new Headers();
headers.set('Content-Type', 'application/json');
const init = { method: 'POST', body: { name: id }, headers };
fetch('/sdapi/v1/history', { method: 'POST', body: JSON.stringify({ name: id }), headers });
fetch(`${window.api}/history`, { method: 'POST', body: JSON.stringify({ name: id }), headers });
}
let enDirty = false;

View File

@ -92,7 +92,7 @@ async function addSeparators() {
async function delayFetchThumb(fn) {
while (outstanding > 16) await new Promise((resolve) => setTimeout(resolve, 50)); // eslint-disable-line no-promise-executor-return
outstanding++;
const res = await fetch(`/sdapi/v1/browser/thumb?file=${encodeURI(fn)}`, { priority: 'low' });
const res = await fetch(`${window.api}/browser/thumb?file=${encodeURI(fn)}`, { priority: 'low' });
if (!res.ok) {
error(`fetchThumb: ${res.statusText}`);
outstanding--;
@ -334,7 +334,7 @@ async function fetchFilesHT(evt) {
el.status.innerText = `Folder | ${evt.target.name} | in-progress`;
let numFiles = 0;
const res = await fetch(`/sdapi/v1/browser/files?folder=${encodeURI(evt.target.name)}`);
const res = await fetch(`${window.api}/browser/files?folder=${encodeURI(evt.target.name)}`);
if (!res || res.status !== 200) {
el.status.innerText = `Folder | ${evt.target.name} | failed: ${res?.statusText}`;
return;
@ -412,7 +412,7 @@ async function pruneImages() {
async function galleryVisible() {
// if (el.folders.children.length > 0) return;
const res = await fetch('/sdapi/v1/browser/folders');
const res = await fetch(`${window.api}/browser/folders`);
if (!res || res.status !== 200) return;
el.folders.innerHTML = '';
url = res.url.split('/sdapi')[0].replace('http', 'ws'); // update global url as ws need fqdn

View File

@ -37,7 +37,7 @@ async function createSplash() {
await preloadImages();
const imgEl = `<div id="spash-img" class="splash-img" alt="logo" style="background-image: url(file=html/logo-bg-${dark ? 'dark' : 'light'}.jpg), url(file=html/logo-bg-${num}.jpg); background-blend-mode: ${dark ? 'multiply' : 'lighten'}"></div>`;
document.getElementById('splash').insertAdjacentHTML('afterbegin', imgEl);
fetch('/sdapi/v1/motd')
fetch(`${window.api}/motd`)
.then((res) => res.text())
.then((text) => {
const motdEl = document.getElementById('motd');
@ -52,7 +52,7 @@ async function removeSplash() {
log('removeSplash');
const t = Math.round(performance.now() - appStartTime) / 1000;
log('startupTime', t);
xhrPost('/sdapi/v1/log', { message: `ready time=${t}` });
xhrPost(`${window.api}/log`, { message: `ready time=${t}` });
}
window.onload = createSplash;

View File

@ -70,7 +70,7 @@ async function logMonitor() {
if (!logMonitorEl) return;
const atBottom = logMonitorEl.scrollHeight <= (logMonitorEl.scrollTop + logMonitorEl.clientHeight);
try {
const res = await fetch('/sdapi/v1/log?clear=True');
const res = await fetch(`${window.api}/log?clear=True`);
if (res?.ok) {
logMonitorStatus = true;
const lines = await res.json();
@ -116,7 +116,7 @@ async function initLogMonitor() {
</table>
`;
el.style.display = 'none';
fetch(`/sdapi/v1/start?agent=${encodeURI(navigator.userAgent)}`);
fetch(`${window.api}/start?agent=${encodeURI(navigator.userAgent)}`);
logMonitor();
log('initLogMonitor');
}

View File

@ -25,7 +25,7 @@ const loginHTML = `
function forceLogin() {
const form = document.createElement('form');
form.method = 'POST';
form.action = '/login';
form.action = `${location.href}login`;
form.id = 'loginForm';
form.style.cssText = loginCSS;
form.innerHTML = loginHTML;
@ -39,8 +39,8 @@ function forceLogin() {
const formData = new FormData(form);
formData.append('username', username.value);
formData.append('password', password.value);
console.warn('login', formData);
fetch('/login', {
console.warn('login', location.href, formData);
fetch(`${location.href}login`, {
method: 'POST',
body: formData,
})
@ -59,7 +59,7 @@ function forceLogin() {
}
function loginCheck() {
fetch('/login_check', {})
fetch(`${location.href}login_check`, {})
.then((res) => {
if (res.status === 200) console.log('login ok');
else forceLogin();

View File

@ -30,7 +30,7 @@ async function updateNVMLChart(mem, load) {
async function updateNVML() {
try {
const res = await fetch('/sdapi/v1/nvml');
const res = await fetch(`${window.api}/nvml`);
if (!res.ok) {
clearInterval(nvmlInterval);
nvmlEl.style.display = 'none';

View File

@ -136,16 +136,16 @@ async function getLocaleData(desiredLocale = null) {
// primary
let json = {};
try {
let res = await fetch(`/file=html/locale_${localeData.locale}.json`);
let res = await fetch(`${window.subpath}/file=html/locale_${localeData.locale}.json`);
if (!res || !res.ok) {
localeData.locale = 'en';
res = await fetch(`/file=html/locale_${localeData.locale}.json`);
res = await fetch(`${window.subpath}/file=html/locale_${localeData.locale}.json`);
}
json = await res.json();
} catch { /**/ }
try {
const res = await fetch(`/file=html/override_${localeData.locale}.json`);
const res = await fetch(`${window.subpath}/file=html/override_${localeData.locale}.json`);
if (res && res.ok) json.override = await res.json();
} catch { /**/ }

View File

@ -152,7 +152,7 @@ async function initModels() {
const el = gradioApp().getElementById('main_info');
const en = gradioApp().getElementById('txt2img_extra_networks');
if (!el || !en) return;
const req = await fetch('/sdapi/v1/sd-models');
const req = await fetch(`${window.api}/sd-models`);
const res = req.ok ? await req.json() : [];
log('initModels', res.length);
const ready = () => `

View File

@ -1,4 +1,6 @@
/* eslint-disable no-undef */
window.api = '/sdapi/v1';
window.subpath = '';
async function initStartup() {
log('initStartup');
@ -23,6 +25,11 @@ async function initStartup() {
// make sure all of the ui is ready and options are loaded
while (Object.keys(window.opts).length === 0) await sleep(50);
log('mountURL', window.opts.subpath);
if (window.opts.subpath?.length > 0) {
window.subpath = window.opts.subpath;
window.api = `${window.subpath}/sdapi/v1`;
}
executeCallbacks(uiReadyCallbacks);
initLogMonitor();
setupExtraNetworks();

View File

@ -385,7 +385,7 @@ function monitorServerStatus() {
<h1>Waiting for server...</h1>
<script>
function monitorServerStatus() {
fetch('/sdapi/v1/progress?skip_current_image=true')
fetch('${window.api}/progress?skip_current_image=true')
.then((res) => { !res?.ok ? setTimeout(monitorServerStatus, 1000) : location.reload(); })
.catch((e) => setTimeout(monitorServerStatus, 1000))
}
@ -400,7 +400,7 @@ function monitorServerStatus() {
function restartReload() {
document.body.style = 'background: #222222; font-size: 1rem; font-family:monospace; margin-top:20%; color:lightgray; text-align:center';
document.body.innerHTML = '<h1>Server shutdown in progress...</h1>';
fetch('/sdapi/v1/progress?skip_current_image=true')
fetch(`${window.api}/progress?skip_current_image=true`)
.then((res) => setTimeout(restartReload, 1000))
.catch((e) => setTimeout(monitorServerStatus, 500));
return [];
@ -479,7 +479,7 @@ function toggleCompact(val, old) {
function previewTheme() {
let name = gradioApp().getElementById('setting_gradio_theme').querySelectorAll('input')?.[0].value || '';
fetch('/file=html/themes.json')
fetch(`${window.subpath}/file=html/themes.json`)
.then((res) => {
res.json()
.then((themes) => {

View File

@ -32,7 +32,11 @@ class Api:
self.generate = generate.APIGenerate(queue_lock)
self.process = process.APIProcess(queue_lock)
self.control = control.APIControl(queue_lock)
# compatibility api
self.text2imgapi = self.generate.post_text2img
self.img2imgapi = self.generate.post_img2img
def register(self):
# server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
@ -97,16 +101,15 @@ class Api:
self.add_api_route("/sdapi/v1/refresh-loras", endpoints.post_refresh_loras, methods=["POST"])
# gallery api
gallery.register_api(app)
gallery.register_api(self.app)
# compatibility api
self.text2imgapi = self.generate.post_text2img
self.img2imgapi = self.generate.post_img2img
def add_api_route(self, path: str, endpoint, **kwargs):
if (shared.cmd_opts.auth or shared.cmd_opts.auth_file) and shared.cmd_opts.api_only:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)
kwargs['dependencies'] = [Depends(self.auth)]
if shared.opts.subpath is not None and len(shared.opts.subpath) > 0:
self.app.add_api_route(f'{shared.opts.subpath}{path}', endpoint, **kwargs)
self.app.add_api_route(path, endpoint, **kwargs)
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
# this is only needed for api-only since otherwise auth is handled in gradio/routes.py

View File

@ -126,7 +126,7 @@ def register_api(app: FastAPI): # register api
shared.log.error(f'Gallery image: file="{filepath}" {e}')
return {}
@app.get('/sdapi/v1/browser/folders', response_model=List[str])
# @app.get('/sdapi/v1/browser/folders', response_model=List[str])
def get_folders():
folders = [shared.opts.data.get(f, '') for f in OPTS_FOLDERS]
folders += list(shared.opts.browser_folders.split(','))
@ -141,7 +141,7 @@ def register_api(app: FastAPI): # register api
debug(f'Browser folders: {folders}')
return JSONResponse(content=folders)
@app.get("/sdapi/v1/browser/thumb", response_model=dict)
# @app.get("/sdapi/v1/browser/thumb", response_model=dict)
async def get_thumb(file: str):
try:
decoded = unquote(file).replace('%3A', ':')
@ -154,7 +154,7 @@ def register_api(app: FastAPI): # register api
content = { 'error': str(e) }
return JSONResponse(content=content)
@app.get("/sdapi/v1/browser/files", response_model=list)
# @app.get("/sdapi/v1/browser/files", response_model=list)
async def ht_files(folder: str):
try:
t0 = time.time()
@ -172,6 +172,10 @@ def register_api(app: FastAPI): # register api
shared.log.error(f'Gallery: {folder} {e}')
return []
shared.api.add_api_route("/sdapi/v1/browser/folders", get_folders, methods=["GET"], response_model=List[str])
shared.api.add_api_route("/sdapi/v1/browser/thumb", get_thumb, methods=["GET"], response_model=dict)
shared.api.add_api_route("/sdapi/v1/browser/files", ht_files, methods=["GET"], response_model=list)
@app.websocket("/sdapi/v1/browser/files")
async def ws_files(ws: WebSocket):
try:

View File

@ -328,6 +328,8 @@ class ReqInterrogate(BaseModel):
clip_model: str = Field(default="", title="CLiP Model", description="The interrogate model used.")
blip_model: str = Field(default="", title="BLiP Model", description="The interrogate model used.")
InterrogateRequest = ReqInterrogate # alias for backwards compatibility
class ResInterrogate(BaseModel):
caption: Optional[str] = Field(default=None, title="Caption", description="The generated caption for the image.")
medium: Optional[str] = Field(default=None, title="Medium", description="Image medium.")

View File

@ -26,6 +26,22 @@ except Exception:
pass
def sanitize_filename_part(text, replace_spaces=True):
if text is None:
return None
if replace_spaces:
text = text.replace(' ', '_')
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
max_filename_part_length = 64
text = text.translate({ord(x): '_' for x in invalid_filename_chars})
text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
text = text.rstrip(invalid_filename_postfix)
return text
def atomically_save_image():
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
while True:

View File

@ -86,5 +86,5 @@ def progressapi(req: ProgressRequest):
return res
def setup_progress_api(app):
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=InternalProgressResponse)
def setup_progress_api():
shared.api.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=InternalProgressResponse)

View File

@ -793,6 +793,7 @@ options_templates.update(options_section(('ui', "User Interface"), {
"theme_style": OptionInfo("Auto", "Theme mode", gr.Radio, {"choices": ["Auto", "Dark", "Light"]}),
"gradio_theme": OptionInfo("black-teal", "UI theme", gr.Dropdown, lambda: {"choices": theme.list_themes()}, refresh=theme.refresh_themes),
"ui_locale": OptionInfo("Auto", "UI locale", gr.Dropdown, lambda: {"choices": theme.list_locales()}),
"subpath": OptionInfo("", "Mount URL subpath"),
"autolaunch": OptionInfo(False, "Autolaunch browser upon startup"),
"font_size": OptionInfo(14, "Font size", gr.Slider, {"minimum": 8, "maximum": 32, "step": 1, "visible": True}),
"aspect_ratios": OptionInfo("1:1, 4:3, 3:2, 16:9, 16:10, 21:9, 2:3, 3:4, 9:16, 10:16, 9:21", "Allowed aspect ratios"),

View File

@ -53,7 +53,7 @@ card_list = '''
preview_map = None
def init_api(app):
def init_api():
def fetch_file(filename: str = ""):
if not os.path.exists(filename):
@ -102,10 +102,10 @@ def init_api(app):
# shared.log.debug(f"Networks desc: page='{page.name}' item={item['name']} len={len(desc)}")
return JSONResponse({"description": desc})
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
app.add_api_route("/sd_extra_networks/info", get_info, methods=["GET"])
app.add_api_route("/sd_extra_networks/description", get_desc, methods=["GET"])
shared.api.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
shared.api.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
shared.api.add_api_route("/sd_extra_networks/info", get_info, methods=["GET"])
shared.api.add_api_route("/sd_extra_networks/description", get_desc, methods=["GET"])
class ExtraNetworksPage:

View File

@ -250,6 +250,18 @@ def start_common():
timer.startup.record("cleanup")
def mount_subpath(app):
if shared.cmd_opts.subpath:
shared.opts.subpath = shared.cmd_opts.subpath
if shared.opts.subpath is None or len(shared.opts.subpath) == 0:
return
import gradio
if not shared.opts.subpath.startswith('/'):
shared.opts.subpath = f'/{shared.opts.subpath}'
gradio.mount_gradio_app(app, shared.demo, path=shared.opts.subpath)
shared.log.info(f'Mounted: subpath="{shared.opts.subpath}"')
def start_ui():
log.debug('UI start sequence')
modules.script_callbacks.before_ui_callback()
@ -324,19 +336,14 @@ def start_ui():
shared.demo.server.wants_restart = False
modules.api.middleware.setup_middleware(app, shared.cmd_opts)
if shared.cmd_opts.subpath:
import gradio
gradio.mount_gradio_app(app, shared.demo, path=f"/{shared.cmd_opts.subpath}")
shared.log.info(f'Redirector mounted: /{shared.cmd_opts.subpath}')
timer.startup.record("launch")
modules.progress.setup_progress_api(app)
shared.api = create_api(app)
shared.api.register()
modules.progress.setup_progress_api()
modules.ui_extra_networks.init_api()
timer.startup.record("api")
modules.ui_extra_networks.init_api(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
timer.startup.record("app-started")
@ -347,6 +354,7 @@ def start_ui():
time_component = [f'{k}:{round(v,3)}' for (k,v) in modules.scripts.time_component.items() if v > 0.005]
if len(time_component) > 0:
shared.log.debug(f'Scripts components: {time_component}')
return app
def webui(restart=False):
@ -355,10 +363,11 @@ def webui(restart=False):
modules.script_callbacks.script_unloaded_callback()
start_common()
start_ui()
app = start_ui()
modules.script_callbacks.after_ui_callback()
modules.sd_models.write_metadata()
load_model()
mount_subpath(app)
shared.opts.save(shared.config_filename)
if shared.cmd_opts.profile:
for k, v in modules.script_callbacks.callback_map.items():
@ -407,6 +416,7 @@ def api_only():
app = FastAPI(**fastapi_args)
modules.api.middleware.setup_middleware(app, shared.cmd_opts)
shared.api = create_api(app)
shared.api.register()
shared.api.wants_restart = False
modules.script_callbacks.app_started_callback(None, app)
modules.sd_models.write_metadata()