mirror of https://github.com/bmaltais/kohya_ss
Add new blip2 caption processor tool
parent
24d0017675
commit
940302cd93
|
|
@ -517,6 +517,8 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is
|
|||
* 2023/08/05 (v21.8.8)
|
||||
- Fix issue with aiofiles: https://github.com/bmaltais/kohya_ss/issues/1359
|
||||
- Merge sd-scripts updates as of Aug 11 2023
|
||||
- Add new blip2 caption processor tool
|
||||
- Add dataset preparation tab to appropriate trainers
|
||||
* 2023/08/05 (v21.8.7)
|
||||
- Add manual captioning option. Thanks to https://github.com/channelcat for this great contribution. (https://github.com/bmaltais/kohya_ss/pull/1352)
|
||||
- Added support for `v_pred_like_loss` to the advanced training tab
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from library.tensorboard_gui import (
|
|||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
)
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from library.utilities import utilities_tab
|
||||
from library.class_sample_images import SampleImages, run_cmd_sample
|
||||
|
||||
|
|
@ -729,7 +730,7 @@ def dreambooth_tab(
|
|||
with gr.Tab('Samples', elem_id='samples_tab'):
|
||||
sample = SampleImages()
|
||||
|
||||
with gr.Tab('Tools'):
|
||||
with gr.Tab('Dataset Preparation'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
|
|
@ -740,6 +741,7 @@ def dreambooth_tab(
|
|||
logging_dir_input=folders.logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
gradio_dataset_balancing_tab(headless=headless)
|
||||
|
||||
with gr.Row():
|
||||
button_run = gr.Button('Start training', variant='primary')
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from .class_sample_images import SampleImages
|
|||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
)
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from .common_gui import color_aug_changed
|
||||
|
||||
class Dreambooth:
|
||||
|
|
@ -49,7 +50,7 @@ class Dreambooth:
|
|||
|
||||
self.sample = SampleImages()
|
||||
|
||||
with gr.Tab('Tools'):
|
||||
with gr.Tab('Dataset Preparation'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
|
|
@ -60,6 +61,7 @@ class Dreambooth:
|
|||
logging_dir_input=self.folders.logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
gradio_dataset_balancing_tab(headless=headless)
|
||||
|
||||
def save_to_json(self, filepath):
|
||||
def serialize(obj):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class LoRATools:
|
|||
gradio_resize_lora_tab(headless=headless)
|
||||
gradio_verify_lora_tab(headless=headless)
|
||||
if folders:
|
||||
with gr.Tab('Deprecated'):
|
||||
with gr.Tab('Dataset Preparation'):
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=folders.train_data_dir,
|
||||
reg_data_dir_input=folders.reg_data_dir,
|
||||
|
|
|
|||
28
lora_gui.py
28
lora_gui.py
|
|
@ -1,15 +1,7 @@
|
|||
# v1: initial release
|
||||
# v2: add open and save folder icons
|
||||
# v3: Add new Utilities tab for Dreambooth folder preparation
|
||||
# v3.1: Adding captionning of images to utilities
|
||||
|
||||
import gradio as gr
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import psutil
|
||||
import pathlib
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from library.common_gui import (
|
||||
|
|
@ -17,7 +9,6 @@ from library.common_gui import (
|
|||
get_any_file_path,
|
||||
get_saveasfile_path,
|
||||
color_aug_changed,
|
||||
save_inference_file,
|
||||
run_cmd_advanced_training,
|
||||
run_cmd_training,
|
||||
update_my_data,
|
||||
|
|
@ -44,6 +35,11 @@ from library.utilities import utilities_tab
|
|||
from library.class_sample_images import SampleImages, run_cmd_sample
|
||||
from library.class_lora_tab import LoRATools
|
||||
|
||||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
)
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
|
||||
from library.custom_logging import setup_logging
|
||||
|
||||
# Set up logging
|
||||
|
|
@ -1416,6 +1412,20 @@ def lora_tab(
|
|||
module_dropout,
|
||||
],
|
||||
)
|
||||
|
||||
with gr.Tab('Dataset Preparation'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
gradio_dreambooth_folder_creation_tab(
|
||||
train_data_dir_input=folders.train_data_dir,
|
||||
reg_data_dir_input=folders.reg_data_dir,
|
||||
output_dir_input=folders.output_dir,
|
||||
logging_dir_input=folders.logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
gradio_dataset_balancing_tab(headless=headless)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
button_run = gr.Button('Start training', variant='primary')
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from library.tensorboard_gui import (
|
|||
from library.dreambooth_folder_creation_gui import (
|
||||
gradio_dreambooth_folder_creation_tab,
|
||||
)
|
||||
from library.dataset_balancing_gui import gradio_dataset_balancing_tab
|
||||
from library.utilities import utilities_tab
|
||||
from library.class_sample_images import SampleImages, run_cmd_sample
|
||||
|
||||
|
|
@ -787,7 +788,7 @@ def ti_tab(
|
|||
with gr.Tab('Samples', elem_id='samples_tab'):
|
||||
sample = SampleImages()
|
||||
|
||||
with gr.Tab('Tools'):
|
||||
with gr.Tab('Dataset Preparation'):
|
||||
gr.Markdown(
|
||||
'This section provide Dreambooth tools to help setup your dataset...'
|
||||
)
|
||||
|
|
@ -798,6 +799,7 @@ def ti_tab(
|
|||
logging_dir_input=folders.logging_dir,
|
||||
headless=headless,
|
||||
)
|
||||
gradio_dataset_balancing_tab(headless=headless)
|
||||
|
||||
with gr.Row():
|
||||
button_run = gr.Button('Start training', variant='primary')
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
# blip2-for-sd
|
||||
|
||||
source: https://github.com/Talmendo/blip2-for-sd
|
||||
|
||||
Simple script to make BLIP2 output image description in a format suitable for Stable Diffusion.
|
||||
|
||||
Format followd is roughly
|
||||
`[STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]`
|
||||
|
||||
## Usage
|
||||
- Install dependencies according to requirements.txt
|
||||
|
||||
- run main.py
|
||||
`python main.py`
|
||||
|
||||
The default model will be loaded automatically from huggingface.
|
||||
You will be presented with an input to specify the folder to process after the model is loaded.
|
||||
|
||||
<img width="854" alt="Screenshot 2023-08-04 102650" src="https://github.com/Talmendo/blip2-for-sd/assets/141401796/fa40cae5-90a4-4dd5-be1d-fc0e8312251a">
|
||||
|
||||
|
||||
- The image or source folder should have the following structure:
|
||||
|
||||

|
||||
|
||||
|
||||
Each folder represents a base prompt to be used for every image inside.
|
||||
|
||||
- You can adjust BLIP2 settings in `caption_processor.py` inbetween runs, without having to stop the script. Just update it before inputting the new source folder.
|
||||
|
||||
## Models
|
||||
Default model is `Salesforce/blip2-opt-2.7b`, works quite well and doesn't require much VRAM.
|
||||
Also tested with `Salesforce/blip2-opt-6.7b-coco` which seems to gives better results at the cost of much more VRAM and a large download (~30GB).
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
import re
|
||||
|
||||
class CaptionProcessor:
|
||||
def __init__(self, model, processor, device):
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.device = device
|
||||
|
||||
def gen(self, inputs, max_length=10, min_length=0, top_k=30, top_p=0.92, num_beams=4):
|
||||
return self.model.generate(
|
||||
**inputs,
|
||||
# max_new_tokens=25, # Number of tokens to generate
|
||||
max_length=max_length, # Maximum length of the sequence to be generated, mutually exclusive with max_new_tokens
|
||||
num_beams=num_beams, # Number of beams to use for beam search
|
||||
num_return_sequences=1, # Number of captions to generate
|
||||
early_stopping=True, # Stop when no new tokens are generated
|
||||
repetition_penalty=1.5, # Penalize repeated words
|
||||
no_repeat_ngram_size=2, # Number of words that can be repeated
|
||||
# do_sample=True, # Introduce randomness to captions
|
||||
# temperature=0.9, # Measure of randomness 0-1, 0 means no randomness
|
||||
top_k=top_k, # Number of highest probability tokens to keep, 0 means no filtering
|
||||
top_p=top_p, # Probability threshold, 0 means no filtering
|
||||
min_length=min_length, # Minimum length of the sequence to be generated
|
||||
)
|
||||
|
||||
def process(self, prompt, image):
|
||||
return self.processor(image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
|
||||
|
||||
def caption_from(self, generated):
|
||||
caption_list = self.processor.batch_decode(generated, skip_special_tokens=True)
|
||||
caption_list = [caption.strip() for caption in caption_list]
|
||||
return caption_list if len(caption_list) > 1 else caption_list[0]
|
||||
|
||||
def sanitise_caption(self, caption):
|
||||
return caption.replace(" - ", "-")
|
||||
|
||||
# TODO this needs some more work
|
||||
def sanitise_prompt_shard(self, prompt):
|
||||
# Remove everything after "Answer:"
|
||||
prompt = prompt.split("Answer:")[0].strip()
|
||||
|
||||
# Define a pattern for multiple replacements
|
||||
replacements = [
|
||||
(r", a point and shoot(?: camera)?", ""), # Matches ", a point and shoot" with optional " camera"
|
||||
(r"it is a ", ""),
|
||||
(r"it is ", ""),
|
||||
(r"hair hair", "hair"),
|
||||
(r"wearing nothing", "nude"),
|
||||
(r"She's ", ""),
|
||||
(r"She is ", "")
|
||||
]
|
||||
|
||||
# Apply the replacements using regex
|
||||
for pattern, replacement in replacements:
|
||||
prompt = re.sub(pattern, replacement, prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
def ask(self, question, image):
|
||||
return self.sanitise_prompt_shard(self.caption_from(self.gen(self.process(f"Question: {question} Answer:", image))))
|
||||
|
||||
def caption_me(self, initial_prompt, image):
|
||||
prompt = ""
|
||||
|
||||
try:
|
||||
# [STYLE OF PHOTO] photo of a [SUBJECT], [IMPORTANT FEATURE], [MORE DETAILS], [POSE OR ACTION], [FRAMING], [SETTING/BACKGROUND], [LIGHTING], [CAMERA ANGLE], [CAMERA PROPERTIES],in style of [PHOTOGRAPHER]
|
||||
# print("\n")
|
||||
hair_color = self.ask("What is her hair color?", image)
|
||||
hair_length = self.ask("What is her hair length?", image)
|
||||
p_hair = f"{hair_color} {hair_length} hair"
|
||||
# print(p_hair)
|
||||
|
||||
p_style = self.ask("Between the choices selfie, mirror selfie, candid, professional portrait what is the style of the photo?", image)
|
||||
# print(p_style)
|
||||
|
||||
p_clothing = self.ask("What is she wearing if anything?", image)
|
||||
# print(p_clothing)
|
||||
|
||||
p_action = self.ask("What is she doing? Could be something like standing, stretching, walking, squatting, etc", image)
|
||||
# print(p_action)
|
||||
|
||||
p_framing = self.ask("Between the choices close up, upper body shot, full body shot what is the framing of the photo?", image)
|
||||
# print(p_framing)
|
||||
|
||||
p_setting = self.ask("Where is she? Be descriptive and detailed", image)
|
||||
# print(p_setting)
|
||||
|
||||
p_lighting = self.ask("What is the scene lighting like? For example: soft lighting, studio lighting, natural lighting", image)
|
||||
# print(p_lighting)
|
||||
|
||||
p_angle = self.ask("What angle is the picture taken from? Be succint, like: from the side, from below, from front", image)
|
||||
# print(p_angle)
|
||||
|
||||
p_camera = self.ask("What kind of camera could this picture have been taken with? Be specific and guess a brand with specific camera type", image)
|
||||
# print(p_camera)
|
||||
|
||||
# prompt = self.sanitise_caption(f"{p_style}, {initial_prompt} with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}")
|
||||
prompt = self.sanitise_caption(f"{p_style}, with {p_hair}, wearing {p_clothing}, {p_action}, {p_framing}, {p_setting}, {p_lighting}, {p_angle}, {p_camera}")
|
||||
|
||||
return prompt
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return prompt
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
import requests, torch, sys, os
|
||||
import argparse
|
||||
|
||||
from importlib import reload
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
||||
from tqdm import tqdm
|
||||
|
||||
import caption_processor
|
||||
|
||||
model = None
|
||||
processor = None
|
||||
device = None
|
||||
|
||||
def load_model(model_name="Salesforce/blip2-opt-2.7b"):
|
||||
global model, processor, device
|
||||
|
||||
print("Loading Model")
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.float16)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA available, using GPU")
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA not available, using CPU")
|
||||
device = "cpu"
|
||||
|
||||
print("Moving model to device")
|
||||
model.to(device)
|
||||
|
||||
def main(path):
|
||||
# reloading caption_processor to enable us to change its values in between executions
|
||||
# without having to reload the model, which can take very long
|
||||
# probably cleaner to do this with a config file and just reload that
|
||||
# but this works for now
|
||||
reload(caption_processor)
|
||||
prompt_file_dict = {}
|
||||
|
||||
# list all sub dirs in path
|
||||
sub_dirs = [dir for dir in os.listdir(path) if os.path.isdir(os.path.join(path, dir))]
|
||||
|
||||
print("Reading prompts from sub dirs and finding image files")
|
||||
for prompt in sub_dirs:
|
||||
prompt_file_dict[prompt] = [file for file in os.listdir(os.path.join(path, prompt)) if file.endswith((".jpg", ".png", ".jpeg", ".webp"))]
|
||||
|
||||
for prompt, file_list in prompt_file_dict.items():
|
||||
print(f"Found {str(len(file_list))} files for prompt \"{prompt}\"")
|
||||
|
||||
for prompt, file_list in prompt_file_dict.items():
|
||||
total = len(file_list)
|
||||
|
||||
for file in tqdm(file_list):
|
||||
# read image
|
||||
image = Image.open(os.path.join(path, prompt, file))
|
||||
|
||||
caption = ""
|
||||
# generate caption
|
||||
try:
|
||||
caption = caption_processor.CaptionProcessor(model, processor, device).caption_me(prompt, image)
|
||||
except:
|
||||
print("Error creating caption for file: " + file)
|
||||
|
||||
# save caption to file
|
||||
# file without extension
|
||||
with open(os.path.join(path, prompt, os.path.splitext(file)[0] + ".txt"), "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Enter the path to the file")
|
||||
parser.add_argument("path", type=str, nargs='?', default="", help="Path to the file")
|
||||
parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
interactive = args.interactive
|
||||
|
||||
load_model(model_name="Salesforce/blip2-opt-2.7b")
|
||||
|
||||
if interactive:
|
||||
while True:
|
||||
path = input("Enter path: ")
|
||||
main(path)
|
||||
continue_prompt = input("Continue? (y/n): ")
|
||||
if continue_prompt.lower() != 'y':
|
||||
break
|
||||
else:
|
||||
path = args.path
|
||||
search_subdirectories = False
|
||||
main(path)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
accelerate==0.21.0
|
||||
certifi==2023.7.22
|
||||
charset-normalizer==3.2.0
|
||||
colorama==0.4.6
|
||||
filelock==3.12.2
|
||||
fsspec==2023.6.0
|
||||
huggingface-hub==0.16.4
|
||||
idna==3.4
|
||||
Jinja2==3.1.2
|
||||
MarkupSafe==2.1.3
|
||||
mpmath==1.3.0
|
||||
networkx==3.1
|
||||
numpy==1.25.2
|
||||
packaging==23.1
|
||||
Pillow==10.0.0
|
||||
psutil==5.9.5
|
||||
PyYAML==6.0.1
|
||||
regex==2023.6.3
|
||||
requests==2.31.0
|
||||
safetensors==0.3.1
|
||||
sympy==1.12
|
||||
tokenizers==0.13.3
|
||||
torch==2.0.1+cu118
|
||||
tqdm==4.65.0
|
||||
transformers==4.31.0
|
||||
typing_extensions==4.7.1
|
||||
urllib3==2.0.4
|
||||
Loading…
Reference in New Issue