fix(caption): address PR review feedback

- Remove superfluous SimpleNamespace import in cli/api-caption.py, use Map instead
- Drop _ prefix from internal helper functions in modules/api/caption.py
- Move DeepDanbooru model path to top-level models folder instead of nesting under CLIP
pull/4613/head
CalamitousFelicitousness 2026-02-11 02:45:22 +00:00
parent 139e331d80
commit 80014fac7c
3 changed files with 28 additions and 29 deletions

View File

@ -8,7 +8,6 @@ import base64
import sys
import os
import asyncio
from types import SimpleNamespace
import filetype
from PIL import Image
from util import log, Map
@ -67,7 +66,7 @@ async def caption(f):
else:
log.error({ 'caption clip error': res })
# run tagger (DeepBooru)
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
tagger_req = Map({'image': json.image, 'model': 'deepbooru', 'show_scores': True})
res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
keywords = {}
if 'scores' in res and res.scores:

View File

@ -21,8 +21,8 @@ ReqTagger/ResTagger, ReqVQA/ResVQA) for precise request validation and typed res
The dispatch endpoint uses a discriminated union (ReqCaptionDispatch) and a superset
response model (ResCaptionDispatch) that includes fields from all backends.
Core processing logic is shared between direct and dispatch handlers via internal
``_do_openclip``, ``_do_tagger``, and ``_do_vqa`` functions to avoid duplication.
Core processing logic is shared between direct and dispatch handlers via
``do_openclip``, ``do_tagger``, and ``do_vqa`` functions to avoid duplication.
"""
from typing import Optional, List, Union, Literal, Annotated
@ -244,7 +244,7 @@ class ResCaptionDispatch(BaseModel):
# Shared Core Logic (eliminates duplication between direct and dispatch endpoints)
# =============================================================================
def _validate_image(image_b64: str):
def validate_image(image_b64: str):
"""Validate and decode a base64 image string, returning an RGB PIL Image.
Raises:
@ -256,7 +256,7 @@ def _validate_image(image_b64: str):
return image.convert('RGB')
def _build_clip_overrides(req) -> dict:
def build_clip_overrides(req) -> dict:
"""Build clip_interrogator overrides dict from request fields."""
overrides = {}
for key in ('max_length', 'chunk_size', 'min_flavors', 'max_flavors', 'flavor_count', 'num_beams'):
@ -266,7 +266,7 @@ def _build_clip_overrides(req) -> dict:
return overrides or None
def _get_top_item(result):
def get_top_item(result):
"""Extract top-ranked item from a Gradio update dict."""
if isinstance(result, dict) and 'value' in result:
value = result['value']
@ -277,7 +277,7 @@ def _get_top_item(result):
return None
def _do_openclip(image, req):
def do_openclip(image, req):
"""Core OpenCLIP captioning logic shared by direct and dispatch endpoints.
Returns (caption, medium, artist, movement, trending, flavor).
@ -286,7 +286,7 @@ def _do_openclip(image, req):
from modules.caption.openclip import caption_image, analyze_image, refresh_clip_models
if req.model not in refresh_clip_models():
raise HTTPException(status_code=404, detail="Model not found")
clip_overrides = _build_clip_overrides(req)
clip_overrides = build_clip_overrides(req)
try:
caption = caption_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides)
except Exception as e:
@ -294,10 +294,10 @@ def _do_openclip(image, req):
if not req.analyze:
return caption, None, None, None, None, None
results = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model)
return caption, _get_top_item(results[0]), _get_top_item(results[1]), _get_top_item(results[2]), _get_top_item(results[3]), _get_top_item(results[4])
return caption, get_top_item(results[0]), get_top_item(results[1]), get_top_item(results[2]), get_top_item(results[3]), get_top_item(results[4])
def _build_vqa_kwargs(req) -> dict:
def build_vqa_kwargs(req) -> dict:
"""Build generation kwargs dict from VQA request fields."""
kwargs = {}
for key in ('max_tokens', 'temperature', 'top_k', 'top_p', 'num_beams', 'do_sample', 'keep_thinking', 'keep_prefill'):
@ -307,7 +307,7 @@ def _build_vqa_kwargs(req) -> dict:
return kwargs or None
def _do_vqa(image, req):
def do_vqa(image, req):
"""Core VLM captioning logic shared by direct and dispatch endpoints.
Returns (answer, annotated_b64).
@ -321,7 +321,7 @@ def _do_vqa(image, req):
model_name=req.model,
prefill=req.prefill,
thinking_mode=req.thinking_mode,
generation_kwargs=_build_vqa_kwargs(req)
generation_kwargs=build_vqa_kwargs(req)
)
if isinstance(answer, str) and answer.startswith('Error:'):
raise HTTPException(status_code=422, detail=answer)
@ -333,7 +333,7 @@ def _do_vqa(image, req):
return answer, annotated_b64
def _parse_tagger_scores(tags: str) -> dict:
def parse_tagger_scores(tags: str) -> dict:
"""Parse confidence scores from tagger output string."""
scores = {}
for item in tags.split(', '):
@ -354,7 +354,7 @@ def _parse_tagger_scores(tags: str) -> dict:
return scores or None
def _do_tagger(image, req):
def do_tagger(image, req):
"""Core tagger logic shared by direct and dispatch endpoints.
Returns (tags, scores).
@ -387,7 +387,7 @@ def _do_tagger(image, req):
shared.opts.waifudiffusion_character_threshold = req.character_threshold
shared.opts.waifudiffusion_model = req.model
tags = tagger.tag(image, model_name='DeepBooru' if is_deepbooru else None)
scores = _parse_tagger_scores(tags) if req.show_scores else None
scores = parse_tagger_scores(tags) if req.show_scores else None
return tags, scores
finally:
for key, value in original_opts.items():
@ -438,8 +438,8 @@ def post_caption(req: ReqCaption):
**Error Codes:**
- ``404``: Image not provided or model not found
"""
image = _validate_image(req.image)
caption, medium, artist, movement, trending, flavor = _do_openclip(image, req)
image = validate_image(req.image)
caption, medium, artist, movement, trending, flavor = do_openclip(image, req)
return ResCaption(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
@ -484,8 +484,8 @@ def post_vqa(req: ReqVQA):
- ``404``: Image not provided
- ``422``: Model returned an error (e.g., unsupported task for model)
"""
image = _validate_image(req.image)
answer, annotated_b64 = _do_vqa(image, req)
image = validate_image(req.image)
answer, annotated_b64 = do_vqa(image, req)
return ResVQA(answer=answer, annotated_image=annotated_b64)
@ -522,16 +522,16 @@ def post_caption_dispatch(req: ReqCaptionDispatch):
- ``422``: VLM model returned an error
"""
if req.backend == "openclip":
image = _validate_image(req.image)
caption, medium, artist, movement, trending, flavor = _do_openclip(image, req)
image = validate_image(req.image)
caption, medium, artist, movement, trending, flavor = do_openclip(image, req)
return ResCaptionDispatch(backend="openclip", caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
elif req.backend == "tagger":
image = _validate_image(req.image)
tags, scores = _do_tagger(image, req)
image = validate_image(req.image)
tags, scores = do_tagger(image, req)
return ResCaptionDispatch(backend="tagger", tags=tags, scores=scores)
elif req.backend == "vlm":
image = _validate_image(req.image)
answer, annotated_b64 = _do_vqa(image, req)
image = validate_image(req.image)
answer, annotated_b64 = do_vqa(image, req)
return ResCaptionDispatch(backend="vlm", answer=answer, annotated_image=annotated_b64)
else:
raise HTTPException(status_code=400, detail=f"Unknown backend: {req.backend}")
@ -558,8 +558,8 @@ def post_tagger(req: ReqTagger):
**Error Codes:**
- ``404``: Image not provided
"""
image = _validate_image(req.image)
tags, scores = _do_tagger(image, req)
image = validate_image(req.image)
tags, scores = do_tagger(image, req)
return ResTagger(tags=tags, scores=scores)

View File

@ -18,7 +18,7 @@ class DeepDanbooru:
with load_lock:
if self.model is not None:
return
model_path = os.path.join(shared.opts.clip_models_path, "DeepDanbooru")
model_path = os.path.join(shared.models_path, "DeepDanbooru")
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
files = modelloader.load_models(
model_path=model_path,