From 5fb1e8f80a13b08b6a4d63b880a9b07f8ddfdef5 Mon Sep 17 00:00:00 2001 From: benkyoujouzu Date: Sun, 11 Dec 2022 21:57:13 +0800 Subject: [PATCH] optimize UI --- scripts/vxa.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/scripts/vxa.py b/scripts/vxa.py index 9daf7bd..7adc2ec 100644 --- a/scripts/vxa.py +++ b/scripts/vxa.py @@ -96,23 +96,18 @@ def add_tab(): with gr.Row(): with gr.Column(): input_image = gr.Image(elem_id="vxa_input_image") + vxa_prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Prompt to be visualized") + vxa_token_indices = gr.Textbox(value="", label="Indices of tokens to be visualized", lines=2, placeholder="Example: 1, 3 means the sum of the first and the third tokens. 1 is suggected for a single token. Leave blank to visualize all tokens.") + vxa_time_embedding = gr.Textbox(value="1.0", label="Time embedding") + for n, m in shared.sd_model.named_modules(): + if(isinstance(m, CrossAttention)): + hidden_layers[n] = m + hidden_layer_names = list(filter(lambda s : "attn2" in s, hidden_layers.keys())) + hidden_layer_select = gr.Dropdown(value=default_hidden_layer_name, label="Cross-attention layer", choices=hidden_layer_names) + vxa_output_mode = gr.Dropdown(value="masked", label="Output mode", choices=["masked", "grey"]) + vxa_generate = gr.Button(value="Visualize Cross-Attention", elem_id="vxa_gen_btn") with gr.Column(): vxa_output = gr.Image(elem_id = "vxa_output", interactive=False) - vxa_generate = gr.Button(value="Visualize Cross-Attention", elem_id="vxa_gen_btn") - with gr.Row(): - vxa_prompt = gr.Textbox(label="Prompt", placeholder="Prompt to be visualized") - with gr.Row(): - vxa_token_indices = gr.Textbox(value="", label="Indices of tokens to be visualized", placeholder="Example: 1, 3 means the sum of the first and the third tokens. 1 is suggected for a single token. Leave blank to visualize all tokens.") - with gr.Row(): - vxa_time_embedding = gr.Textbox(value="1.0", label="Time embedding") - with gr.Row(): - for n, m in shared.sd_model.named_modules(): - if(isinstance(m, CrossAttention)): - hidden_layers[n] = m - hidden_layer_names = list(filter(lambda s : "attn2" in s, hidden_layers.keys())) - hidden_layer_select = gr.Dropdown(value=default_hidden_layer_name, label="Cross-attention layer", choices=hidden_layer_names) - with gr.Row(): - vxa_output_mode = gr.Dropdown(value="masked", label="Output mode", choices=["masked", "grey"]) vxa_generate.click( fn=generate_vxa, @@ -123,4 +118,5 @@ def add_tab(): return (visualize_cross_attention, "VXA", "visualize_cross_attention"), script_callbacks.on_ui_tabs(add_tab) -script_callbacks.on_model_loaded(update_layer_names) \ No newline at end of file +script_callbacks.on_model_loaded(update_layer_names) +