automatic/modules/api/server.py

165 lines
6.5 KiB
Python

import os
import time
from fastapi import Request, Depends
from fastapi.exceptions import HTTPException
from fastapi.responses import FileResponse
import installer
from modules import shared
from modules.logger import log
from modules.api import models, helpers
def get_js(request: Request):
file = request.query_params.get("file", None)
if (file is None) or (len(file) == 0):
raise HTTPException(status_code=400, detail="file parameter is required")
ext = file.split('.')[-1]
if ext not in ['js', 'css', 'map', 'html', 'wasm', 'ttf', 'mjs', 'json']:
raise HTTPException(status_code=400, detail=f"invalid file extension: {ext}")
if not os.path.exists(file):
log.error(f"API: file not found: {file}")
raise HTTPException(status_code=404, detail=f"file not found: {file}")
if ext in ['js', 'mjs']:
media_type = 'application/javascript'
elif ext in ['map', 'json']:
media_type = 'application/json'
elif ext in ['css']:
media_type = 'text/css'
elif ext in ['html']:
media_type = 'text/html'
elif ext in ['wasm']:
media_type = 'application/wasm'
elif ext in ['ttf']:
media_type = 'font/ttf'
else:
media_type = 'application/octet-stream'
return FileResponse(file, media_type=media_type)
def get_version():
return installer.get_version()
def get_motd():
import requests
motd = ""
ver = get_version()
if ver.get("updated", None) is not None:
motd = f"version <b>{ver['commit']} {ver['updated']}</b> <span style='color: var(--primary-500)'>{ver['url'].split('/')[-1]}</span><br>" # pylint: disable=use-maxsplit-arg
if shared.opts.motd:
try:
res = requests.get("https://vladmandic.github.io/sdnext/motd", timeout=3)
if res.status_code == 200:
msg = (res.text or "").strip()
log.info(f"MOTD: {msg if len(msg) > 0 else 'N/A'}")
motd += res.text
else:
log.error(f"MOTD: {res.status_code}")
except Exception as err:
log.error(f"MOTD: {err}")
return motd
def get_platform():
from modules.loader import get_packages as loader_get_packages
return { **installer.get_platform(), **loader_get_packages() }
def get_torch():
return dict(installer.torch_info)
def get_log(req: models.ReqGetLog = Depends()):
lines = log.buffer[:req.lines] if req.lines > 0 else log.buffer.copy()
if req.clear:
log.buffer.clear()
return lines
def post_log(req: models.ReqPostLog):
if req.message is not None:
log.info(f'UI: {req.message}')
if req.debug is not None:
log.debug(f'UI: {req.debug}')
if req.error is not None:
log.error(f'UI: {req.error}')
return {}
def post_shutdown():
log.info("Shutdown request received")
import sys
sys.exit(0)
def get_cmd_flags():
return vars(shared.cmd_opts)
def get_history(req: models.ReqHistory = Depends()):
if req.id is not None and len(req.id) > 0:
res = [item for item in shared.state.state_history if item['id'] == req.id]
else:
res = shared.state.state_history
res = [models.ResHistory(**item) for item in res]
return res
def get_progress(req: models.ReqProgress = Depends()):
if shared.state.job_count == 0 and shared.state.sampling_step == 0: # truly idle
return models.ResProgress(id=shared.state.id, progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
shared.state.do_set_current_image()
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = helpers.encode_pil_to_base64(shared.state.current_image)
batch_x = max(shared.state.job_no, 0)
batch_y = max(shared.state.job_count, 1)
step_x = max(shared.state.sampling_step, 0)
prev_steps = max(shared.state.sampling_steps, 1)
while step_x > shared.state.sampling_steps:
shared.state.sampling_steps += prev_steps
step_y = max(shared.state.sampling_steps, 1)
current = step_y * batch_x + step_x
total = step_y * batch_y
progress = min((current / total) if current > 0 and total > 0 else 0, 1)
time_since_start = time.time() - shared.state.time_start
eta_relative = (time_since_start / progress) - time_since_start if progress > 0 else 0
# log.critical(f'get_progress: batch {batch_x}/{batch_y} step {step_x}/{step_y} current {current}/{total} time={time_since_start} eta={eta_relative}')
# log.critical(shared.state)
res = models.ResProgress(id=shared.state.id, progress=round(progress, 2), eta_relative=round(eta_relative, 2), current_image=current_image, textinfo=shared.state.textinfo, state=shared.state.dict(), )
return res
def get_status():
return shared.state.status()
def post_interrupt():
shared.state.interrupt()
return {}
def post_skip():
shared.state.skip()
def get_memory():
try:
import psutil
process = psutil.Process(os.getpid())
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
except Exception as err:
ram = { 'error': f'{err}' }
try:
import torch
if torch.cuda.is_available():
s = torch.cuda.mem_get_info()
system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
s = dict(torch.cuda.memory_stats(shared.device))
allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
cuda = {
'system': system,
'active': active,
'allocated': allocated,
'reserved': reserved,
'inactive': inactive,
'events': warnings,
}
else:
cuda = { 'error': 'unavailable' }
except Exception as err:
cuda = { 'error': f'{err}' }
return models.ResMemory(ram = ram, cuda = cuda)