refactor: update CLI tools for caption module

- Rename cli/api-interrogate.py to cli/api-caption.py
- Update cli/options.py, cli/process.py for new module paths
- Update cli/test-tagger.py for caption module imports
pull/4613/head
CalamitousFelicitousness 2026-01-26 01:18:00 +00:00
parent f4b5abde68
commit d78c5c1cd0
4 changed files with 33 additions and 33 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
""" """
use clip to interrogate image(s) use clip to caption image(s)
""" """
import io import io
@ -44,28 +44,28 @@ def print_summary():
log.info({ 'keyword stats': keywords }) log.info({ 'keyword stats': keywords })
async def interrogate(f): async def caption(f):
if not filetype.is_image(f): if not filetype.is_image(f):
log.info({ 'interrogate skip': f }) log.info({ 'caption skip': f })
return return
json = Map({ 'image': encode(f) }) json = Map({ 'image': encode(f) })
log.info({ 'interrogate': f }) log.info({ 'caption': f })
# run clip # run clip
json.model = 'clip' json.model = 'clip'
res = await sdapi.post('/sdapi/v1/interrogate', json) res = await sdapi.post('/sdapi/v1/caption', json)
caption = "" caption = ""
style = "" style = ""
if 'caption' in res: if 'caption' in res:
caption = res.caption caption = res.caption
log.info({ 'interrogate caption': caption }) log.info({ 'caption result': caption })
if ', by' in caption: if ', by' in caption:
style = caption.split(', by')[1].strip() style = caption.split(', by')[1].strip()
log.info({ 'interrogate style': style }) log.info({ 'caption style': style })
for word in caption.split(' '): for word in caption.split(' '):
if word not in exclude: if word not in exclude:
stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1 stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
else: else:
log.error({ 'interrogate clip error': res }) log.error({ 'caption clip error': res })
# run tagger (DeepBooru) # run tagger (DeepBooru)
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True) tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
res = await sdapi.post('/sdapi/v1/tagger', tagger_req) res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
@ -74,13 +74,13 @@ async def interrogate(f):
keywords = dict(sorted(res.scores.items(), key=lambda x: x[1], reverse=True)) keywords = dict(sorted(res.scores.items(), key=lambda x: x[1], reverse=True))
for word in keywords: for word in keywords:
stats['keywords'][word] = stats['keywords'][word] + 1 if word in stats['keywords'] else 1 stats['keywords'][word] = stats['keywords'][word] + 1 if word in stats['keywords'] else 1
log.info({'interrogate keywords': keywords}) log.info({'caption keywords': keywords})
elif 'tags' in res: elif 'tags' in res:
for tag in res.tags.split(', '): for tag in res.tags.split(', '):
stats['keywords'][tag] = stats['keywords'][tag] + 1 if tag in stats['keywords'] else 1 stats['keywords'][tag] = stats['keywords'][tag] + 1 if tag in stats['keywords'] else 1
log.info({'interrogate tags': res.tags}) log.info({'caption tags': res.tags})
else: else:
log.error({'interrogate tagger error': res}) log.error({'caption tagger error': res})
return caption, keywords, style return caption, keywords, style
@ -88,19 +88,19 @@ async def main():
sys.argv.pop(0) sys.argv.pop(0)
await sdapi.session() await sdapi.session()
if len(sys.argv) == 0: if len(sys.argv) == 0:
log.error({ 'interrogate': 'no files specified' }) log.error({ 'caption': 'no files specified' })
for arg in sys.argv: for arg in sys.argv:
if os.path.exists(arg): if os.path.exists(arg):
if os.path.isfile(arg): if os.path.isfile(arg):
await interrogate(arg) await caption(arg)
elif os.path.isdir(arg): elif os.path.isdir(arg):
for root, _dirs, files in os.walk(arg): for root, _dirs, files in os.walk(arg):
for f in files: for f in files:
_caption, _keywords, _style = await interrogate(os.path.join(root, f)) _caption, _keywords, _style = await caption(os.path.join(root, f))
else: else:
log.error({ 'interrogate unknown file type': arg }) log.error({ 'caption unknown file type': arg })
else: else:
log.error({ 'interrogate file missing': arg }) log.error({ 'caption file missing': arg })
await sdapi.close() await sdapi.close()
print_summary() print_summary()

View File

@ -151,12 +151,12 @@ def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = Fa
return res return res
def interrogate_image(res: Result, tag: str = None): def caption_image(res: Result, tag: str = None):
caption = '' caption = ''
tags = [] tags = []
for model in options.process.interrogate_model: for model in options.process.caption_model:
json = util.Map({ 'image': encode(res.image), 'model': model }) json = util.Map({ 'image': encode(res.image), 'model': model })
result = sdapi.postsync('/sdapi/v1/interrogate', json) result = sdapi.postsync('/sdapi/v1/caption', json)
if model == 'clip': if model == 'clip':
caption = result.caption if 'caption' in result else '' caption = result.caption if 'caption' in result else ''
caption = caption.split(',')[0].replace(' a ', ' ').strip() caption = caption.split(',')[0].replace(' a ', ' ').strip()
@ -176,7 +176,7 @@ def interrogate_image(res: Result, tag: str = None):
tags = tags[:options.process.tag_limit] tags = tags[:options.process.tag_limit]
res.caption = caption res.caption = caption
res.tags = tags res.tags = tags
res.ops.append('interrogate') res.ops.append('caption')
return res return res
@ -314,8 +314,8 @@ def file(filename: str, folder: str, tag = None, requested = []):
res.message = f'low resolution: [{res.image.width}, {res.image.height}]' res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
res.image = None res.image = None
return res return res
if 'interrogate' in requested: if 'caption' in requested:
res = interrogate_image(res, tag) res = caption_image(res, tag)
if 'resize' in requested: if 'resize' in requested:
res = resize_image(res) res = resize_image(res)
if 'square' in requested: if 'square' in requested:

View File

@ -130,9 +130,9 @@ process = Map({
'body_pad': 0.2, # pad body image percentage 'body_pad': 0.2, # pad body image percentage
'body_model': 2, # body model to use 0/low 1/medium 2/high 'body_model': 2, # body model to use 0/low 1/medium 2/high
# similarity detection settings # similarity detection settings
# interrogate settings # caption settings
'interrogate': False, # interrogate images 'caption': False, # caption images
'interrogate_model': ['clip', 'deepdanbooru'], # interrogate models 'caption_model': ['clip', 'deepdanbooru'], # caption models
'tag_limit': 5, # number of tags to extract 'tag_limit': 5, # number of tags to extract
# validations # validations
# tbd # tbd

View File

@ -116,7 +116,7 @@ class TaggerTest:
# Load models # Load models
print("\nLoading models...") print("\nLoading models...")
from modules.interrogate import waifudiffusion, deepbooru from modules.caption import waifudiffusion, deepbooru
t0 = time.time() t0 = time.time()
self.waifudiffusion_loaded = waifudiffusion.load_model() self.waifudiffusion_loaded = waifudiffusion.load_model()
@ -132,7 +132,7 @@ class TaggerTest:
print("CLEANUP") print("CLEANUP")
print("=" * 70) print("=" * 70)
from modules.interrogate import waifudiffusion, deepbooru from modules.caption import waifudiffusion, deepbooru
from modules import devices from modules import devices
waifudiffusion.unload_model() waifudiffusion.unload_model()
@ -207,7 +207,7 @@ class TaggerTest:
# Test 5: If WaifuDiffusion loaded, check session providers # Test 5: If WaifuDiffusion loaded, check session providers
if self.waifudiffusion_loaded: if self.waifudiffusion_loaded:
from modules.interrogate import waifudiffusion from modules.caption import waifudiffusion
if waifudiffusion.tagger.session is not None: if waifudiffusion.tagger.session is not None:
session_providers = waifudiffusion.tagger.session.get_providers() session_providers = waifudiffusion.tagger.session.get_providers()
self.log_pass(f"WaifuDiffusion session providers: {session_providers}") self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
@ -251,7 +251,7 @@ class TaggerTest:
import torch import torch
import gc import gc
from modules import devices from modules import devices
from modules.interrogate import waifudiffusion, deepbooru from modules.caption import waifudiffusion, deepbooru
# Memory leak tolerance (MB) - some variance is expected # Memory leak tolerance (MB) - some variance is expected
GPU_LEAK_TOLERANCE_MB = 50 GPU_LEAK_TOLERANCE_MB = 50
@ -473,7 +473,7 @@ class TaggerTest:
('tagger_show_scores', bool), ('tagger_show_scores', bool),
('waifudiffusion_model', str), ('waifudiffusion_model', str),
('waifudiffusion_character_threshold', float), ('waifudiffusion_character_threshold', float),
('interrogate_offload', bool), ('caption_offload', bool),
] ]
for setting, _expected_type in settings: for setting, _expected_type in settings:
@ -521,10 +521,10 @@ class TaggerTest:
def tag(self, tagger, **kwargs): def tag(self, tagger, **kwargs):
"""Helper to call the appropriate tagger.""" """Helper to call the appropriate tagger."""
if tagger == 'waifudiffusion': if tagger == 'waifudiffusion':
from modules.interrogate import waifudiffusion from modules.caption import waifudiffusion
return waifudiffusion.tagger.predict(self.test_image, **kwargs) return waifudiffusion.tagger.predict(self.test_image, **kwargs)
else: else:
from modules.interrogate import deepbooru from modules.caption import deepbooru
return deepbooru.model.tag(self.test_image, **kwargs) return deepbooru.model.tag(self.test_image, **kwargs)
# ========================================================================= # =========================================================================
@ -794,7 +794,7 @@ class TaggerTest:
print("TEST: Unified tagger.tag() interface") print("TEST: Unified tagger.tag() interface")
print("=" * 70) print("=" * 70)
from modules.interrogate import tagger from modules.caption import tagger
# Test WaifuDiffusion through unified interface # Test WaifuDiffusion through unified interface
if self.waifudiffusion_loaded: if self.waifudiffusion_loaded: