merge process and interrogate, add exif handler

pull/3000/head
Vladimir Mandic 2024-03-03 13:26:09 -05:00
parent 328a9cacd4
commit 1b44a16a4e
14 changed files with 283 additions and 268 deletions

View File

@ -5,7 +5,7 @@
- EDM samplers for Playground require `diffusers==0.27.0`
- StableCascade requires diffusers `kashif/diffusers.git@wuerstchen-v3`
## Update for 2024-03-02
## Update for 2024-03-03
- [Playground v2.5](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic)
- new model version from Playground: based on SDXL, but with some cool new concepts
@ -52,6 +52,10 @@
- **Refiner** validated workflows:
- Fully functional: SD15 + SD15, SDXL + SDXL, SDXL + SDXL-R
- Functional, but result is not as good: SD15 + SDXL, SDXL + SD15, SD15 + SDXL-R
- **UI**
- *interrogate* tab is now merged into *process* tab
- *image viewer* now displays image metadata
- *themes* improve on-the-fly switching
- **Fixes**
- improve *model cpu offload* compatibility
- improve *model sequential offload* compatibility

@ -1 +1 @@
Subproject commit 39159f2d52f53a5cf7ba0dfd1a0f085cff3e71e5
Subproject commit d23a7055778cd1743d1c6239a4ea5f1243b63337

View File

@ -165,7 +165,7 @@
{"id":"","label":"Separate Init Image","localized":"","hint":"Creates an additional window next to Control input labeled Init input, so you can have a separate image for both Control operations and an init source."}
],
"process tab": [
{"id":"","label":"Single Image","localized":"","hint":"Process single image"},
{"id":"","label":"Process Image","localized":"","hint":"Process single image"},
{"id":"","label":"Process Batch","localized":"","hint":"Process batch of images"},
{"id":"","label":"Process Folder","localized":"","hint":"Process all images in a folder"},
{"id":"","label":"Scale by","localized":"","hint":"Use this tab to resize the source image(s) by a chosen factor"},

1
javascript/exifr.js Normal file

File diff suppressed because one or more lines are too long

View File

@ -24,7 +24,6 @@ function modalImageSwitch(offset) {
nextButton.click();
const modalImage = gradioApp().getElementById('modalImage');
const modal = gradioApp().getElementById('lightboxModal');
modalImage.onload = () => modalPreviewZone.focus();
modalImage.src = nextButton.children[0].src;
if (modalImage.style.display === 'none') modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
@ -55,6 +54,24 @@ function modalKeyHandler(event) {
event.stopPropagation();
}
async function displayExif(el) {
const modalExif = gradioApp().getElementById('modalExif');
modalExif.innerHTML = '';
const exif = await window.exifr.parse(el);
if (!exif) return;
log('exif', exif);
try {
let html = `
<b>Image</b> <a href="${el.src}" target="_blank">${el.src}</a> <b>Size</b> ${el.naturalWidth}x${el.naturalHeight}<br>
<b>Prompt</b> ${exif.parameters || ''}<br>
`;
html = html.replace('\n', '<br>');
html = html.replace('Negative prompt:', '<br><b>Negative</b>');
html = html.replace('Steps:', '<br><b>Params</b> Steps:');
modalExif.innerHTML = html;
} catch(e) { }
}
function showModal(event) {
const source = event.target || event.srcElement;
const modalImage = gradioApp().getElementById('modalImage');
@ -63,6 +80,7 @@ function showModal(event) {
modalImage.onload = () => {
previewInstance.moveTo(0, 0);
modalPreviewZone.focus();
displayExif(modalImage);''
};
modalImage.src = source.src;
if (modalImage.style.display === 'none') lb.style.setProperty('background-image', `url(${source.src})`);
@ -165,45 +183,50 @@ async function initImageViewer() {
const modalZoom = document.createElement('span');
modalZoom.id = 'modal_zoom';
modalZoom.className = 'cursor';
modalZoom.innerHTML = '🔍';
modalZoom.innerHTML = '\uf531';
modalZoom.title = 'Toggle zoomed view';
modalZoom.addEventListener('click', modalZoomToggle, true);
const modalReset = document.createElement('span');
modalReset.id = 'modal_reset';
modalReset.className = 'cursor';
modalReset.innerHTML = '♻️';
modalReset.innerHTML = '\uf532';
modalReset.title = 'Reset zoomed view';
modalReset.addEventListener('click', modalResetInstance, true);
const modalTile = document.createElement('span');
modalTile.id = 'modal_tile';
modalTile.className = 'cursor';
modalTile.innerHTML = '🖽';
modalTile.innerHTML = '\udb81\udd70';
modalTile.title = 'Preview tiling';
modalTile.addEventListener('click', modalTileToggle, true);
const modalSave = document.createElement('span');
modalSave.id = 'modal_save';
modalSave.className = 'cursor';
modalSave.innerHTML = '💾';
modalSave.innerHTML = '\udb80\udd93';
modalSave.title = 'Save Image';
modalSave.addEventListener('click', modalSaveImage, true);
const modalDownload = document.createElement('span');
modalDownload.id = 'modal_download';
modalDownload.className = 'cursor';
modalDownload.innerHTML = '📷';
modalDownload.innerHTML = '\udb85\udc62';
modalDownload.title = 'Download Image';
modalDownload.addEventListener('click', modalDownloadImage, true);
const modalClose = document.createElement('span');
modalClose.id = 'modal_close';
modalClose.className = 'cursor';
modalClose.innerHTML = '🗙';
modalClose.innerHTML = '\udb80\udd57';
modalClose.title = 'Close';
modalClose.addEventListener('click', (evt) => closeModal(evt, true), true);
// exif
const modalExif = document.createElement('div');
modalExif.id = 'modalExif';
modalExif.style = 'position: absolute; bottom: 0px; width: 98%; background-color: rgba(0, 0, 0, 0.5); color: var(--neutral-300); padding: 1em; font-size: small;'
// handlers
modalPreviewZone.addEventListener('mousedown', () => { previewDrag = false; });
modalPreviewZone.addEventListener('touchstart', () => { previewDrag = false; }, { passive: true });
@ -233,6 +256,7 @@ async function initImageViewer() {
modal.appendChild(modalPreviewZone);
modal.appendChild(modalNext);
modal.append(modalControls);
modal.append(modalExif);
modalControls.appendChild(modalZoom);
modalControls.appendChild(modalReset);
modalControls.appendChild(modalTile);

View File

@ -152,7 +152,7 @@ div#extras_scale_to_tab div.form{ flex-direction: row; }
user-select: none; -webkit-user-select: none; flex-direction: row; }
.modalControls { display: flex; justify-content: space-evenly; background-color: transparent; position: absolute; width: 99%; z-index: 1; }
.modalControls:hover { background-color: #50505050; }
.modalControls span { color: white; font-size: 2em; font-weight: bold; cursor: pointer; filter: grayscale(100%); }
.modalControls span { color: white; font-size: 2em !important; font-weight: bold; cursor: pointer; filter: grayscale(100%); }
.modalControls span:hover, .modalControls span:focus { color: var(--highlight-color); filter: none; }
.lightboxModalPreviewZone { display: flex; width: 100%; height: 100%; }
.lightboxModalPreviewZone:focus-visible { outline: none; }

View File

@ -67,6 +67,8 @@ function extract_image_from_gallery(gallery) {
async function setTheme(val, old) {
if (!old || val === old) return;
old = old.replace('modern/', '');
val = val.replace('modern/', '');
const links = Array.from(document.getElementsByTagName('link')).filter((l) => l.href.includes(old));
for (const link of links) {
const href = link.href.replace(old, val);

View File

@ -68,8 +68,8 @@ def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, f
return res
def get_interrogate():
from modules.ui_interrogate import get_models
return ['clip', 'deepdanbooru'] + get_models()
from modules.interrogate import get_clip_models
return ['clip', 'deepdanbooru'] + get_clip_models()
def post_interrogate(req: models.ReqInterrogate):
if req.image is None or len(req.image) < 64:
@ -87,8 +87,8 @@ def post_interrogate(req: models.ReqInterrogate):
caption = deepbooru.model.tag(image)
return models.ResInterrogate(caption=caption)
else:
from modules.ui_interrogate import interrogate_image, analyze_image, get_models
if req.model not in get_models():
from modules.interrogate import interrogate_image, analyze_image, get_clip_models
if req.model not in get_clip_models():
raise HTTPException(status_code=404, detail="Model not found")
try:
caption = interrogate_image(image, model=req.model, mode=req.mode)

View File

@ -79,7 +79,7 @@ class InterrogateModels:
def load_blip_model(self):
self.create_fake_fairscale()
from repositories.blip import models
from repositories.blip import models # pylint: disable=unused-import
from repositories.blip.models import blip
import modules.modelloader as modelloader
model_path = os.path.join(paths.models_path, "BLIP")
@ -195,3 +195,162 @@ class InterrogateModels:
self.unload()
shared.state.end()
return res
# --------- interrrogate ui
ci = None
low_vram = False
class BatchWriter:
def __init__(self, folder):
self.folder = folder
self.csv, self.file = None, None
def add(self, file, prompt):
txt_file = os.path.splitext(file)[0] + ".txt"
with open(os.path.join(self.folder, txt_file), 'w', encoding='utf-8') as f:
f.write(prompt)
def close(self):
if self.file is not None:
self.file.close()
def get_clip_models():
import open_clip
return ['/'.join(x) for x in open_clip.list_pretrained()]
def load_interrogator(model):
from clip_interrogator import Config, Interrogator
global ci # pylint: disable=global-statement
if ci is None:
config = Config(device=devices.get_optimal_device(), cache_path=os.path.join(paths.models_path, 'Interrogator'), clip_model_name=model, quiet=True)
if low_vram:
config.apply_low_vram_defaults()
shared.log.info(f'Interrogate load: config={config}')
ci = Interrogator(config)
elif model != ci.config.clip_model_name:
ci.config.clip_model_name = model
shared.log.info(f'Interrogate load: config={ci.config}')
ci.load_clip_model()
def unload_clip_model():
if ci is not None:
shared.log.debug('Interrogate offload')
ci.caption_model = ci.caption_model.to(devices.cpu)
ci.clip_model = ci.clip_model.to(devices.cpu)
ci.caption_offloaded = True
ci.clip_offloaded = True
devices.torch_gc()
def interrogate(image, mode, caption=None):
shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}')
if mode == 'best':
prompt = ci.interrogate(image, caption=caption)
elif mode == 'caption':
prompt = ci.generate_caption(image) if caption is None else caption
elif mode == 'classic':
prompt = ci.interrogate_classic(image, caption=caption)
elif mode == 'fast':
prompt = ci.interrogate_fast(image, caption=caption)
elif mode == 'negative':
prompt = ci.interrogate_negative(image)
else:
raise RuntimeError(f"Unknown mode {mode}")
return prompt
def interrogate_image(image, model, mode):
shared.state.begin()
shared.state.job = 'interrogate'
try:
if shared.backend == shared.Backend.ORIGINAL and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.send_everything_to_cpu()
devices.torch_gc()
load_interrogator(model)
image = image.convert('RGB')
shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}')
prompt = interrogate(image, mode)
except Exception as e:
prompt = f"Exception {type(e)}"
shared.log.error(f'Interrogate: {e}')
shared.state.end()
return prompt
def interrogate_batch(batch_files, batch_folder, batch_str, model, mode, write):
files = []
if batch_files is not None:
files += [f.name for f in batch_files]
if batch_folder is not None:
files += [f.name for f in batch_folder]
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
files += [os.path.join(batch_str, f) for f in os.listdir(batch_str) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
if len(files) == 0:
shared.log.error('Interrogate batch no images')
return ''
shared.state.begin()
shared.state.job = 'batch interrogate'
prompts = []
try:
if shared.backend == shared.Backend.ORIGINAL and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.send_everything_to_cpu()
devices.torch_gc()
load_interrogator(model)
shared.log.info(f'Interrogate batch: images={len(files)} mode={mode} config={ci.config}')
captions = []
# first pass: generate captions
for file in files:
caption = ""
try:
if shared.state.interrupted:
break
image = Image.open(file).convert('RGB')
caption = ci.generate_caption(image)
except Exception as e:
shared.log.error(f'Interrogate caption: {e}')
finally:
captions.append(caption)
# second pass: interrogate
if write:
writer = BatchWriter(os.path.dirname(files[0]))
for idx, file in enumerate(files):
try:
if shared.state.interrupted:
break
image = Image.open(file).convert('RGB')
prompt = interrogate(image, mode, caption=captions[idx])
prompts.append(prompt)
if write:
writer.add(file, prompt)
except OSError as e:
shared.log.error(f'Interrogate batch: {e}')
if write:
writer.close()
ci.config.quiet = False
unload_clip_model()
except Exception as e:
shared.log.error(f'Interrogate batch: {e}')
shared.state.end()
return '\n\n'.join(prompts)
def analyze_image(image, model):
load_interrogator(model)
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = dict(zip(top_mediums, ci.similarities(image_features, top_mediums)))
artist_ranks = dict(zip(top_artists, ci.similarities(image_features, top_artists)))
movement_ranks = dict(zip(top_movements, ci.similarities(image_features, top_movements)))
trending_ranks = dict(zip(top_trendings, ci.similarities(image_features, top_trendings)))
flavor_ranks = dict(zip(top_flavors, ci.similarities(image_features, top_flavors)))
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks

View File

@ -413,6 +413,7 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
try:
t0 = time.time()
sd_models_compile.check_deepcache(enable=True)
sd_models.move_model(shared.sd_model, devices.device)
output = shared.sd_model(**base_args) # pylint: disable=not-callable
if isinstance(output, dict):
output = SimpleNamespace(**output)
@ -480,6 +481,7 @@ def process_diffusers(p: processing.StableDiffusionProcessing):
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
update_sampler(shared.sd_model, second_pass=True)
shared.log.info(f'HiRes: class={shared.sd_model.__class__.__name__} sampler="{p.hr_sampler_name}"')
sd_models.move_model(shared.sd_model, devices.device)
hires_args = set_pipeline_args(
model=shared.sd_model,
prompts=[p.refiner_prompt] if len(p.refiner_prompt) > 0 else p.prompts,

View File

@ -149,12 +149,6 @@ def create_ui(startup_timer = None):
ui_models.create_ui()
timer.startup.record("ui-models")
with gr.Blocks(analytics_enabled=False) as interrogate_interface:
from modules import ui_interrogate
ui_interrogate.create_ui()
timer.startup.record("ui-interrogate")
def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -371,7 +365,6 @@ def create_ui(startup_timer = None):
interfaces += [(img2img_interface, "Image", "img2img")]
interfaces += [(control_interface, "Control", "control")] if control_interface is not None else []
interfaces += [(extras_interface, "Process", "process")]
interfaces += [(interrogate_interface, "Interrogate", "interrogate")]
interfaces += [(models_interface, "Models", "models")]
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "System", "system")]

View File

@ -1,231 +0,0 @@
import os
import gradio as gr
import torch
from PIL import Image
import modules.generation_parameters_copypaste as parameters_copypaste
from modules import devices, lowvram, shared, paths, ui_common
ci = None
low_vram = False
class BatchWriter:
def __init__(self, folder):
self.folder = folder
self.csv, self.file = None, None
def add(self, file, prompt):
txt_file = os.path.splitext(file)[0] + ".txt"
with open(os.path.join(self.folder, txt_file), 'w', encoding='utf-8') as f:
f.write(prompt)
def close(self):
if self.file is not None:
self.file.close()
def get_models():
import open_clip
return ['/'.join(x) for x in open_clip.list_pretrained()]
def load_interrogator(clip_model_name):
from clip_interrogator import Config, Interrogator
global ci # pylint: disable=global-statement
if ci is None:
config = Config(device=devices.get_optimal_device(), cache_path=os.path.join(paths.models_path, 'Interrogator'), clip_model_name=clip_model_name, quiet=True)
if low_vram:
config.apply_low_vram_defaults()
shared.log.info(f'Interrogate load: config={config}')
ci = Interrogator(config)
elif clip_model_name != ci.config.clip_model_name:
ci.config.clip_model_name = clip_model_name
shared.log.info(f'Interrogate load: config={ci.config}')
ci.load_clip_model()
def unload():
if ci is not None:
shared.log.debug('Interrogate offload')
ci.caption_model = ci.caption_model.to(devices.cpu)
ci.clip_model = ci.clip_model.to(devices.cpu)
ci.caption_offloaded = True
ci.clip_offloaded = True
devices.torch_gc()
def interrogate(image, mode, caption=None):
shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}')
if mode == 'best':
prompt = ci.interrogate(image, caption=caption)
elif mode == 'caption':
prompt = ci.generate_caption(image) if caption is None else caption
elif mode == 'classic':
prompt = ci.interrogate_classic(image, caption=caption)
elif mode == 'fast':
prompt = ci.interrogate_fast(image, caption=caption)
elif mode == 'negative':
prompt = ci.interrogate_negative(image)
else:
raise RuntimeError(f"Unknown mode {mode}")
return prompt
def interrogate_image(image, model, mode):
shared.state.begin()
shared.state.job = 'interrogate'
try:
if shared.backend == shared.Backend.ORIGINAL and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.send_everything_to_cpu()
devices.torch_gc()
load_interrogator(model)
image = image.convert('RGB')
shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}')
prompt = interrogate(image, mode)
except Exception as e:
prompt = f"Exception {type(e)}"
shared.log.error(f'Interrogate: {e}')
shared.state.end()
return prompt
def interrogate_batch(batch_files, batch_folder, batch_str, model, mode, write):
files = []
if batch_files is not None:
files += [f.name for f in batch_files]
if batch_folder is not None:
files += [f.name for f in batch_folder]
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
files += [os.path.join(batch_str, f) for f in os.listdir(batch_str) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
if len(files) == 0:
shared.log.error('Interrogate batch no images')
return ''
shared.state.begin()
shared.state.job = 'batch interrogate'
prompts = []
try:
if shared.backend == shared.Backend.ORIGINAL and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
lowvram.send_everything_to_cpu()
devices.torch_gc()
load_interrogator(model)
shared.log.info(f'Interrogate batch: images={len(files)} mode={mode} config={ci.config}')
captions = []
# first pass: generate captions
for file in files:
caption = ""
try:
if shared.state.interrupted:
break
image = Image.open(file).convert('RGB')
caption = ci.generate_caption(image)
except Exception as e:
shared.log.error(f'Interrogate caption: {e}')
finally:
captions.append(caption)
# second pass: interrogate
if write:
writer = BatchWriter(os.path.dirname(files[0]))
for idx, file in enumerate(files):
try:
if shared.state.interrupted:
break
image = Image.open(file).convert('RGB')
prompt = interrogate(image, mode, caption=captions[idx])
prompts.append(prompt)
if write:
writer.add(file, prompt)
except OSError as e:
shared.log.error(f'Interrogate batch: {e}')
if write:
writer.close()
ci.config.quiet = False
unload()
except Exception as e:
shared.log.error(f'Interrogate batch: {e}')
shared.state.end()
return '\n\n'.join(prompts)
def analyze_image(image, model):
load_interrogator(model)
image = image.convert('RGB')
image_features = ci.image_to_features(image)
top_mediums = ci.mediums.rank(image_features, 5)
top_artists = ci.artists.rank(image_features, 5)
top_movements = ci.movements.rank(image_features, 5)
top_trendings = ci.trendings.rank(image_features, 5)
top_flavors = ci.flavors.rank(image_features, 5)
medium_ranks = dict(zip(top_mediums, ci.similarities(image_features, top_mediums)))
artist_ranks = dict(zip(top_artists, ci.similarities(image_features, top_artists)))
movement_ranks = dict(zip(top_movements, ci.similarities(image_features, top_movements)))
trending_ranks = dict(zip(top_trendings, ci.similarities(image_features, top_trendings)))
flavor_ranks = dict(zip(top_flavors, ci.similarities(image_features, top_flavors)))
return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks
def create_ui():
global low_vram # pylint: disable=global-statement
low_vram = shared.cmd_opts.lowvram or shared.cmd_opts.medvram
if not low_vram and torch.cuda.is_available():
device = devices.get_optimal_device()
vram_total = torch.cuda.get_device_properties(device).total_memory
if vram_total <= 12*1024*1024*1024:
low_vram = True
with gr.Row(elem_id="interrogate_tab"):
with gr.Column():
with gr.Tab("Image"):
with gr.Row():
image = gr.Image(type='pil', label="Image")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3)
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
with gr.Row():
clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLIP Model')
ui_common.create_refresh_button(clip_model, get_models, lambda: {"choices": get_models()}, 'refresh_interrogate_models')
mode = gr.Radio(['best', 'fast', 'classic', 'caption', 'negative'], label='Mode', value='best')
with gr.Row():
btn_interrogate_img = gr.Button("Interrogate", variant='primary')
btn_analyze_img = gr.Button("Analyze", variant='primary')
btn_unload = gr.Button("Unload")
with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "extras", "control"])
for tabname, button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=prompt, source_image_component=image,))
btn_interrogate_img.click(interrogate_image, inputs=[image, clip_model, mode], outputs=prompt)
btn_analyze_img.click(analyze_image, inputs=[image, clip_model], outputs=[medium, artist, movement, trending, flavor])
btn_unload.click(unload)
with gr.Tab("Batch"):
with gr.Row():
batch_files = gr.File(label="Files", show_label=True, file_count='multiple', file_types=['image'], type='file', interactive=True, height=100)
with gr.Row():
batch_folder = gr.File(label="Folder", show_label=True, file_count='directory', file_types=['image'], type='file', interactive=True, height=100)
with gr.Row():
batch_str = gr.Text(label="Folder", value="", interactive=True)
with gr.Row():
batch = gr.Text(label="Prompts", lines=10)
with gr.Row():
write = gr.Checkbox(label='Write prompts to files', value=False)
with gr.Row():
clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLIP Model')
ui_common.create_refresh_button(clip_model, get_models, lambda: {"choices": get_models()}, 'refresh_interrogate_models')
with gr.Row():
btn_interrogate_batch = gr.Button("Interrogate", variant='primary')
btn_interrogate_batch.click(interrogate_batch, inputs=[batch_files, batch_folder, batch_str, clip_model, mode, write], outputs=[batch])
with gr.Tab("VQA"):
from modules import vqa
with gr.Row():
vqa_image = gr.Image(type='pil', label="Image")
with gr.Row():
vqa_question = gr.Textbox(label="Question")
with gr.Row():
vqa_answer = gr.Textbox(label="Answer", lines=3)
with gr.Row():
vqa_model = gr.Dropdown(list(vqa.MODELS), value='None', label='VQA Model')
vqa_submit = gr.Button("Interrogate", variant='primary')
vqa_submit.click(vqa.interrogate, inputs=[vqa_question, vqa_image, vqa_model], outputs=[vqa_answer])

View File

@ -1,18 +1,18 @@
import json
import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue
from modules import scripts, shared, ui_common, postprocessing, call_queue, interrogate
import modules.generation_parameters_copypaste as parameters_copypaste
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call # pylint: disable=unused-import
from modules.extras import run_pnginfo
from modules.ui_common import infotext_to_html
def wrap_pnginfo(image):
def submit_info(image):
_, geninfo, info = run_pnginfo(image)
return infotext_to_html(geninfo), info, geninfo
def submit_click(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, save_output, *script_inputs):
def submit_process(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, save_output, *script_inputs):
result_images, geninfo, js_info = postprocessing.run_postprocessing(tab_index, extras_image, image_batch, extras_batch_input_dir, extras_batch_output_dir, show_extras_results, *script_inputs, save_output=save_output)
return result_images, geninfo, json.dumps(js_info), ''
@ -22,18 +22,74 @@ def create_ui():
with gr.Row(equal_height=False, variant='compact', elem_classes="extras"):
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Process Batch', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
with gr.Tab('Process Image', id="single_image", elem_id="extras_single_tab") as tab_single:
with gr.Row():
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.Row(elem_id='copy_buttons_process'):
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "control"])
with gr.Tab('Process Batch', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.Files(label="Batch process", interactive=True, elem_id="extras_image_batch")
with gr.TabItem('Process Folder', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
with gr.Tab('Process Folder', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
with gr.Row(elem_id="copy_buttons_extras"):
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "control"])
with gr.Row():
save_output = gr.Checkbox(label='Save output', value=True, elem_id="extras_save_output")
with gr.Tab("Interrogate Image"):
with gr.Row():
image = gr.Image(type='pil', label="Image")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3)
with gr.Row():
medium = gr.Label(label="Medium", num_top_classes=5)
artist = gr.Label(label="Artist", num_top_classes=5)
movement = gr.Label(label="Movement", num_top_classes=5)
trending = gr.Label(label="Trending", num_top_classes=5)
flavor = gr.Label(label="Flavor", num_top_classes=5)
with gr.Row():
clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLIP Model')
ui_common.create_refresh_button(clip_model, interrogate.get_clip_models, lambda: {"choices": interrogate.get_clip_models()}, 'refresh_interrogate_models')
mode = gr.Radio(['best', 'fast', 'classic', 'caption', 'negative'], label='Mode', value='best')
with gr.Row():
btn_interrogate_img = gr.Button("Interrogate", variant='primary')
btn_analyze_img = gr.Button("Analyze", variant='primary')
btn_unload = gr.Button("Unload")
with gr.Row(elem_id='copy_buttons_interrogate'):
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "extras", "control"])
for tabname, button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=prompt, source_image_component=image,))
btn_interrogate_img.click(interrogate.interrogate_image, inputs=[image, clip_model, mode], outputs=prompt)
btn_analyze_img.click(interrogate.analyze_image, inputs=[image, clip_model], outputs=[medium, artist, movement, trending, flavor])
btn_unload.click(interrogate.unload_clip_model)
with gr.Tab("Interrogate Batch"):
with gr.Row():
batch_files = gr.File(label="Files", show_label=True, file_count='multiple', file_types=['image'], type='file', interactive=True, height=100)
with gr.Row():
batch_folder = gr.File(label="Folder", show_label=True, file_count='directory', file_types=['image'], type='file', interactive=True, height=100)
with gr.Row():
batch_str = gr.Text(label="Folder", value="", interactive=True)
with gr.Row():
batch = gr.Text(label="Prompts", lines=10)
with gr.Row():
clip_model = gr.Dropdown([], value='ViT-L-14/openai', label='CLIP Model')
ui_common.create_refresh_button(clip_model, interrogate.get_clip_models, lambda: {"choices": interrogate.get_clip_models()}, 'refresh_interrogate_models')
with gr.Row():
btn_interrogate_batch = gr.Button("Interrogate", variant='primary')
with gr.Tab("Query Image"):
from modules import vqa
with gr.Row():
vqa_image = gr.Image(type='pil', label="Image")
with gr.Row():
vqa_question = gr.Textbox(label="Question")
with gr.Row():
vqa_answer = gr.Textbox(label="Answer", lines=3)
with gr.Row():
vqa_model = gr.Dropdown(list(vqa.MODELS), value='None', label='VQA Model')
vqa_submit = gr.Button("Interrogate", variant='primary')
vqa_submit.click(vqa.interrogate, inputs=[vqa_question, vqa_image, vqa_model], outputs=[vqa_answer])
with gr.Row():
save_output = gr.Checkbox(label='Save output', value=True, elem_id="extras_save_output")
script_inputs = scripts.scripts_postproc.setup_ui()
with gr.Column():
id_part = 'extras'
@ -54,13 +110,13 @@ def create_ui():
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
extras_image.change(
fn=wrap_gradio_call(wrap_pnginfo),
fn=wrap_gradio_call(submit_info),
inputs=[extras_image],
outputs=[html_info_formatted, exif_info, gen_info],
)
submit.click(
_js="submit_postprocessing",
fn=call_queue.wrap_gradio_gpu_call(submit_click, extra_outputs=[None, '']),
fn=call_queue.wrap_gradio_gpu_call(submit_process, extra_outputs=[None, '']),
inputs=[
tab_index,
extras_image,
@ -78,6 +134,11 @@ def create_ui():
html_log,
]
)
btn_interrogate_batch.click(
fn=interrogate.interrogate_batch,
inputs=[batch_files, batch_folder, batch_str, clip_model, mode, save_output],
outputs=[batch],
)
parameters_copypaste.add_paste_fields("extras", extras_image, None)

2
wiki

@ -1 +1 @@
Subproject commit 5c52cbb7301c3e008a9dfd76702f9321ae7b3a34
Subproject commit a6e56a04f38b8b8fd57f397a6ada720042228174