automatic/modules/api/validate.py

108 lines
3.6 KiB
Python

import re
from modules.logger import log
# value is cost: -1=disabled, 0=unlimited, 1=default, >1 expensive
request_cost = {
"/file": 0,
"/run/predict": 0,
"/sdapi/v1/browser/thumb": 0,
"/sdapi/v1/network/thumb": 0,
"/sdapi/v1/txt2img": 5,
"/sdapi/v1/img2img": 5,
"/sdapi/v1/control": 5,
}
log_cost = {
"/.well-known/appspecific/com.chrome.devtools.json": -1,
"/info": -1,
"/file": -1,
"/token": -1,
"/theme.css": -1,
"/sdapi/v1/browser/thumb": -1,
"/sdapi/v1/network/thumb": -1,
"/run/predict": -1,
"/queue/join": -1,
"/internal/progress": -1,
"/sdapi/v1/version": -1,
"/sdapi/v1/log": -1,
"/sdapi/v1/torch": -1,
"/sdapi/v1/gpu": -1,
"/sdapi/v1/memory": -1,
"/sdapi/v1/platform": -1,
"/sdapi/v1/checkpoint": -1,
"/sdapi/v1/status": 60,
"/sdapi/v1/progress": 60,
}
log_exclude_suffix = ['.css', '.js', '.ico', '.svg']
log_exclude_prefix = ['/assets']
class Limiter():
def __init__(self, limit):
import limits
self.request_backend = limits.storage.MemoryStorage()
self.request_limit = limit # default is 300 requests per minute
self.request_strategy = limits.strategies.SlidingWindowCounterRateLimiter(self.request_backend)
self.request_limiter = limits.parse(f"{self.request_limit}/minute")
self.log_backend = limits.storage.MemoryStorage()
self.log_limit = limit // 5 # default is 300/5=60 logs per minute
self.log_strategy = limits.strategies.FixedWindowRateLimiter(self.log_backend)
self.log_limiter = limits.parse(f"{self.log_limit}/minute")
self.summary = {}
log.info(f'API: limit={self.request_limit} strategy={self.request_strategy.__class__.__name__} backend={self.request_backend.__class__.__name__}')
def stats(self):
for k, v in self.summary.items():
if v > 1:
log.trace(f'API stats: {k}={v}')
def check_request(self, client: str, api: str, quiet: bool = False):
if self.request_limit <= 0:
return True
cost = request_cost.get(api, 1)
if cost < 0:
return False
status = self.request_strategy.hit(self.request_limiter, client, api, cost=cost)
if not status and not quiet:
log.warning(f'API: client={client} api={api} rate limit exceeded')
from fastapi.exceptions import HTTPException
raise HTTPException(status_code=429, detail=f"{client}:{api}: rate limit exceeded")
return status
def check_log(self, client: str, api: str):
if self.log_limit < 0:
return True
if any(api.endswith(s) for s in log_exclude_suffix):
return False
if any(api.startswith(s) for s in log_exclude_prefix):
return False
cost = log_cost.get(api, 1)
if cost < 0:
return False
status = self.log_strategy.hit(self.log_limiter, client, api, cost=cost)
return status
limiter = Limiter(300)
def get_api_stats():
limiter.stats()
def validate_request(client, endpoint):
global limiter # pylint: disable=global-statement
from modules.shared import opts
if opts.server_rate_limit != limiter.request_limit:
limiter = Limiter(opts.server_rate_limit)
api = re.match(r"^[^?#&=]+", endpoint).group(0)
key = f"{client}:{api}"
if key not in limiter.summary:
limiter.summary[key] = 0
limiter.summary[key] += 1
return limiter.check_request(client, api)
def validate_log(client, endpoint):
api = re.match(r"^[^?#&=]+", endpoint).group(0)
return limiter.check_log(client, api)