Add new blip2 caption processor tool

pull/1428/head
bmaltais 2023-08-15 06:39:21 -04:00
parent 24d0017675
commit 940302cd93
10 changed files with 286 additions and 13 deletions

View File

@ -517,6 +517,8 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
* 2023/08/05 (v21.8.8)
- Fix issue with aiofiles: https://github.com/bmaltais/kohya_ss/issues/1359
- Merge sd-scripts updates as of Aug 11 2023
- Add new blip2 caption processor tool
- Add dataset preparation tab to appropriate trainers
* 2023/08/05 (v21.8.7)
- Add manual captioning option. Thanks to https://github.com/channelcat for this great contribution. (https://github.com/bmaltais/kohya_ss/pull/1352)
- Added support for `v_pred_like_loss` to the advanced training tab

View File

@ -40,6 +40,7 @@ from library.tensorboard_gui import (
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.class_sample_images import SampleImages, run_cmd_sample
@ -729,7 +730,7 @@ def dreambooth_tab(
with gr.Tab('Samples', elem_id='samples_tab'):
sample = SampleImages()
with gr.Tab('Tools'):
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
@ -740,6 +741,7 @@ def dreambooth_tab(
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)
with gr.Row():
button_run = gr.Button('Start training', variant='primary')

View File

@ -9,6 +9,7 @@ from .class_sample_images import SampleImages
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from .common_gui import color_aug_changed
class Dreambooth:
@ -49,7 +50,7 @@ class Dreambooth:
self.sample = SampleImages()
with gr.Tab('Tools'):
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
@ -60,6 +61,7 @@ class Dreambooth:
logging_dir_input=self.folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)
def save_to_json(self, filepath):
def serialize(obj):

View File

@ -31,7 +31,7 @@ class LoRATools:
gradio_resize_lora_tab(headless=headless)
gradio_verify_lora_tab(headless=headless)
if folders:
with gr.Tab('Deprecated'):
with gr.Tab('Dataset Preparation'):
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,

View File

@ -1,15 +1,7 @@
# v1: initial release
# v2: add open and save folder icons
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities
import gradio as gr
import json
import math
import os
import subprocess
import psutil
import pathlib
import argparse
from datetime import datetime
from library.common_gui import (
@ -17,7 +9,6 @@ from library.common_gui import (
get_any_file_path,
get_saveasfile_path,
color_aug_changed,
save_inference_file,
run_cmd_advanced_training,
run_cmd_training,
update_my_data,
@ -44,6 +35,11 @@ from library.utilities import utilities_tab
from library.class_sample_images import SampleImages, run_cmd_sample
from library.class_lora_tab import LoRATools
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.custom_logging import setup_logging
# Set up logging
@ -1416,6 +1412,20 @@ def lora_tab(
module_dropout,
],
)
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)
with gr.Row():
button_run = gr.Button('Start training', variant='primary')

View File

@ -40,6 +40,7 @@ from library.tensorboard_gui import (
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
from library.utilities import utilities_tab
from library.class_sample_images import SampleImages, run_cmd_sample
@ -787,7 +788,7 @@ def ti_tab(
with gr.Tab('Samples', elem_id='samples_tab'):
sample = SampleImages()
with gr.Tab('Tools'):
with gr.Tab('Dataset Preparation'):
gr.Markdown(
'This section provide Dreambooth tools to help setup your dataset...'
)
@ -798,6 +799,7 @@ def ti_tab(
logging_dir_input=folders.logging_dir,
headless=headless,
)
gradio_dataset_balancing_tab(headless=headless)
with gr.Row():
button_run = gr.Button('Start training', variant='primary')

View File

@ -0,0 +1,33 @@
# blip2-for-sd
source: https://github.com/Talmendo/blip2-for-sd
Simple script to make BLIP2 output image description in a format suitable for Stable Diffusion.
Format followd is roughly
`[STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]`
## Usage
- Install dependencies according to requirements.txt
- run main.py
`python main.py`
The default model will be loaded automatically from huggingface.
You will be presented with an input to specify the folder to process after the model is loaded.
<img width="854" alt="Screenshot 2023-08-04 102650" src="https://github.com/Talmendo/blip2-for-sd/assets/141401796/fa40cae5-90a4-4dd5-be1d-fc0e8312251a">
- The image or source folder should have the following structure:
![Screenshot 2023-08-04 102544](https://github.com/Talmendo/blip2-for-sd/assets/141401796/eea9c2b0-e96a-40e4-8a6d-32dd7aa3e802)
Each folder represents a base prompt to be used for every image inside.
- You can adjust BLIP2 settings in `caption_processor.py` inbetween runs, without having to stop the script. Just update it before inputting the new source folder.
## Models
Default model is `Salesforce/blip2-opt-2.7b`, works quite well and doesn't require much VRAM.
Also tested with `Salesforce/blip2-opt-6.7b-coco` which seems to gives better results at the cost of much more VRAM and a large download (~30GB).

View File

@ -0,0 +1,105 @@
import torch
import re
class CaptionProcessor:
def __init__(self, model, processor, device):
self.model = model
self.processor = processor
self.device = device
def gen(self, inputs, max_length=10, min_length=0, top_k=30, top_p=0.92, num_beams=4):
return self.model.generate(
**inputs,
# max_new_tokens=25, # Number of tokens to generate
max_length=max_length, # Maximum length of the sequence to be generated, mutually exclusive with max_new_tokens
num_beams=num_beams, # Number of beams to use for beam search
num_return_sequences=1, # Number of captions to generate
early_stopping=True, # Stop when no new tokens are generated
repetition_penalty=1.5, # Penalize repeated words
no_repeat_ngram_size=2, # Number of words that can be repeated
# do_sample=True, # Introduce randomness to captions
# temperature=0.9, # Measure of randomness 0-1, 0 means no randomness
top_k=top_k, # Number of highest probability tokens to keep, 0 means no filtering
top_p=top_p, # Probability threshold, 0 means no filtering
min_length=min_length, # Minimum length of the sequence to be generated
)
def process(self, prompt, image):
return self.processor(image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
def caption_from(self, generated):
caption_list = self.processor.batch_decode(generated, skip_special_tokens=True)
caption_list = [caption.strip() for caption in caption_list]
return caption_list if len(caption_list) > 1 else caption_list[0]
def sanitise_caption(self, caption):
return caption.replace(" - ", "-")
# TODO this needs some more work
def sanitise_prompt_shard(self, prompt):
# Remove everything after "Answer:"
prompt = prompt.split("Answer:")[0].strip()
# Define a pattern for multiple replacements
replacements = [
(r", a point and shoot(?: camera)?", ""), # Matches ", a point and shoot" with optional " camera"
(r"it is a ", ""),
(r"it is ", ""),
(r"hair hair", "hair"),
(r"wearing nothing", "nude"),
(r"She's ", ""),
(r"She is ", "")
]
# Apply the replacements using regex
for pattern, replacement in replacements:
prompt = re.sub(pattern, replacement, prompt)
return prompt
def ask(self, question, image):
return self.sanitise_prompt_shard(self.caption_from(self.gen(self.process(f"Question: {question} Answer:", image))))
def caption_me(self, initial_prompt, image):
prompt = ""
try:
# [STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]
# print("\n")
hair_color = self.ask("What is her hair color?", image)
hair_length = self.ask("What is her hair length?", image)
p_hair = f"{hair_color} {hair_length} hair"
# print(p_hair)
p_style = self.ask("Between the choices selfie, mirror selfie, candid, professional portrait what is the style of the photo?", image)
# print(p_style)
p_clothing = self.ask("What is she wearing if anything?", image)
# print(p_clothing)
p_action = self.ask("What is she doing? Could be something like standing, stretching, walking, squatting, etc", image)
# print(p_action)
p_framing = self.ask("Between the choices close up, upper body shot, full body shot what is the framing of the photo?", image)
# print(p_framing)
p_setting = self.ask("Where is she? Be descriptive and detailed", image)
# print(p_setting)
p_lighting = self.ask("What is the scene lighting like? For example: soft lighting, studio lighting, natural lighting", image)
# print(p_lighting)
p_angle = self.ask("What angle is the picture taken from? Be succint, like: from the side, from below, from front", image)
# print(p_angle)
p_camera = self.ask("What kind of camera could this picture have been taken with? Be specific and guess a brand with specific camera type", image)
# print(p_camera)
# prompt = self.sanitise_caption(f"{p_style}, {initial_prompt} with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}")
prompt = self.sanitise_caption(f"{p_style}, with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}")
return prompt
except Exception as e:
print(e)
return prompt

View File

@ -0,0 +1,89 @@
import requests, torch, sys, os
import argparse
from importlib import reload
from PIL import Image
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from tqdm import tqdm
import caption_processor
model = None
processor = None
device = None
def load_model(model_name="Salesforce/blip2-opt-2.7b"):
global model, processor, device
print("Loading Model")
processor = AutoProcessor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16)
if torch.cuda.is_available():
print("CUDA available, using GPU")
device = "cuda"
else:
print("CUDA not available, using CPU")
device = "cpu"
print("Moving model to device")
model.to(device)
def main(path):
# reloading caption_processor to enable us to change its values in between executions
# without having to reload the model, which can take very long
# probably cleaner to do this with a config file and just reload that
# but this works for now
reload(caption_processor)
prompt_file_dict = {}
# list all sub dirs in path
sub_dirs = [dir for dir in os.listdir(path) if os.path.isdir(os.path.join(path, dir))]
print("Reading prompts from sub dirs and finding image files")
for prompt in sub_dirs:
prompt_file_dict[prompt] = [file for file in os.listdir(os.path.join(path, prompt)) if file.endswith((".jpg", ".png", ".jpeg", ".webp"))]
for prompt, file_list in prompt_file_dict.items():
print(f"Found {str(len(file_list))} files for prompt \"{prompt}\"")
for prompt, file_list in prompt_file_dict.items():
total = len(file_list)
for file in tqdm(file_list):
# read image
image = Image.open(os.path.join(path, prompt, file))
caption = ""
# generate caption
try:
caption = caption_processor.CaptionProcessor(model, processor, device).caption_me(prompt, image)
except:
print("Error creating caption for file: " + file)
# save caption to file
# file without extension
with open(os.path.join(path, prompt, os.path.splitext(file)[0] + ".txt"), "w", encoding="utf-8") as f:
f.write(caption)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Enter the path to the file")
parser.add_argument("path", type=str, nargs='?', default="", help="Path to the file")
parser.add_argument("--interactive", action="store_true", help="Interactive mode")
args = parser.parse_args()
interactive = args.interactive
load_model(model_name="Salesforce/blip2-opt-2.7b")
if interactive:
while True:
path = input("Enter path: ")
main(path)
continue_prompt = input("Continue? (y/n): ")
if continue_prompt.lower() != 'y':
break
else:
path = args.path
search_subdirectories = False
main(path)

View File

@ -0,0 +1,28 @@
--extra-index-url https://download.pytorch.org/whl/cu118
accelerate==0.21.0
certifi==2023.7.22
charset-normalizer==3.2.0
colorama==0.4.6
filelock==3.12.2
fsspec==2023.6.0
huggingface-hub==0.16.4
idna==3.4
Jinja2==3.1.2
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.1
numpy==1.25.2
packaging==23.1
Pillow==10.0.0
psutil==5.9.5
PyYAML==6.0.1
regex==2023.6.3
requests==2.31.0
safetensors==0.3.1
sympy==1.12
tokenizers==0.13.3
torch==2.0.1+cu118
tqdm==4.65.0
transformers==4.31.0
typing_extensions==4.7.1
urllib3==2.0.4