diff --git a/.gitmodules b/.gitmodules index 05e6e3c84..a6d3e9bf1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -12,10 +12,6 @@ path = modules/lora url = https://github.com/kohya-ss/sd-scripts ignore = dirty -[submodule "extensions-builtin/clip-interrogator-ext"] - path = extensions-builtin/clip-interrogator-ext - url = https://github.com/Dahvikiin/clip-interrogator-ext.git - ignore = dirty [submodule "extensions-builtin/sd-webui-controlnet"] path = extensions-builtin/sd-webui-controlnet url = https://github.com/Mikubill/sd-webui-controlnet diff --git a/CHANGELOG.md b/CHANGELOG.md index e6cbd5ec4..40a7ed833 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,8 @@ Some highlights: [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVIN - new option: *settings -> system paths -> models* can be used to set custom base path for *all* models (previously only as cli option) - remove external clone of items in `/repositories` + - **Interrogator** module has been removed from `extensions-builtin` + and fully implemented (and improved) natively - **UI** - UI tweaks for default themes - UI switch core font in default theme to **noto-sans** diff --git a/README.md b/README.md index 84e66df8d..e56eb9d23 100644 --- a/README.md +++ b/README.md @@ -124,10 +124,7 @@ SD.Next comes with several extensions pre-installed: - [ControlNet](https://github.com/Mikubill/sd-webui-controlnet) - [Agent Scheduler](https://github.com/ArtVentureX/sd-webui-agent-scheduler) -- [Multi-Diffusion Tiled Diffusion and VAE](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111) -- [LyCORIS](https://github.com/KohakuBlueleaf/a1111-sd-webui-lycoris) - [Image Browser](https://github.com/AlUlkesh/stable-diffusion-webui-images-browser) -- [CLiP Interrogator](https://github.com/pharmapsychotic/clip-interrogator-ext) - [Rembg Background Removal](https://github.com/AUTOMATIC1111/stable-diffusion-webui-rembg) ### **Collab** diff --git a/html/locale_en.json b/html/locale_en.json index e2480b253..f0c84af87 100644 --- a/html/locale_en.json +++ b/html/locale_en.json @@ -40,9 +40,9 @@ {"id":"","label":"disabled","localized":"","hint":""} ], "tabs": [ - {"id":"","label":"From Text","localized":"","hint":"Create image from text"}, - {"id":"","label":"From Image","localized":"","hint":"Create image from image"}, - {"id":"","label":"Process Image","localized":"","hint":"Process existing image"}, + {"id":"","label":"Text","localized":"","hint":"Create image from text"}, + {"id":"","label":"Image","localized":"","hint":"Create image from image"}, + {"id":"","label":"Process","localized":"","hint":"Process existing image"}, {"id":"","label":"Train","localized":"","hint":"Run training or model merging"}, {"id":"","label":"Models","localized":"","hint":"Convert or merge your models"}, {"id":"","label":"Interrogator","localized":"","hint":"Run interrogate to get description of your image"}, diff --git a/javascript/amethyst-nightfall.css b/javascript/amethyst-nightfall.css index 18d296730..1c20eb3e8 100644 --- a/javascript/amethyst-nightfall.css +++ b/javascript/amethyst-nightfall.css @@ -238,13 +238,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } --radius-md: 0; --radius-xl: 0; --radius-xxl: 0; - --text-xxs: 9px; - --text-xs: 10px; - --text-sm: 12px; - --text-md: 14px; - --text-lg: 16px; - --text-xl: 22px; - --text-xxl: 26px; --font: 'Source Sans Pro', 'ui-sans-serif', 'system-ui', sans-serif; --font-mono: 'IBM Plex Mono', 'ui-monospace', 'Consolas', monospace; --body-text-size: var(--text-md); diff --git a/javascript/black-orange.css b/javascript/black-orange.css index 4b23af2f5..210544d19 100644 --- a/javascript/black-orange.css +++ b/javascript/black-orange.css @@ -38,9 +38,6 @@ --spacing-xxl: 6px; --line-sm: 1.2em; --line-md: 1.4em; - --text-sm: 12px; - --text-md: 14px; - --text-lg: 15px; } html { font-size: var(--font-size); } @@ -253,10 +250,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } --radius-md: 0; --radius-xl: 0; --radius-xxl: 0; - --text-xxs: 9px; - --text-xs: 10px; - --text-xl: 22px; - --text-xxl: 26px; --body-text-size: var(--text-md); --body-text-weight: 400; --embed-radius: var(--radius-lg); diff --git a/javascript/black-teal.css b/javascript/black-teal.css index a9094fe31..23a69dfc4 100644 --- a/javascript/black-teal.css +++ b/javascript/black-teal.css @@ -33,9 +33,6 @@ --radius-lg: 4px; --line-sm: 1.2em; --line-md: 1.4em; - --text-sm: 12px; - --text-md: 14px; - --text-lg: 15px; } html { font-size: var(--font-size); font-family: var(--font); } @@ -246,10 +243,6 @@ textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8p --radius-md: 0; --radius-xl: 0; --radius-xxl: 0; - --text-xxs: 9px; - --text-xs: 10px; - --text-xl: 22px; - --text-xxl: 26px; --body-text-size: var(--text-md); --body-text-weight: 400; --embed-radius: var(--radius-lg); diff --git a/javascript/invoked.css b/javascript/invoked.css index 684527817..39d25b0e6 100644 --- a/javascript/invoked.css +++ b/javascript/invoked.css @@ -234,13 +234,6 @@ button.selected {background: var(--button-primary-background-fill);} --radius-md: 0; --radius-xl: 0; --radius-xxl: 0; - --text-xxs: 9px; - --text-xs: 10px; - --text-sm: 12px; - --text-md: 14px; - --text-lg: 16px; - --text-xl: 22px; - --text-xxl: 26px; --body-text-size: var(--text-md); --body-text-weight: 400; --embed-radius: var(--radius-lg); diff --git a/javascript/light-teal.css b/javascript/light-teal.css index 9b34212ec..5541cea74 100644 --- a/javascript/light-teal.css +++ b/javascript/light-teal.css @@ -33,9 +33,6 @@ --radius-lg: 4px; --line-sm: 1.2em; --line-md: 1.4em; - --text-sm: 12px; - --text-md: 14px; - --text-lg: 15px; } html { font-size: var(--font-size); } @@ -309,8 +306,4 @@ svg.feather.feather-image, .feather .feather-image { display: none } --table-odd-background-fill: #333333; --table-radius: var(--radius-lg); --table-row-focus: var(--color-accent-soft); - --text-lg: 16px; - --text-xs: 10px; - --text-xxl: 26px; - --text-xxs: 9px; } diff --git a/javascript/midnight-barbie.css b/javascript/midnight-barbie.css index 7049a96b9..31916705e 100644 --- a/javascript/midnight-barbie.css +++ b/javascript/midnight-barbie.css @@ -239,13 +239,6 @@ svg.feather.feather-image, .feather .feather-image { display: none } --radius-md: 0; --radius-xl: 0; --radius-xxl: 0; - --text-xxs: 9px; - --text-xs: 10px; - --text-sm: 12px; - --text-md: 14px; - --text-lg: 16px; - --text-xl: 22px; - --text-xxl: 26px; --body-text-size: var(--text-md); --body-text-weight: 400; --embed-radius: var(--radius-lg); diff --git a/javascript/sdnext.css b/javascript/sdnext.css index 9bb3782a5..035862dbd 100644 --- a/javascript/sdnext.css +++ b/javascript/sdnext.css @@ -263,3 +263,13 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt @keyframes move { from { background-position-x: 0, -40px; } to { background-position-x: 0, 40px; } } @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } @keyframes color { from { filter: hue-rotate(0deg) } to { filter: hue-rotate(360deg) } } + +:root, .light, .dark { + --text-xxs: 9px; + --text-xs: 10px; + --text-sm: 12px; + --text-md: 14px; + --text-lg: 16px; + --text-xl: 18px; + --text-xxl: 20px; +} diff --git a/modules/loader.py b/modules/loader.py index c725d6d24..3dbb0c4a8 100644 --- a/modules/loader.py +++ b/modules/loader.py @@ -41,4 +41,4 @@ errors.install([gradio]) import diffusers # pylint: disable=W0611,C0411 timer.startup.record("diffusers") -errors.log.debug(f'Load packages: torch={getattr(torch, "__long_version__", torch.__version__)} diffusers={diffusers.__version__} gradio={gradio.__version__}') +errors.log.info(f'Load packages: torch={getattr(torch, "__long_version__", torch.__version__)} diffusers={diffusers.__version__} gradio={gradio.__version__}') diff --git a/modules/ui.py b/modules/ui.py index 0edd90083..a4cb0be12 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -10,7 +10,7 @@ import numpy as np from PIL import Image from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, ui_loadsave, ui_train, ui_models +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, ui_loadsave, ui_train, ui_models, ui_interrogate from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts @@ -635,7 +635,6 @@ def create_ui(startup_timer = None): img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory", **modules.shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] - for i, tab in enumerate(img2img_tabs): tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab]) @@ -903,6 +902,11 @@ def create_ui(startup_timer = None): ui_models.create_ui() timer.startup.record("ui-models") + with gr.Blocks(analytics_enabled=False) as interrogate_interface: + ui_interrogate.create_ui() + timer.startup.record("ui-interrogate") + + def create_setting_component(key, is_quicksettings=False): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1102,11 +1106,12 @@ def create_ui(startup_timer = None): timer.startup.record("ui-settings") interfaces = [ - (txt2img_interface, "From Text", "txt2img"), - (img2img_interface, "From Image", "img2img"), - (extras_interface, "Process Image", "process"), + (txt2img_interface, "Text", "txt2img"), + (img2img_interface, "Image", "img2img"), + (extras_interface, "Process", "process"), (train_interface, "Train", "train"), (models_interface, "Models", "models"), + (interrogate_interface, "Interrogate", "interrogate"), ] interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "System", "system")] diff --git a/modules/ui_interrogate.py b/modules/ui_interrogate.py new file mode 100644 index 000000000..f8537ac37 --- /dev/null +++ b/modules/ui_interrogate.py @@ -0,0 +1,267 @@ +import os +import base64 +from io import BytesIO +import gradio as gr +import open_clip +import torch +from PIL import Image +from pydantic import BaseModel, Field # pylint: disable=no-name-in-module +from fastapi import FastAPI +from fastapi.exceptions import HTTPException +from clip_interrogator import Config, Interrogator +import modules.generation_parameters_copypaste as parameters_copypaste +from modules import devices, lowvram, shared, paths + + +ci = None +low_vram = False + + +class BatchWriter: + def __init__(self, folder): + self.folder = folder + self.csv, self.file = None, None + + def add(self, file, prompt): + txt_file = os.path.splitext(file)[0] + ".txt" + with open(os.path.join(self.folder, txt_file), 'w', encoding='utf-8') as f: + f.write(prompt) + + def close(self): + if self.file is not None: + self.file.close() + + +def load(clip_model_name): + global ci # pylint: disable=global-statement + if ci is None: + config = Config(device=devices.get_optimal_device(), cache_path=os.path.join(paths.models_path, 'clip-interrogator'), clip_model_name=clip_model_name, quiet=True) + if low_vram: + config.apply_low_vram_defaults() + shared.log.info(f'Interrogate load: config={config}') + ci = Interrogator(config) + elif clip_model_name != ci.config.clip_model_name: + ci.config.clip_model_name = clip_model_name + shared.log.info(f'Interrogate load: config={ci.config}') + ci.load_clip_model() + + +def unload(): + if ci is not None: + shared.log.debug('Interrogate offload') + ci.caption_model = ci.caption_model.to(devices.cpu) + ci.clip_model = ci.clip_model.to(devices.cpu) + ci.caption_offloaded = True + ci.clip_offloaded = True + devices.torch_gc() + + +def image_analysis(image, clip_model_name): + load(clip_model_name) + image = image.convert('RGB') + image_features = ci.image_to_features(image) + top_mediums = ci.mediums.rank(image_features, 5) + top_artists = ci.artists.rank(image_features, 5) + top_movements = ci.movements.rank(image_features, 5) + top_trendings = ci.trendings.rank(image_features, 5) + top_flavors = ci.flavors.rank(image_features, 5) + medium_ranks = {medium: sim for medium, sim in zip(top_mediums, ci.similarities(image_features, top_mediums))} + artist_ranks = {artist: sim for artist, sim in zip(top_artists, ci.similarities(image_features, top_artists))} + movement_ranks = {movement: sim for movement, sim in zip(top_movements, ci.similarities(image_features, top_movements))} + trending_ranks = {trending: sim for trending, sim in zip(top_trendings, ci.similarities(image_features, top_trendings))} + flavor_ranks = {flavor: sim for flavor, sim in zip(top_flavors, ci.similarities(image_features, top_flavors))} + return medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks + + +def interrogate(image, mode, caption=None): + shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}') + if mode == 'best': + prompt = ci.interrogate(image, caption=caption) + elif mode == 'caption': + prompt = ci.generate_caption(image) if caption is None else caption + elif mode == 'classic': + prompt = ci.interrogate_classic(image, caption=caption) + elif mode == 'fast': + prompt = ci.interrogate_fast(image, caption=caption) + elif mode == 'negative': + prompt = ci.interrogate_negative(image) + else: + raise RuntimeError(f"Unknown mode {mode}") + return prompt + + +def image_to_prompt(image, mode, clip_model_name): + shared.state.begin() + shared.state.job = 'interrogate' + try: + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + devices.torch_gc() + load(clip_model_name) + image = image.convert('RGB') + shared.log.info(f'Interrogate: image={image} mode={mode} config={ci.config}') + prompt = interrogate(image, mode) + except Exception as e: + prompt = f"Exception {type(e)}" + shared.log.error(f'Interrogate: {e}') + shared.state.end() + return prompt + + +def get_models(): + return ['/'.join(x) for x in open_clip.list_pretrained()] + + +def batch_process(batch_files, batch_folder, batch_str, mode, clip_model, write): + files = [] + if batch_files is not None: + files += [f.name for f in batch_files] + if batch_folder is not None: + files += [f.name for f in batch_folder] + if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str): + files += [os.path.join(batch_str, f) for f in os.listdir(batch_str) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] + if len(files) == 0: + shared.log.error('Interrogate batch no images') + return '' + shared.log.info(f'Interrogate batch: images={len(files)} mode={mode} config={ci.config}') + shared.state.begin() + shared.state.job = 'batch interrogate' + prompts = [] + try: + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + devices.torch_gc() + load(clip_model) + captions = [] + # first pass: generate captions + for file in files: + caption = "" + try: + if shared.state.interrupted: + break + image = Image.open(file).convert('RGB') + caption = ci.generate_caption(image) + except Exception as e: + shared.log.error(f'Interrogate caption: {e}') + finally: + captions.append(caption) + # second pass: interrogate + if write: + writer = BatchWriter(os.path.dirname(files[0])) + for idx, file in enumerate(files): + try: + if shared.state.interrupted: + break + image = Image.open(file).convert('RGB') + prompt = interrogate(image, mode, caption=captions[idx]) + prompts.append(prompt) + if write: + writer.add(file, prompt) + except OSError as e: + shared.log.error(f'Interrogate batch: {e}') + if write: + writer.close() + ci.config.quiet = False + unload() + except Exception as e: + shared.log.error(f'Interrogate batch: {e}') + shared.state.end() + return '\n\n'.join(prompts) + + +def create_ui(): + global low_vram # pylint: disable=global-statement + low_vram = shared.cmd_opts.lowvram or shared.cmd_opts.medvram + if not low_vram and torch.cuda.is_available(): + device = devices.get_optimal_device() + vram_total = torch.cuda.get_device_properties(device).total_memory + if vram_total <= 12*1024*1024*1024: + low_vram = True + with gr.Row(elem_id="interrogate_tab"): + with gr.Column(): + with gr.Tab("Image"): + with gr.Row(): + image = gr.Image(type='pil', label="Image") + with gr.Row(): + prompt = gr.Textbox(label="Prompt", lines=3) + with gr.Row(): + medium = gr.Label(label="Medium", num_top_classes=5) + artist = gr.Label(label="Artist", num_top_classes=5) + movement = gr.Label(label="Movement", num_top_classes=5) + trending = gr.Label(label="Trending", num_top_classes=5) + flavor = gr.Label(label="Flavor", num_top_classes=5) + with gr.Row(): + interrogate_btn = gr.Button("Interrogate", variant='primary') + analyze_btn = gr.Button("Analyze", variant='primary') + unload_btn = gr.Button("Unload") + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "extras"]) + for tabname, button in buttons.items(): + parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(paste_button=button, tabname=tabname, source_text_component=prompt, source_image_component=image,)) + with gr.Tab("Batch"): + with gr.Row(): + batch_files = gr.File(label="Files", show_label=True, file_count='multiple', file_types=['image'], type='file', interactive=True, height=100) + with gr.Row(): + batch_folder = gr.File(label="Folder", show_label=True, file_count='directory', file_types=['image'], type='file', interactive=True, height=100) + with gr.Row(): + batch_str = gr.Text(label="Folder", value="", interactive=True) + with gr.Row(): + batch = gr.Text(label="Prompts", lines=10) + with gr.Row(): + write = gr.Checkbox(label='Write prompts to files', value=False) + with gr.Row(): + batch_btn = gr.Button("Interrogate", variant='primary') + with gr.Column(): + with gr.Row(): + clip_model = gr.Dropdown(get_models(), value='ViT-L-14/openai', label='CLIP Model') + with gr.Row(): + mode = gr.Radio(['best', 'fast', 'classic', 'caption', 'negative'], label='Mode', value='best') + interrogate_btn.click(image_to_prompt, inputs=[image, mode, clip_model], outputs=prompt) + analyze_btn.click(image_analysis, inputs=[image, clip_model], outputs=[medium, artist, movement, trending, flavor]) + unload_btn.click(unload) + batch_btn.click(batch_process, inputs=[batch_files, batch_folder, batch_str, mode, clip_model, write], outputs=[batch]) + + +def decode_base64_to_image(encoding): + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + + +def mount_interrogator_api(_: gr.Blocks, app: FastAPI): # TODO redesign interrogator api + + class InterrogatorAnalyzeRequest(BaseModel): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + clip_model_name: str = Field(default="ViT-L-14/openai", title="Model", description="The interrogate model used. See the models endpoint for a list of available models.") + + class InterrogatorPromptRequest(InterrogatorAnalyzeRequest): + mode: str = Field(default="fast", title="Mode", description="The mode used to generate the prompt. Can be one of: best, fast, classic, negative.") + + @app.get("/interrogator/models") + async def api_get_models(): + return ["/".join(x) for x in open_clip.list_pretrained()] + + @app.post("/interrogator/prompt") + async def api_get_prompt(analyzereq: InterrogatorPromptRequest): + image_b64 = analyzereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + img = decode_base64_to_image(image_b64) + prompt = image_to_prompt(img, analyzereq.mode, analyzereq.clip_model_name) + return {"prompt": prompt} + + @app.post("/interrogator/analyze") + async def api_analyze(analyzereq: InterrogatorAnalyzeRequest): + image_b64 = analyzereq.image + if image_b64 is None: + raise HTTPException(status_code=404, detail="Image not found") + + img = decode_base64_to_image(image_b64) + (medium_ranks, artist_ranks, movement_ranks, trending_ranks, flavor_ranks) = image_analysis(img, analyzereq.clip_model_name) + return {"medium": medium_ranks, "artist": artist_ranks, "movement": movement_ranks, "trending": trending_ranks, "flavor": flavor_ranks} + +# script_callbacks.on_app_started(mount_interrogator_api) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index b28597653..334afcc9a 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -25,7 +25,7 @@ def create_ui(): with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") with gr.TabItem('Process Batch', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch: - image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch") + image_batch = gr.Files(label="Batch process", interactive=True, elem_id="extras_image_batch") with gr.TabItem('Process Folder', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir: extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")