231 lines
7.1 KiB
Python
231 lines
7.1 KiB
Python
import gradio as gr
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.patches import Patch
|
|
import io
|
|
import json
|
|
from PIL import Image
|
|
from typing import List
|
|
|
|
from scripts.enums import StableDiffusionVersion
|
|
from scripts.global_state import get_sd_version
|
|
from scripts.ipadapter.weight import calc_weights
|
|
|
|
|
|
INPUT_BLOCK_COLOR = "#61bdee"
|
|
MIDDLE_BLOCK_COLOR = "#e2e2e2"
|
|
OUTPUT_BLOCK_COLOR = "#dc6e55"
|
|
|
|
|
|
def get_bar_colors(
|
|
sd_version: StableDiffusionVersion, input_color, middle_color, output_color
|
|
):
|
|
middle_block_idx = 4 if sd_version == StableDiffusionVersion.SDXL else 6
|
|
|
|
def get_color(idx):
|
|
if idx < middle_block_idx:
|
|
return input_color
|
|
elif idx == middle_block_idx:
|
|
return middle_color
|
|
else:
|
|
return output_color
|
|
|
|
return [get_color(i) for i in range(sd_version.transformer_block_num)]
|
|
|
|
|
|
def plot_weights(
|
|
numbers: List[float],
|
|
colors: List[str],
|
|
):
|
|
# Create a bar chart
|
|
plt.figure(figsize=(8, 4))
|
|
plt.bar(range(len(numbers)), numbers, color=colors)
|
|
plt.xlabel("Transformer Index")
|
|
plt.ylabel("Weight")
|
|
plt.legend(
|
|
handles=[
|
|
Patch(color=color, label=label)
|
|
for color, label in (
|
|
(INPUT_BLOCK_COLOR, "Input Block"),
|
|
(MIDDLE_BLOCK_COLOR, "Middle Block"),
|
|
(OUTPUT_BLOCK_COLOR, "Output Block"),
|
|
)
|
|
],
|
|
loc="best",
|
|
)
|
|
|
|
# Save the plot to a BytesIO buffer
|
|
buffer = io.BytesIO()
|
|
plt.savefig(buffer, format="png")
|
|
plt.close()
|
|
buffer.seek(0)
|
|
|
|
# Convert the buffer to a PIL image and return it
|
|
image = Image.open(buffer)
|
|
return image
|
|
|
|
|
|
class AdvancedWeightControl:
|
|
def __init__(self):
|
|
self.group = None
|
|
self.weight_type = None
|
|
self.weight_plot = None
|
|
self.weight_editor = None
|
|
self.weight_composition = None
|
|
|
|
def render(self):
|
|
with gr.Group(visible=False) as self.group:
|
|
with gr.Row():
|
|
self.weight_type = gr.Dropdown(
|
|
choices=[
|
|
"normal",
|
|
"ease in",
|
|
"ease out",
|
|
"ease in-out",
|
|
"reverse in-out",
|
|
"weak input",
|
|
"weak output",
|
|
"weak middle",
|
|
"strong middle",
|
|
"style transfer",
|
|
"composition",
|
|
"strong style transfer",
|
|
"style and composition",
|
|
"strong style and composition",
|
|
],
|
|
label="Weight Type",
|
|
value="normal",
|
|
)
|
|
self.weight_composition = gr.Slider(
|
|
label="Composition Weight",
|
|
minimum=0,
|
|
maximum=2.0,
|
|
value=1.0,
|
|
step=0.01,
|
|
visible=False,
|
|
)
|
|
self.weight_editor = gr.Textbox(label="Weights", visible=False)
|
|
|
|
self.weight_plot = gr.Image(
|
|
value=None,
|
|
label="Weight Plot",
|
|
interactive=False,
|
|
visible=False,
|
|
)
|
|
|
|
def register_callbacks(
|
|
self,
|
|
weight_input: gr.Slider,
|
|
advanced_weighting: gr.State,
|
|
control_type: gr.Radio,
|
|
update_unit_counter: gr.Number,
|
|
):
|
|
def advanced_weighting_supported(control_type: str) -> bool:
|
|
return control_type in ("IP-Adapter", "Instant-ID")
|
|
|
|
self.weight_type.change(
|
|
fn=lambda weight_type: gr.update(
|
|
visible=weight_type
|
|
in ("style and composition", "strong style and composition")
|
|
),
|
|
inputs=[self.weight_type],
|
|
outputs=[self.weight_composition],
|
|
)
|
|
|
|
def update_weight_textbox(
|
|
control_type: str,
|
|
weight_type: str,
|
|
weight: float,
|
|
weight_composition: float,
|
|
):
|
|
if not advanced_weighting_supported(control_type):
|
|
return gr.update()
|
|
|
|
sd_version = get_sd_version()
|
|
weights = calc_weights(weight_type, weight, sd_version, weight_composition)
|
|
return gr.update(value=str([round(w, 2) for w in weights]), visible=True)
|
|
|
|
trigger_inputs = [self.weight_type, weight_input, self.weight_composition]
|
|
for trigger_input in trigger_inputs:
|
|
trigger_input.change(
|
|
fn=update_weight_textbox,
|
|
inputs=[
|
|
control_type,
|
|
self.weight_type,
|
|
weight_input,
|
|
self.weight_composition,
|
|
],
|
|
outputs=[self.weight_editor],
|
|
)
|
|
|
|
def update_plot(weights_string: str):
|
|
try:
|
|
weights = json.loads(weights_string)
|
|
assert isinstance(weights, list)
|
|
except Exception:
|
|
return gr.update(visible=False)
|
|
|
|
sd_version = get_sd_version()
|
|
weight_plot = plot_weights(
|
|
weights,
|
|
get_bar_colors(
|
|
sd_version,
|
|
input_color=INPUT_BLOCK_COLOR,
|
|
middle_color=MIDDLE_BLOCK_COLOR,
|
|
output_color=OUTPUT_BLOCK_COLOR,
|
|
),
|
|
)
|
|
return gr.update(value=weight_plot, visible=True)
|
|
|
|
def update_advanced_weighting(weights_string: str):
|
|
try:
|
|
weights = json.loads(weights_string)
|
|
assert isinstance(weights, list)
|
|
except Exception:
|
|
return None
|
|
return weights
|
|
|
|
self.weight_editor.change(
|
|
fn=update_plot,
|
|
inputs=[self.weight_editor],
|
|
outputs=[self.weight_plot],
|
|
)
|
|
|
|
self.weight_editor.change(
|
|
fn=update_advanced_weighting,
|
|
inputs=[self.weight_editor],
|
|
outputs=[advanced_weighting],
|
|
).then(
|
|
fn=lambda x: gr.update(value=x + 1),
|
|
inputs=[update_unit_counter],
|
|
outputs=[update_unit_counter],
|
|
) # Necessary to flush gr.State change to unit state.
|
|
|
|
# TODO: Expose advanced weighting control for other control types.
|
|
def control_type_change(control_type: str, old_weights):
|
|
supported = advanced_weighting_supported(control_type)
|
|
if supported:
|
|
return (
|
|
gr.update(visible=supported),
|
|
old_weights,
|
|
gr.update(),
|
|
gr.update(),
|
|
)
|
|
else:
|
|
return (
|
|
gr.update(visible=supported),
|
|
None,
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
)
|
|
|
|
control_type.change(
|
|
fn=control_type_change,
|
|
inputs=[control_type, advanced_weighting],
|
|
outputs=[
|
|
self.group,
|
|
advanced_weighting,
|
|
self.weight_editor,
|
|
self.weight_plot,
|
|
],
|
|
)
|