mirror of https://github.com/vladmandic/automatic
refactor: update CLI tools for caption module
- Rename cli/api-interrogate.py to cli/api-caption.py - Update cli/options.py, cli/process.py for new module paths - Update cli/test-tagger.py for caption module importspull/4613/head
parent
f4b5abde68
commit
d78c5c1cd0
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
use clip to interrogate image(s)
|
||||
use clip to caption image(s)
|
||||
"""
|
||||
|
||||
import io
|
||||
|
|
@ -44,28 +44,28 @@ def print_summary():
|
|||
log.info({ 'keyword stats': keywords })
|
||||
|
||||
|
||||
async def interrogate(f):
|
||||
async def caption(f):
|
||||
if not filetype.is_image(f):
|
||||
log.info({ 'interrogate skip': f })
|
||||
log.info({ 'caption skip': f })
|
||||
return
|
||||
json = Map({ 'image': encode(f) })
|
||||
log.info({ 'interrogate': f })
|
||||
log.info({ 'caption': f })
|
||||
# run clip
|
||||
json.model = 'clip'
|
||||
res = await sdapi.post('/sdapi/v1/interrogate', json)
|
||||
res = await sdapi.post('/sdapi/v1/caption', json)
|
||||
caption = ""
|
||||
style = ""
|
||||
if 'caption' in res:
|
||||
caption = res.caption
|
||||
log.info({ 'interrogate caption': caption })
|
||||
log.info({ 'caption result': caption })
|
||||
if ', by' in caption:
|
||||
style = caption.split(', by')[1].strip()
|
||||
log.info({ 'interrogate style': style })
|
||||
log.info({ 'caption style': style })
|
||||
for word in caption.split(' '):
|
||||
if word not in exclude:
|
||||
stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
|
||||
else:
|
||||
log.error({ 'interrogate clip error': res })
|
||||
log.error({ 'caption clip error': res })
|
||||
# run tagger (DeepBooru)
|
||||
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
|
||||
res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
|
||||
|
|
@ -74,13 +74,13 @@ async def interrogate(f):
|
|||
keywords = dict(sorted(res.scores.items(), key=lambda x: x[1], reverse=True))
|
||||
for word in keywords:
|
||||
stats['keywords'][word] = stats['keywords'][word] + 1 if word in stats['keywords'] else 1
|
||||
log.info({'interrogate keywords': keywords})
|
||||
log.info({'caption keywords': keywords})
|
||||
elif 'tags' in res:
|
||||
for tag in res.tags.split(', '):
|
||||
stats['keywords'][tag] = stats['keywords'][tag] + 1 if tag in stats['keywords'] else 1
|
||||
log.info({'interrogate tags': res.tags})
|
||||
log.info({'caption tags': res.tags})
|
||||
else:
|
||||
log.error({'interrogate tagger error': res})
|
||||
log.error({'caption tagger error': res})
|
||||
return caption, keywords, style
|
||||
|
||||
|
||||
|
|
@ -88,19 +88,19 @@ async def main():
|
|||
sys.argv.pop(0)
|
||||
await sdapi.session()
|
||||
if len(sys.argv) == 0:
|
||||
log.error({ 'interrogate': 'no files specified' })
|
||||
log.error({ 'caption': 'no files specified' })
|
||||
for arg in sys.argv:
|
||||
if os.path.exists(arg):
|
||||
if os.path.isfile(arg):
|
||||
await interrogate(arg)
|
||||
await caption(arg)
|
||||
elif os.path.isdir(arg):
|
||||
for root, _dirs, files in os.walk(arg):
|
||||
for f in files:
|
||||
_caption, _keywords, _style = await interrogate(os.path.join(root, f))
|
||||
_caption, _keywords, _style = await caption(os.path.join(root, f))
|
||||
else:
|
||||
log.error({ 'interrogate unknown file type': arg })
|
||||
log.error({ 'caption unknown file type': arg })
|
||||
else:
|
||||
log.error({ 'interrogate file missing': arg })
|
||||
log.error({ 'caption file missing': arg })
|
||||
await sdapi.close()
|
||||
print_summary()
|
||||
|
||||
|
|
@ -151,12 +151,12 @@ def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = Fa
|
|||
return res
|
||||
|
||||
|
||||
def interrogate_image(res: Result, tag: str = None):
|
||||
def caption_image(res: Result, tag: str = None):
|
||||
caption = ''
|
||||
tags = []
|
||||
for model in options.process.interrogate_model:
|
||||
for model in options.process.caption_model:
|
||||
json = util.Map({ 'image': encode(res.image), 'model': model })
|
||||
result = sdapi.postsync('/sdapi/v1/interrogate', json)
|
||||
result = sdapi.postsync('/sdapi/v1/caption', json)
|
||||
if model == 'clip':
|
||||
caption = result.caption if 'caption' in result else ''
|
||||
caption = caption.split(',')[0].replace(' a ', ' ').strip()
|
||||
|
|
@ -176,7 +176,7 @@ def interrogate_image(res: Result, tag: str = None):
|
|||
tags = tags[:options.process.tag_limit]
|
||||
res.caption = caption
|
||||
res.tags = tags
|
||||
res.ops.append('interrogate')
|
||||
res.ops.append('caption')
|
||||
return res
|
||||
|
||||
|
||||
|
|
@ -314,8 +314,8 @@ def file(filename: str, folder: str, tag = None, requested = []):
|
|||
res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
|
||||
res.image = None
|
||||
return res
|
||||
if 'interrogate' in requested:
|
||||
res = interrogate_image(res, tag)
|
||||
if 'caption' in requested:
|
||||
res = caption_image(res, tag)
|
||||
if 'resize' in requested:
|
||||
res = resize_image(res)
|
||||
if 'square' in requested:
|
||||
|
|
|
|||
|
|
@ -130,9 +130,9 @@ process = Map({
|
|||
'body_pad': 0.2, # pad body image percentage
|
||||
'body_model': 2, # body model to use 0/low 1/medium 2/high
|
||||
# similarity detection settings
|
||||
# interrogate settings
|
||||
'interrogate': False, # interrogate images
|
||||
'interrogate_model': ['clip', 'deepdanbooru'], # interrogate models
|
||||
# caption settings
|
||||
'caption': False, # caption images
|
||||
'caption_model': ['clip', 'deepdanbooru'], # caption models
|
||||
'tag_limit': 5, # number of tags to extract
|
||||
# validations
|
||||
# tbd
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class TaggerTest:
|
|||
|
||||
# Load models
|
||||
print("\nLoading models...")
|
||||
from modules.interrogate import waifudiffusion, deepbooru
|
||||
from modules.caption import waifudiffusion, deepbooru
|
||||
|
||||
t0 = time.time()
|
||||
self.waifudiffusion_loaded = waifudiffusion.load_model()
|
||||
|
|
@ -132,7 +132,7 @@ class TaggerTest:
|
|||
print("CLEANUP")
|
||||
print("=" * 70)
|
||||
|
||||
from modules.interrogate import waifudiffusion, deepbooru
|
||||
from modules.caption import waifudiffusion, deepbooru
|
||||
from modules import devices
|
||||
|
||||
waifudiffusion.unload_model()
|
||||
|
|
@ -207,7 +207,7 @@ class TaggerTest:
|
|||
|
||||
# Test 5: If WaifuDiffusion loaded, check session providers
|
||||
if self.waifudiffusion_loaded:
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
if waifudiffusion.tagger.session is not None:
|
||||
session_providers = waifudiffusion.tagger.session.get_providers()
|
||||
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
|
||||
|
|
@ -251,7 +251,7 @@ class TaggerTest:
|
|||
import torch
|
||||
import gc
|
||||
from modules import devices
|
||||
from modules.interrogate import waifudiffusion, deepbooru
|
||||
from modules.caption import waifudiffusion, deepbooru
|
||||
|
||||
# Memory leak tolerance (MB) - some variance is expected
|
||||
GPU_LEAK_TOLERANCE_MB = 50
|
||||
|
|
@ -473,7 +473,7 @@ class TaggerTest:
|
|||
('tagger_show_scores', bool),
|
||||
('waifudiffusion_model', str),
|
||||
('waifudiffusion_character_threshold', float),
|
||||
('interrogate_offload', bool),
|
||||
('caption_offload', bool),
|
||||
]
|
||||
|
||||
for setting, _expected_type in settings:
|
||||
|
|
@ -521,10 +521,10 @@ class TaggerTest:
|
|||
def tag(self, tagger, **kwargs):
|
||||
"""Helper to call the appropriate tagger."""
|
||||
if tagger == 'waifudiffusion':
|
||||
from modules.interrogate import waifudiffusion
|
||||
from modules.caption import waifudiffusion
|
||||
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
|
||||
else:
|
||||
from modules.interrogate import deepbooru
|
||||
from modules.caption import deepbooru
|
||||
return deepbooru.model.tag(self.test_image, **kwargs)
|
||||
|
||||
# =========================================================================
|
||||
|
|
@ -794,7 +794,7 @@ class TaggerTest:
|
|||
print("TEST: Unified tagger.tag() interface")
|
||||
print("=" * 70)
|
||||
|
||||
from modules.interrogate import tagger
|
||||
from modules.caption import tagger
|
||||
|
||||
# Test WaifuDiffusion through unified interface
|
||||
if self.waifudiffusion_loaded:
|
||||
|
|
|
|||
Loading…
Reference in New Issue