mirror of https://github.com/vladmandic/automatic
refactor: update CLI tools for caption module
- Rename cli/api-interrogate.py to cli/api-caption.py - Update cli/options.py, cli/process.py for new module paths - Update cli/test-tagger.py for caption module importspull/4613/head
parent
f4b5abde68
commit
d78c5c1cd0
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
"""
|
"""
|
||||||
use clip to interrogate image(s)
|
use clip to caption image(s)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
|
@ -44,28 +44,28 @@ def print_summary():
|
||||||
log.info({ 'keyword stats': keywords })
|
log.info({ 'keyword stats': keywords })
|
||||||
|
|
||||||
|
|
||||||
async def interrogate(f):
|
async def caption(f):
|
||||||
if not filetype.is_image(f):
|
if not filetype.is_image(f):
|
||||||
log.info({ 'interrogate skip': f })
|
log.info({ 'caption skip': f })
|
||||||
return
|
return
|
||||||
json = Map({ 'image': encode(f) })
|
json = Map({ 'image': encode(f) })
|
||||||
log.info({ 'interrogate': f })
|
log.info({ 'caption': f })
|
||||||
# run clip
|
# run clip
|
||||||
json.model = 'clip'
|
json.model = 'clip'
|
||||||
res = await sdapi.post('/sdapi/v1/interrogate', json)
|
res = await sdapi.post('/sdapi/v1/caption', json)
|
||||||
caption = ""
|
caption = ""
|
||||||
style = ""
|
style = ""
|
||||||
if 'caption' in res:
|
if 'caption' in res:
|
||||||
caption = res.caption
|
caption = res.caption
|
||||||
log.info({ 'interrogate caption': caption })
|
log.info({ 'caption result': caption })
|
||||||
if ', by' in caption:
|
if ', by' in caption:
|
||||||
style = caption.split(', by')[1].strip()
|
style = caption.split(', by')[1].strip()
|
||||||
log.info({ 'interrogate style': style })
|
log.info({ 'caption style': style })
|
||||||
for word in caption.split(' '):
|
for word in caption.split(' '):
|
||||||
if word not in exclude:
|
if word not in exclude:
|
||||||
stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
|
stats['captions'][word] = stats['captions'][word] + 1 if word in stats['captions'] else 1
|
||||||
else:
|
else:
|
||||||
log.error({ 'interrogate clip error': res })
|
log.error({ 'caption clip error': res })
|
||||||
# run tagger (DeepBooru)
|
# run tagger (DeepBooru)
|
||||||
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
|
tagger_req = SimpleNamespace(image=json.image, model='deepbooru', show_scores=True)
|
||||||
res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
|
res = await sdapi.post('/sdapi/v1/tagger', tagger_req)
|
||||||
|
|
@ -74,13 +74,13 @@ async def interrogate(f):
|
||||||
keywords = dict(sorted(res.scores.items(), key=lambda x: x[1], reverse=True))
|
keywords = dict(sorted(res.scores.items(), key=lambda x: x[1], reverse=True))
|
||||||
for word in keywords:
|
for word in keywords:
|
||||||
stats['keywords'][word] = stats['keywords'][word] + 1 if word in stats['keywords'] else 1
|
stats['keywords'][word] = stats['keywords'][word] + 1 if word in stats['keywords'] else 1
|
||||||
log.info({'interrogate keywords': keywords})
|
log.info({'caption keywords': keywords})
|
||||||
elif 'tags' in res:
|
elif 'tags' in res:
|
||||||
for tag in res.tags.split(', '):
|
for tag in res.tags.split(', '):
|
||||||
stats['keywords'][tag] = stats['keywords'][tag] + 1 if tag in stats['keywords'] else 1
|
stats['keywords'][tag] = stats['keywords'][tag] + 1 if tag in stats['keywords'] else 1
|
||||||
log.info({'interrogate tags': res.tags})
|
log.info({'caption tags': res.tags})
|
||||||
else:
|
else:
|
||||||
log.error({'interrogate tagger error': res})
|
log.error({'caption tagger error': res})
|
||||||
return caption, keywords, style
|
return caption, keywords, style
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,19 +88,19 @@ async def main():
|
||||||
sys.argv.pop(0)
|
sys.argv.pop(0)
|
||||||
await sdapi.session()
|
await sdapi.session()
|
||||||
if len(sys.argv) == 0:
|
if len(sys.argv) == 0:
|
||||||
log.error({ 'interrogate': 'no files specified' })
|
log.error({ 'caption': 'no files specified' })
|
||||||
for arg in sys.argv:
|
for arg in sys.argv:
|
||||||
if os.path.exists(arg):
|
if os.path.exists(arg):
|
||||||
if os.path.isfile(arg):
|
if os.path.isfile(arg):
|
||||||
await interrogate(arg)
|
await caption(arg)
|
||||||
elif os.path.isdir(arg):
|
elif os.path.isdir(arg):
|
||||||
for root, _dirs, files in os.walk(arg):
|
for root, _dirs, files in os.walk(arg):
|
||||||
for f in files:
|
for f in files:
|
||||||
_caption, _keywords, _style = await interrogate(os.path.join(root, f))
|
_caption, _keywords, _style = await caption(os.path.join(root, f))
|
||||||
else:
|
else:
|
||||||
log.error({ 'interrogate unknown file type': arg })
|
log.error({ 'caption unknown file type': arg })
|
||||||
else:
|
else:
|
||||||
log.error({ 'interrogate file missing': arg })
|
log.error({ 'caption file missing': arg })
|
||||||
await sdapi.close()
|
await sdapi.close()
|
||||||
print_summary()
|
print_summary()
|
||||||
|
|
||||||
|
|
@ -151,12 +151,12 @@ def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = Fa
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def interrogate_image(res: Result, tag: str = None):
|
def caption_image(res: Result, tag: str = None):
|
||||||
caption = ''
|
caption = ''
|
||||||
tags = []
|
tags = []
|
||||||
for model in options.process.interrogate_model:
|
for model in options.process.caption_model:
|
||||||
json = util.Map({ 'image': encode(res.image), 'model': model })
|
json = util.Map({ 'image': encode(res.image), 'model': model })
|
||||||
result = sdapi.postsync('/sdapi/v1/interrogate', json)
|
result = sdapi.postsync('/sdapi/v1/caption', json)
|
||||||
if model == 'clip':
|
if model == 'clip':
|
||||||
caption = result.caption if 'caption' in result else ''
|
caption = result.caption if 'caption' in result else ''
|
||||||
caption = caption.split(',')[0].replace(' a ', ' ').strip()
|
caption = caption.split(',')[0].replace(' a ', ' ').strip()
|
||||||
|
|
@ -176,7 +176,7 @@ def interrogate_image(res: Result, tag: str = None):
|
||||||
tags = tags[:options.process.tag_limit]
|
tags = tags[:options.process.tag_limit]
|
||||||
res.caption = caption
|
res.caption = caption
|
||||||
res.tags = tags
|
res.tags = tags
|
||||||
res.ops.append('interrogate')
|
res.ops.append('caption')
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -314,8 +314,8 @@ def file(filename: str, folder: str, tag = None, requested = []):
|
||||||
res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
|
res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
|
||||||
res.image = None
|
res.image = None
|
||||||
return res
|
return res
|
||||||
if 'interrogate' in requested:
|
if 'caption' in requested:
|
||||||
res = interrogate_image(res, tag)
|
res = caption_image(res, tag)
|
||||||
if 'resize' in requested:
|
if 'resize' in requested:
|
||||||
res = resize_image(res)
|
res = resize_image(res)
|
||||||
if 'square' in requested:
|
if 'square' in requested:
|
||||||
|
|
|
||||||
|
|
@ -130,9 +130,9 @@ process = Map({
|
||||||
'body_pad': 0.2, # pad body image percentage
|
'body_pad': 0.2, # pad body image percentage
|
||||||
'body_model': 2, # body model to use 0/low 1/medium 2/high
|
'body_model': 2, # body model to use 0/low 1/medium 2/high
|
||||||
# similarity detection settings
|
# similarity detection settings
|
||||||
# interrogate settings
|
# caption settings
|
||||||
'interrogate': False, # interrogate images
|
'caption': False, # caption images
|
||||||
'interrogate_model': ['clip', 'deepdanbooru'], # interrogate models
|
'caption_model': ['clip', 'deepdanbooru'], # caption models
|
||||||
'tag_limit': 5, # number of tags to extract
|
'tag_limit': 5, # number of tags to extract
|
||||||
# validations
|
# validations
|
||||||
# tbd
|
# tbd
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ class TaggerTest:
|
||||||
|
|
||||||
# Load models
|
# Load models
|
||||||
print("\nLoading models...")
|
print("\nLoading models...")
|
||||||
from modules.interrogate import waifudiffusion, deepbooru
|
from modules.caption import waifudiffusion, deepbooru
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
self.waifudiffusion_loaded = waifudiffusion.load_model()
|
self.waifudiffusion_loaded = waifudiffusion.load_model()
|
||||||
|
|
@ -132,7 +132,7 @@ class TaggerTest:
|
||||||
print("CLEANUP")
|
print("CLEANUP")
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
|
|
||||||
from modules.interrogate import waifudiffusion, deepbooru
|
from modules.caption import waifudiffusion, deepbooru
|
||||||
from modules import devices
|
from modules import devices
|
||||||
|
|
||||||
waifudiffusion.unload_model()
|
waifudiffusion.unload_model()
|
||||||
|
|
@ -207,7 +207,7 @@ class TaggerTest:
|
||||||
|
|
||||||
# Test 5: If WaifuDiffusion loaded, check session providers
|
# Test 5: If WaifuDiffusion loaded, check session providers
|
||||||
if self.waifudiffusion_loaded:
|
if self.waifudiffusion_loaded:
|
||||||
from modules.interrogate import waifudiffusion
|
from modules.caption import waifudiffusion
|
||||||
if waifudiffusion.tagger.session is not None:
|
if waifudiffusion.tagger.session is not None:
|
||||||
session_providers = waifudiffusion.tagger.session.get_providers()
|
session_providers = waifudiffusion.tagger.session.get_providers()
|
||||||
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
|
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
|
||||||
|
|
@ -251,7 +251,7 @@ class TaggerTest:
|
||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from modules.interrogate import waifudiffusion, deepbooru
|
from modules.caption import waifudiffusion, deepbooru
|
||||||
|
|
||||||
# Memory leak tolerance (MB) - some variance is expected
|
# Memory leak tolerance (MB) - some variance is expected
|
||||||
GPU_LEAK_TOLERANCE_MB = 50
|
GPU_LEAK_TOLERANCE_MB = 50
|
||||||
|
|
@ -473,7 +473,7 @@ class TaggerTest:
|
||||||
('tagger_show_scores', bool),
|
('tagger_show_scores', bool),
|
||||||
('waifudiffusion_model', str),
|
('waifudiffusion_model', str),
|
||||||
('waifudiffusion_character_threshold', float),
|
('waifudiffusion_character_threshold', float),
|
||||||
('interrogate_offload', bool),
|
('caption_offload', bool),
|
||||||
]
|
]
|
||||||
|
|
||||||
for setting, _expected_type in settings:
|
for setting, _expected_type in settings:
|
||||||
|
|
@ -521,10 +521,10 @@ class TaggerTest:
|
||||||
def tag(self, tagger, **kwargs):
|
def tag(self, tagger, **kwargs):
|
||||||
"""Helper to call the appropriate tagger."""
|
"""Helper to call the appropriate tagger."""
|
||||||
if tagger == 'waifudiffusion':
|
if tagger == 'waifudiffusion':
|
||||||
from modules.interrogate import waifudiffusion
|
from modules.caption import waifudiffusion
|
||||||
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
|
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
|
||||||
else:
|
else:
|
||||||
from modules.interrogate import deepbooru
|
from modules.caption import deepbooru
|
||||||
return deepbooru.model.tag(self.test_image, **kwargs)
|
return deepbooru.model.tag(self.test_image, **kwargs)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
@ -794,7 +794,7 @@ class TaggerTest:
|
||||||
print("TEST: Unified tagger.tag() interface")
|
print("TEST: Unified tagger.tag() interface")
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
|
|
||||||
from modules.interrogate import tagger
|
from modules.caption import tagger
|
||||||
|
|
||||||
# Test WaifuDiffusion through unified interface
|
# Test WaifuDiffusion through unified interface
|
||||||
if self.waifudiffusion_loaded:
|
if self.waifudiffusion_loaded:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue