mirror of https://github.com/vladmandic/automatic
fix(caption): address PR review feedback
- Remove superfluous SimpleNamespace import in cli/api-caption.py, use Map instead - Drop _ prefix from internal helper functions in modules/api/caption.py - Move DeepDanbooru model path to top-level models folder instead of nesting under CLIPpull/4613/head
parent
139e331d80
commit
80014fac7c
|
|
@ -8,7 +8,6 @@ import base64
|
|||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
import filetype
|
||||
from PIL import Image
|
||||
from util import log, Map
|
||||
|
|
@ -67,7 +66,7 @@ async def caption(f):
|
|||
else:
|
||||
log.error({ 'caption clip error': res })
|
||||
# run tagger (DeepBooru)
|
||||
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
|
||||
tagger_req = Map({'image': json.image, 'model': 'deepbooru', 'show_scores': True})
|
||||
res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
|
||||
keywords = {}
|
||||
if 'scores' in res and res.scores:
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ ReqTagger/ResTagger, ReqVQA/ResVQA) for precise request validation and typed res
|
|||
The dispatch endpoint uses a discriminated union (ReqCaptionDispatch) and a superset
|
||||
response model (ResCaptionDispatch) that includes fields from all backends.
|
||||
|
||||
Core processing logic is shared between direct and dispatch handlers via internal
|
||||
``_do_openclip``, ``_do_tagger``, and ``_do_vqa`` functions to avoid duplication.
|
||||
Core processing logic is shared between direct and dispatch handlers via
|
||||
``do_openclip``, ``do_tagger``, and ``do_vqa`` functions to avoid duplication.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Union, Literal, Annotated
|
||||
|
|
@ -244,7 +244,7 @@ class ResCaptionDispatch(BaseModel):
|
|||
# Shared Core Logic (eliminates duplication between direct and dispatch endpoints)
|
||||
# =============================================================================
|
||||
|
||||
def _validate_image(image_b64: str):
|
||||
def validate_image(image_b64: str):
|
||||
"""Validate and decode a base64 image string, returning an RGB PIL Image.
|
||||
|
||||
Raises:
|
||||
|
|
@ -256,7 +256,7 @@ def _validate_image(image_b64: str):
|
|||
return image.convert('RGB')
|
||||
|
||||
|
||||
def _build_clip_overrides(req) -> dict:
|
||||
def build_clip_overrides(req) -> dict:
|
||||
"""Build clip_interrogator overrides dict from request fields."""
|
||||
overrides = {}
|
||||
for key in ('max_length', 'chunk_size', 'min_flavors', 'max_flavors', 'flavor_count', 'num_beams'):
|
||||
|
|
@ -266,7 +266,7 @@ def _build_clip_overrides(req) -> dict:
|
|||
return overrides or None
|
||||
|
||||
|
||||
def _get_top_item(result):
|
||||
def get_top_item(result):
|
||||
"""Extract top-ranked item from a Gradio update dict."""
|
||||
if isinstance(result, dict) and 'value' in result:
|
||||
value = result['value']
|
||||
|
|
@ -277,7 +277,7 @@ def _get_top_item(result):
|
|||
return None
|
||||
|
||||
|
||||
def _do_openclip(image, req):
|
||||
def do_openclip(image, req):
|
||||
"""Core OpenCLIP captioning logic shared by direct and dispatch endpoints.
|
||||
|
||||
Returns (caption, medium, artist, movement, trending, flavor).
|
||||
|
|
@ -286,7 +286,7 @@ def _do_openclip(image, req):
|
|||
from modules.caption.openclip import caption_image, analyze_image, refresh_clip_models
|
||||
if req.model not in refresh_clip_models():
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
clip_overrides = _build_clip_overrides(req)
|
||||
clip_overrides = build_clip_overrides(req)
|
||||
try:
|
||||
caption = caption_image(image, clip_model=req.clip_model, blip_model=req.blip_model, mode=req.mode, overrides=clip_overrides)
|
||||
except Exception as e:
|
||||
|
|
@ -294,10 +294,10 @@ def _do_openclip(image, req):
|
|||
if not req.analyze:
|
||||
return caption, None, None, None, None, None
|
||||
results = analyze_image(image, clip_model=req.clip_model, blip_model=req.blip_model)
|
||||
return caption, _get_top_item(results[0]), _get_top_item(results[1]), _get_top_item(results[2]), _get_top_item(results[3]), _get_top_item(results[4])
|
||||
return caption, get_top_item(results[0]), get_top_item(results[1]), get_top_item(results[2]), get_top_item(results[3]), get_top_item(results[4])
|
||||
|
||||
|
||||
def _build_vqa_kwargs(req) -> dict:
|
||||
def build_vqa_kwargs(req) -> dict:
|
||||
"""Build generation kwargs dict from VQA request fields."""
|
||||
kwargs = {}
|
||||
for key in ('max_tokens', 'temperature', 'top_k', 'top_p', 'num_beams', 'do_sample', 'keep_thinking', 'keep_prefill'):
|
||||
|
|
@ -307,7 +307,7 @@ def _build_vqa_kwargs(req) -> dict:
|
|||
return kwargs or None
|
||||
|
||||
|
||||
def _do_vqa(image, req):
|
||||
def do_vqa(image, req):
|
||||
"""Core VLM captioning logic shared by direct and dispatch endpoints.
|
||||
|
||||
Returns (answer, annotated_b64).
|
||||
|
|
@ -321,7 +321,7 @@ def _do_vqa(image, req):
|
|||
model_name=req.model,
|
||||
prefill=req.prefill,
|
||||
thinking_mode=req.thinking_mode,
|
||||
generation_kwargs=_build_vqa_kwargs(req)
|
||||
generation_kwargs=build_vqa_kwargs(req)
|
||||
)
|
||||
if isinstance(answer, str) and answer.startswith('Error:'):
|
||||
raise HTTPException(status_code=422, detail=answer)
|
||||
|
|
@ -333,7 +333,7 @@ def _do_vqa(image, req):
|
|||
return answer, annotated_b64
|
||||
|
||||
|
||||
def _parse_tagger_scores(tags: str) -> dict:
|
||||
def parse_tagger_scores(tags: str) -> dict:
|
||||
"""Parse confidence scores from tagger output string."""
|
||||
scores = {}
|
||||
for item in tags.split(', '):
|
||||
|
|
@ -354,7 +354,7 @@ def _parse_tagger_scores(tags: str) -> dict:
|
|||
return scores or None
|
||||
|
||||
|
||||
def _do_tagger(image, req):
|
||||
def do_tagger(image, req):
|
||||
"""Core tagger logic shared by direct and dispatch endpoints.
|
||||
|
||||
Returns (tags, scores).
|
||||
|
|
@ -387,7 +387,7 @@ def _do_tagger(image, req):
|
|||
shared.opts.waifudiffusion_character_threshold = req.character_threshold
|
||||
shared.opts.waifudiffusion_model = req.model
|
||||
tags = tagger.tag(image, model_name='DeepBooru' if is_deepbooru else None)
|
||||
scores = _parse_tagger_scores(tags) if req.show_scores else None
|
||||
scores = parse_tagger_scores(tags) if req.show_scores else None
|
||||
return tags, scores
|
||||
finally:
|
||||
for key, value in original_opts.items():
|
||||
|
|
@ -438,8 +438,8 @@ def post_caption(req: ReqCaption):
|
|||
**Error Codes:**
|
||||
- ``404``: Image not provided or model not found
|
||||
"""
|
||||
image = _validate_image(req.image)
|
||||
caption, medium, artist, movement, trending, flavor = _do_openclip(image, req)
|
||||
image = validate_image(req.image)
|
||||
caption, medium, artist, movement, trending, flavor = do_openclip(image, req)
|
||||
return ResCaption(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
|
||||
|
||||
|
||||
|
|
@ -484,8 +484,8 @@ def post_vqa(req: ReqVQA):
|
|||
- ``404``: Image not provided
|
||||
- ``422``: Model returned an error (e.g., unsupported task for model)
|
||||
"""
|
||||
image = _validate_image(req.image)
|
||||
answer, annotated_b64 = _do_vqa(image, req)
|
||||
image = validate_image(req.image)
|
||||
answer, annotated_b64 = do_vqa(image, req)
|
||||
return ResVQA(answer=answer, annotated_image=annotated_b64)
|
||||
|
||||
|
||||
|
|
@ -522,16 +522,16 @@ def post_caption_dispatch(req: ReqCaptionDispatch):
|
|||
- ``422``: VLM model returned an error
|
||||
"""
|
||||
if req.backend == "openclip":
|
||||
image = _validate_image(req.image)
|
||||
caption, medium, artist, movement, trending, flavor = _do_openclip(image, req)
|
||||
image = validate_image(req.image)
|
||||
caption, medium, artist, movement, trending, flavor = do_openclip(image, req)
|
||||
return ResCaptionDispatch(backend="openclip", caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
|
||||
elif req.backend == "tagger":
|
||||
image = _validate_image(req.image)
|
||||
tags, scores = _do_tagger(image, req)
|
||||
image = validate_image(req.image)
|
||||
tags, scores = do_tagger(image, req)
|
||||
return ResCaptionDispatch(backend="tagger", tags=tags, scores=scores)
|
||||
elif req.backend == "vlm":
|
||||
image = _validate_image(req.image)
|
||||
answer, annotated_b64 = _do_vqa(image, req)
|
||||
image = validate_image(req.image)
|
||||
answer, annotated_b64 = do_vqa(image, req)
|
||||
return ResCaptionDispatch(backend="vlm", answer=answer, annotated_image=annotated_b64)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown backend: {req.backend}")
|
||||
|
|
@ -558,8 +558,8 @@ def post_tagger(req: ReqTagger):
|
|||
**Error Codes:**
|
||||
- ``404``: Image not provided
|
||||
"""
|
||||
image = _validate_image(req.image)
|
||||
tags, scores = _do_tagger(image, req)
|
||||
image = validate_image(req.image)
|
||||
tags, scores = do_tagger(image, req)
|
||||
return ResTagger(tags=tags, scores=scores)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class DeepDanbooru:
|
|||
with load_lock:
|
||||
if self.model is not None:
|
||||
return
|
||||
model_path = os.path.join(shared.opts.clip_models_path, "DeepDanbooru")
|
||||
model_path = os.path.join(shared.models_path, "DeepDanbooru")
|
||||
shared.log.debug(f'Caption load: module=DeepDanbooru folder="{model_path}"')
|
||||
files = modelloader.load_models(
|
||||
model_path=model_path,
|
||||
|
|
|
|||
Loading…
Reference in New Issue