refactor: update CLI tools for caption module

- Rename cli/api-interrogate.py to cli/api-caption.py
- Update cli/options.py, cli/process.py for new module paths
- Update cli/test-tagger.py for caption module imports
pull/4613/head
CalamitousFelicitousness 2026-01-26 01:18:00 +00:00
parent f4b5abde68
commit d78c5c1cd0
4 changed files with 33 additions and 33 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
"""
use clip to interrogate image(s)
use clip to caption image(s)
"""
import io
@ -44,28 +44,28 @@ def print_summary():
log.info({ 'keyword stats': keywords })
async def interrogate(f):
async def caption(f):
if not filetype.is_image(f):
log.info({ 'interrogate skip': f })
log.info({ 'caption skip': f })
return
json = Map({ 'image': encode(f) })
log.info({ 'interrogate': f })
log.info({ 'caption': f })
# run clip
json.model = 'clip'
res = await sdapi.post('/sdapi/v1/interrogate', json)
res = await sdapi.post('/sdapi/v1/caption', json)
caption = ""
style = ""
if 'caption' in res:
caption = res.caption
log.info({ 'interrogate caption': caption })
log.info({ 'caption result': caption })
if ', by' in caption:
style = caption.split(', by')[1].strip()
log.info({ 'interrogate style': style })
log.info({ 'caption style': style })
for word in caption.split(' '):
if word not in exclude:
stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
else:
log.error({ 'interrogate clip error': res })
log.error({ 'caption clip error': res })
# run tagger (DeepBooru)
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
@ -74,13 +74,13 @@ async def interrogate(f):
keywords = dict(sorted(res.scores.items(), key=lambda x: x[1], reverse=True))
for word in keywords:
stats['keywords'][word] = stats['keywords'][word] + 1 if word in stats['keywords'] else 1
log.info({'interrogate keywords': keywords})
log.info({'caption keywords': keywords})
elif 'tags' in res:
for tag in res.tags.split(', '):
stats['keywords'][tag] = stats['keywords'][tag] + 1 if tag in stats['keywords'] else 1
log.info({'interrogate tags': res.tags})
log.info({'caption tags': res.tags})
else:
log.error({'interrogate tagger error': res})
log.error({'caption tagger error': res})
return caption, keywords, style
@ -88,19 +88,19 @@ async def main():
sys.argv.pop(0)
await sdapi.session()
if len(sys.argv) == 0:
log.error({ 'interrogate': 'no files specified' })
log.error({ 'caption': 'no files specified' })
for arg in sys.argv:
if os.path.exists(arg):
if os.path.isfile(arg):
await interrogate(arg)
await caption(arg)
elif os.path.isdir(arg):
for root, _dirs, files in os.walk(arg):
for f in files:
_caption, _keywords, _style = await interrogate(os.path.join(root, f))
_caption, _keywords, _style = await caption(os.path.join(root, f))
else:
log.error({ 'interrogate unknown file type': arg })
log.error({ 'caption unknown file type': arg })
else:
log.error({ 'interrogate file missing': arg })
log.error({ 'caption file missing': arg })
await sdapi.close()
print_summary()

View File

@ -151,12 +151,12 @@ def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = Fa
return res
def interrogate_image(res: Result, tag: str = None):
def caption_image(res: Result, tag: str = None):
caption = ''
tags = []
for model in options.process.interrogate_model:
for model in options.process.caption_model:
json = util.Map({ 'image': encode(res.image), 'model': model })
result = sdapi.postsync('/sdapi/v1/interrogate', json)
result = sdapi.postsync('/sdapi/v1/caption', json)
if model == 'clip':
caption = result.caption if 'caption' in result else ''
caption = caption.split(',')[0].replace(' a ', ' ').strip()
@ -176,7 +176,7 @@ def interrogate_image(res: Result, tag: str = None):
tags = tags[:options.process.tag_limit]
res.caption = caption
res.tags = tags
res.ops.append('interrogate')
res.ops.append('caption')
return res
@ -314,8 +314,8 @@ def file(filename: str, folder: str, tag = None, requested = []):
res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
res.image = None
return res
if 'interrogate' in requested:
res = interrogate_image(res, tag)
if 'caption' in requested:
res = caption_image(res, tag)
if 'resize' in requested:
res = resize_image(res)
if 'square' in requested:

View File

@ -130,9 +130,9 @@ process = Map({
'body_pad': 0.2, # pad body image percentage
'body_model': 2, # body model to use 0/low 1/medium 2/high
# similarity detection settings
# interrogate settings
'interrogate': False, # interrogate images
'interrogate_model': ['clip', 'deepdanbooru'], # interrogate models
# caption settings
'caption': False, # caption images
'caption_model': ['clip', 'deepdanbooru'], # caption models
'tag_limit': 5, # number of tags to extract
# validations
# tbd

View File

@ -116,7 +116,7 @@ class TaggerTest:
# Load models
print("\nLoading models...")
from modules.interrogate import waifudiffusion, deepbooru
from modules.caption import waifudiffusion, deepbooru
t0 = time.time()
self.waifudiffusion_loaded = waifudiffusion.load_model()
@ -132,7 +132,7 @@ class TaggerTest:
print("CLEANUP")
print("=" * 70)
from modules.interrogate import waifudiffusion, deepbooru
from modules.caption import waifudiffusion, deepbooru
from modules import devices
waifudiffusion.unload_model()
@ -207,7 +207,7 @@ class TaggerTest:
# Test 5: If WaifuDiffusion loaded, check session providers
if self.waifudiffusion_loaded:
from modules.interrogate import waifudiffusion
from modules.caption import waifudiffusion
if waifudiffusion.tagger.session is not None:
session_providers = waifudiffusion.tagger.session.get_providers()
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
@ -251,7 +251,7 @@ class TaggerTest:
import torch
import gc
from modules import devices
from modules.interrogate import waifudiffusion, deepbooru
from modules.caption import waifudiffusion, deepbooru
# Memory leak tolerance (MB) - some variance is expected
GPU_LEAK_TOLERANCE_MB = 50
@ -473,7 +473,7 @@ class TaggerTest:
('tagger_show_scores', bool),
('waifudiffusion_model', str),
('waifudiffusion_character_threshold', float),
('interrogate_offload', bool),
('caption_offload', bool),
]
for setting, _expected_type in settings:
@ -521,10 +521,10 @@ class TaggerTest:
def tag(self, tagger, **kwargs):
"""Helper to call the appropriate tagger."""
if tagger == 'waifudiffusion':
from modules.interrogate import waifudiffusion
from modules.caption import waifudiffusion
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
else:
from modules.interrogate import deepbooru
from modules.caption import deepbooru
return deepbooru.model.tag(self.test_image, **kwargs)
# =========================================================================
@ -794,7 +794,7 @@ class TaggerTest:
print("TEST: Unified tagger.tag() interface")
print("=" * 70)
from modules.interrogate import tagger
from modules.caption import tagger
# Test WaifuDiffusion through unified interface
if self.waifudiffusion_loaded: