mirror of https://github.com/vladmandic/automatic
native implementation for interrogator
parent
426729aa6f
commit
418279ff7b
|
|
@ -12,10 +12,6 @@
|
|||
path = modules/lora
|
||||
url = https://github.com/kohya-ss/sd-scripts
|
||||
ignore = dirty
|
||||
[submodule "extensions-builtin/clip-interrogator-ext"]
|
||||
path = extensions-builtin/clip-interrogator-ext
|
||||
url = https://github.com/Dahvikiin/clip-interrogator-ext.git
|
||||
ignore = dirty
|
||||
[submodule "extensions-builtin/sd-webui-controlnet"]
|
||||
path = extensions-builtin/sd-webui-controlnet
|
||||
url = https://github.com/Mikubill/sd-webui-controlnet
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ Some highlights: [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVIN
|
|||
- new option: *settings -> system paths -> models*
|
||||
can be used to set custom base path for *all* models (previously only as cli option)
|
||||
- remove external clone of items in `/repositories`
|
||||
- **Interrogator** module has been removed from `extensions-builtin`
|
||||
and fully implemented (and improved) natively
|
||||
- **UI**
|
||||
- UI tweaks for default themes
|
||||
- UI switch core font in default theme to **noto-sans**
|
||||
|
|
|
|||
|
|
@ -124,10 +124,7 @@ SD.Next comes with several extensions pre-installed:
|
|||
|
||||
- [ControlNet](https://github.com/Mikubill/sd-webui-controlnet)
|
||||
- [Agent Scheduler](https://github.com/ArtVentureX/sd-webui-agent-scheduler)
|
||||
- [Multi-Diffusion Tiled Diffusion and VAE](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111)
|
||||
- [LyCORIS](https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris)
|
||||
- [Image Browser](https://github.com/AlUlkesh/stable-diffusion-webui-images-browser)
|
||||
- [CLiP Interrogator](https://github.com/pharmapsychotic/clip-interrogator-ext)
|
||||
- [Rembg Background Removal](https://github.com/AUTOMATIC1111/stable-diffusion-webui-rembg)
|
||||
|
||||
### **Collab**
|
||||
|
|
|
|||
|
|
@ -40,9 +40,9 @@
|
|||
{"id":"","label":"disabled","localized":"","hint":""}
|
||||
],
|
||||
"tabs": [
|
||||
{"id":"","label":"From Text","localized":"","hint":"Create image from text"},
|
||||
{"id":"","label":"From Image","localized":"","hint":"Create image from image"},
|
||||
{"id":"","label":"Process Image","localized":"","hint":"Process existing image"},
|
||||
{"id":"","label":"Text","localized":"","hint":"Create image from text"},
|
||||
{"id":"","label":"Image","localized":"","hint":"Create image from image"},
|
||||
{"id":"","label":"Process","localized":"","hint":"Process existing image"},
|
||||
{"id":"","label":"Train","localized":"","hint":"Run training or model merging"},
|
||||
{"id":"","label":"Models","localized":"","hint":"Convert or merge your models"},
|
||||
{"id":"","label":"Interrogator","localized":"","hint":"Run interrogate to get description of your image"},
|
||||
|
|
|
|||
|
|
@ -238,13 +238,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
--radius-md: 0;
|
||||
--radius-xl: 0;
|
||||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 16px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--font: 'Source Sans Pro', 'ui-sans-serif', 'system-ui', sans-serif;
|
||||
--font-mono: 'IBM Plex Mono', 'ui-monospace', 'Consolas', monospace;
|
||||
--body-text-size: var(--text-md);
|
||||
|
|
|
|||
|
|
@ -38,9 +38,6 @@
|
|||
--spacing-xxl: 6px;
|
||||
--line-sm: 1.2em;
|
||||
--line-md: 1.4em;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 15px;
|
||||
}
|
||||
|
||||
html { font-size: var(--font-size); }
|
||||
|
|
@ -253,10 +250,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
--radius-md: 0;
|
||||
--radius-xl: 0;
|
||||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--body-text-size: var(--text-md);
|
||||
--body-text-weight: 400;
|
||||
--embed-radius: var(--radius-lg);
|
||||
|
|
|
|||
|
|
@ -33,9 +33,6 @@
|
|||
--radius-lg: 4px;
|
||||
--line-sm: 1.2em;
|
||||
--line-md: 1.4em;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 15px;
|
||||
}
|
||||
|
||||
html { font-size: var(--font-size); font-family: var(--font); }
|
||||
|
|
@ -246,10 +243,6 @@ textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8p
|
|||
--radius-md: 0;
|
||||
--radius-xl: 0;
|
||||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--body-text-size: var(--text-md);
|
||||
--body-text-weight: 400;
|
||||
--embed-radius: var(--radius-lg);
|
||||
|
|
|
|||
|
|
@ -234,13 +234,6 @@ button.selected {background: var(--button-primary-background-fill);}
|
|||
--radius-md: 0;
|
||||
--radius-xl: 0;
|
||||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 16px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--body-text-size: var(--text-md);
|
||||
--body-text-weight: 400;
|
||||
--embed-radius: var(--radius-lg);
|
||||
|
|
|
|||
|
|
@ -33,9 +33,6 @@
|
|||
--radius-lg: 4px;
|
||||
--line-sm: 1.2em;
|
||||
--line-md: 1.4em;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 15px;
|
||||
}
|
||||
|
||||
html { font-size: var(--font-size); }
|
||||
|
|
@ -309,8 +306,4 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
--table-odd-background-fill: #333333;
|
||||
--table-radius: var(--radius-lg);
|
||||
--table-row-focus: var(--color-accent-soft);
|
||||
--text-lg: 16px;
|
||||
--text-xs: 10px;
|
||||
--text-xxl: 26px;
|
||||
--text-xxs: 9px;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -239,13 +239,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
--radius-md: 0;
|
||||
--radius-xl: 0;
|
||||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 16px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--body-text-size: var(--text-md);
|
||||
--body-text-weight: 400;
|
||||
--embed-radius: var(--radius-lg);
|
||||
|
|
|
|||
|
|
@ -263,3 +263,13 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
|
|||
@keyframes move { from { background-position-x: 0, -40px; } to { background-position-x: 0, 40px; } }
|
||||
@keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } }
|
||||
@keyframes color { from { filter: hue-rotate(0deg) } to { filter: hue-rotate(360deg) } }
|
||||
|
||||
:root, .light, .dark {
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 16px;
|
||||
--text-xl: 18px;
|
||||
--text-xxl: 20px;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,4 +41,4 @@ errors.install([gradio])
|
|||
|
||||
import diffusers # pylint: disable=W0611,C0411
|
||||
timer.startup.record("diffusers")
|
||||
errors.log.debug(f'Load packages: torch={getattr(torch, "__long_version__", torch.__version__)} diffusers={diffusers.__version__} gradio={gradio.__version__}')
|
||||
errors.log.info(f'Load packages: torch={getattr(torch, "__long_version__", torch.__version__)} diffusers={diffusers.__version__} gradio={gradio.__version__}')
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, ui_loadsave, ui_train, ui_models
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, ui_loadsave, ui_train, ui_models, ui_interrogate
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||
from modules.paths import script_path, data_path
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
|
@ -635,7 +635,6 @@ def create_ui(startup_timer = None):
|
|||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory", **modules.shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
|
||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||
|
||||
for i, tab in enumerate(img2img_tabs):
|
||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||
|
||||
|
|
@ -903,6 +902,11 @@ def create_ui(startup_timer = None):
|
|||
ui_models.create_ui()
|
||||
timer.startup.record("ui-models")
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as interrogate_interface:
|
||||
ui_interrogate.create_ui()
|
||||
timer.startup.record("ui-interrogate")
|
||||
|
||||
|
||||
def create_setting_component(key, is_quicksettings=False):
|
||||
def fun():
|
||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||
|
|
@ -1102,11 +1106,12 @@ def create_ui(startup_timer = None):
|
|||
timer.startup.record("ui-settings")
|
||||
|
||||
interfaces = [
|
||||
(txt2img_interface, "From Text", "txt2img"),
|
||||
(img2img_interface, "From Image", "img2img"),
|
||||
(extras_interface, "Process Image", "process"),
|
||||
(txt2img_interface, "Text", "txt2img"),
|
||||
(img2img_interface, "Image", "img2img"),
|
||||
(extras_interface, "Process", "process"),
|
||||
(train_interface, "Train", "train"),
|
||||
(models_interface, "Models", "models"),
|
||||
(interrogate_interface, "Interrogate", "interrogate"),
|
||||
]
|
||||
interfaces += script_callbacks.ui_tabs_callback()
|
||||
interfaces += [(settings_interface, "System", "system")]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,267 @@
|
|||
import os
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import gradio as gr
|
||||
import open_clip
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field # pylint: disable=no-name-in-module
|
||||
from fastapi import FastAPI
|
||||
from fastapi.exceptions import HTTPException
|
||||
from clip_interrogator import Config, Interrogator
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
from modules import devices, lowvram, shared, paths
|
||||
|
||||
|
||||
ci = None
|
||||
low_vram = False
|
||||
|
||||
|
||||
class BatchWriter:
|
||||
def __init__(self, folder):
|
||||
self.folder = folder
|
||||
self.csv, self.file = None, None
|
||||
|
||||
def add(self, file, prompt):
|
||||
txt_file = os.path.splitext(file)[0] + ".txt"
|
||||
with open(os.path.join(self.folder, txt_file), 'w', encoding='utf-8') as f:
|
||||
f.write(prompt)
|
||||
|
||||
def close(self):
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
|
||||
|
||||
def load(clip_model_name):
|
||||
global ci # pylint: disable=global-statement
|
||||
if ci is None:
|
||||
config = Config(device=devices.get_optimal_device(), cache_path=os.path.join(paths.models_path, 'clip-interrogator'), clip_model_name=clip_model_name, quiet=True)
|
||||
if low_vram:
|
||||
config.apply_low_vram_defaults()
|
||||
shared.log.info(f'Interrogate load: config={config}')
|
||||
ci = Interrogator(config)
|
||||
elif clip_model_name != ci.config.clip_model_name:
|
||||
ci.config.clip_model_name = clip_model_name
|
||||
shared.log.info(f'Interrogate load: config={ci.config}')
|
||||
ci.load_clip_model()
|
||||
|
||||
|
||||
def unload():
|
||||
if ci is not None:
|
||||
shared.log.debug('Interrogate offload')
|
||||
ci.caption_model = ci.caption_model.to(devices.cpu)
|
||||
ci.clip_model = ci.clip_model.to(devices.cpu)
|
||||
ci.caption_offloaded = True
|
||||
ci.clip_offloaded = True
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def image_analysis(image, clip_model_name):
|
||||
load(clip_model_name)
|
||||
image = image.convert('RGB')
|
||||
image_features = ci.image_to_features(image)
|
||||
top_mediums = ci.mediums.rank(image_features, 5)
|
||||
top_artists = ci.artists.rank(image_features, 5)
|
||||
top_movements = ci.movements.rank(image_features, 5)
|
||||
top_trendings = ci.trendings.rank(image_features, 5)
|
||||
top_flavors = ci.flavors.rank(image_features, 5)
|
||||
medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))}
|
||||
artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))}
|
||||
movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))}
|
||||
trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))}
|
||||
flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))}
|
||||
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
|
||||
|
||||
|
||||
def interrogate(image, mode, caption=None):
|
||||
shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}')
|
||||
if mode == 'best':
|
||||
prompt = ci.interrogate(image, caption=caption)
|
||||
elif mode == 'caption':
|
||||
prompt = ci.generate_caption(image) if caption is None else caption
|
||||
elif mode == 'classic':
|
||||
prompt = ci.interrogate_classic(image, caption=caption)
|
||||
elif mode == 'fast':
|
||||
prompt = ci.interrogate_fast(image, caption=caption)
|
||||
elif mode == 'negative':
|
||||
prompt = ci.interrogate_negative(image)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown mode {mode}")
|
||||
return prompt
|
||||
|
||||
|
||||
def image_to_prompt(image, mode, clip_model_name):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'interrogate'
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
load(clip_model_name)
|
||||
image = image.convert('RGB')
|
||||
shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}')
|
||||
prompt = interrogate(image, mode)
|
||||
except Exception as e:
|
||||
prompt = f"Exception {type(e)}"
|
||||
shared.log.error(f'Interrogate: {e}')
|
||||
shared.state.end()
|
||||
return prompt
|
||||
|
||||
|
||||
def get_models():
|
||||
return ['/'.join(x) for x in open_clip.list_pretrained()]
|
||||
|
||||
|
||||
def batch_process(batch_files, batch_folder, batch_str, mode, clip_model, write):
|
||||
files = []
|
||||
if batch_files is not None:
|
||||
files += [f.name for f in batch_files]
|
||||
if batch_folder is not None:
|
||||
files += [f.name for f in batch_folder]
|
||||
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
|
||||
files += [os.path.join(batch_str, f) for f in os.listdir(batch_str) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
|
||||
if len(files) == 0:
|
||||
shared.log.error('Interrogate batch no images')
|
||||
return ''
|
||||
shared.log.info(f'Interrogate batch: images={len(files)} mode={mode} config={ci.config}')
|
||||
shared.state.begin()
|
||||
shared.state.job = 'batch interrogate'
|
||||
prompts = []
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
load(clip_model)
|
||||
captions = []
|
||||
# first pass: generate captions
|
||||
for file in files:
|
||||
caption = ""
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
image = Image.open(file).convert('RGB')
|
||||
caption = ci.generate_caption(image)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Interrogate caption: {e}')
|
||||
finally:
|
||||
captions.append(caption)
|
||||
# second pass: interrogate
|
||||
if write:
|
||||
writer = BatchWriter(os.path.dirname(files[0]))
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
if shared.state.interrupted:
|
||||
break
|
||||
image = Image.open(file).convert('RGB')
|
||||
prompt = interrogate(image, mode, caption=captions[idx])
|
||||
prompts.append(prompt)
|
||||
if write:
|
||||
writer.add(file, prompt)
|
||||
except OSError as e:
|
||||
shared.log.error(f'Interrogate batch: {e}')
|
||||
if write:
|
||||
writer.close()
|
||||
ci.config.quiet = False
|
||||
unload()
|
||||
except Exception as e:
|
||||
shared.log.error(f'Interrogate batch: {e}')
|
||||
shared.state.end()
|
||||
return '\n\n'.join(prompts)
|
||||
|
||||
|
||||
def create_ui():
|
||||
global low_vram # pylint: disable=global-statement
|
||||
low_vram = shared.cmd_opts.lowvram or shared.cmd_opts.medvram
|
||||
if not low_vram and torch.cuda.is_available():
|
||||
device = devices.get_optimal_device()
|
||||
vram_total = torch.cuda.get_device_properties(device).total_memory
|
||||
if vram_total <= 12*1024*1024*1024:
|
||||
low_vram = True
|
||||
with gr.Row(elem_id="interrogate_tab"):
|
||||
with gr.Column():
|
||||
with gr.Tab("Image"):
|
||||
with gr.Row():
|
||||
image = gr.Image(type='pil', label="Image")
|
||||
with gr.Row():
|
||||
prompt = gr.Textbox(label="Prompt", lines=3)
|
||||
with gr.Row():
|
||||
medium = gr.Label(label="Medium", num_top_classes=5)
|
||||
artist = gr.Label(label="Artist", num_top_classes=5)
|
||||
movement = gr.Label(label="Movement", num_top_classes=5)
|
||||
trending = gr.Label(label="Trending", num_top_classes=5)
|
||||
flavor = gr.Label(label="Flavor", num_top_classes=5)
|
||||
with gr.Row():
|
||||
interrogate_btn = gr.Button("Interrogate", variant='primary')
|
||||
analyze_btn = gr.Button("Analyze", variant='primary')
|
||||
unload_btn = gr.Button("Unload")
|
||||
with gr.Row():
|
||||
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "extras"])
|
||||
for tabname, button in buttons.items():
|
||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=prompt, source_image_component=image,))
|
||||
with gr.Tab("Batch"):
|
||||
with gr.Row():
|
||||
batch_files = gr.File(label="Files", show_label=True, file_count='multiple', file_types=['image'], type='file', interactive=True, height=100)
|
||||
with gr.Row():
|
||||
batch_folder = gr.File(label="Folder", show_label=True, file_count='directory', file_types=['image'], type='file', interactive=True, height=100)
|
||||
with gr.Row():
|
||||
batch_str = gr.Text(label="Folder", value="", interactive=True)
|
||||
with gr.Row():
|
||||
batch = gr.Text(label="Prompts", lines=10)
|
||||
with gr.Row():
|
||||
write = gr.Checkbox(label='Write prompts to files', value=False)
|
||||
with gr.Row():
|
||||
batch_btn = gr.Button("Interrogate", variant='primary')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
clip_model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model')
|
||||
with gr.Row():
|
||||
mode = gr.Radio(['best', 'fast', 'classic', 'caption', 'negative'], label='Mode', value='best')
|
||||
interrogate_btn.click(image_to_prompt, inputs=[image, mode, clip_model], outputs=prompt)
|
||||
analyze_btn.click(image_analysis, inputs=[image, clip_model], outputs=[medium, artist, movement, trending, flavor])
|
||||
unload_btn.click(unload)
|
||||
batch_btn.click(batch_process, inputs=[batch_files, batch_folder, batch_str, mode, clip_model, write], outputs=[batch])
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
|
||||
|
||||
def mount_interrogator_api(_: gr.Blocks, app: FastAPI): # TODO redesign interrogator api
|
||||
|
||||
class InterrogatorAnalyzeRequest(BaseModel):
|
||||
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
|
||||
clip_model_name: str = Field(default="ViT-L-14/openai", title="Model", description="The interrogate model used. See the models endpoint for a list of available models.")
|
||||
|
||||
class InterrogatorPromptRequest(InterrogatorAnalyzeRequest):
|
||||
mode: str = Field(default="fast", title="Mode", description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.")
|
||||
|
||||
@app.get("/interrogator/models")
|
||||
async def api_get_models():
|
||||
return ["/".join(x) for x in open_clip.list_pretrained()]
|
||||
|
||||
@app.post("/interrogator/prompt")
|
||||
async def api_get_prompt(analyzereq: InterrogatorPromptRequest):
|
||||
image_b64 = analyzereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
img = decode_base64_to_image(image_b64)
|
||||
prompt = image_to_prompt(img, analyzereq.mode, analyzereq.clip_model_name)
|
||||
return {"prompt": prompt}
|
||||
|
||||
@app.post("/interrogator/analyze")
|
||||
async def api_analyze(analyzereq: InterrogatorAnalyzeRequest):
|
||||
image_b64 = analyzereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = decode_base64_to_image(image_b64)
|
||||
(medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks) = image_analysis(img, analyzereq.clip_model_name)
|
||||
return {"medium": medium_ranks, "artist": artist_ranks, "movement": movement_ranks, "trending": trending_ranks, "flavor": flavor_ranks}
|
||||
|
||||
# script_callbacks.on_app_started(mount_interrogator_api)
|
||||
|
|
@ -25,7 +25,7 @@ def create_ui():
|
|||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
||||
with gr.TabItem('Process Batch', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
||||
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
||||
image_batch = gr.Files(label="Batch process", interactive=True, elem_id="extras_image_batch")
|
||||
with gr.TabItem('Process Folder', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
||||
|
|
|
|||
Loading…
Reference in New Issue