From d80e6e77c5a1c434690badf5bf8085a2bc9787d3 Mon Sep 17 00:00:00 2001 From: Dimitri Date: Fri, 21 Jul 2023 22:53:44 +0200 Subject: [PATCH] v0.2 --- javascript/fabric.js | 19 ++++++ scripts/fabric.py | 149 +++++++++++++++++++++++++++---------------- scripts/patching.py | 7 -- 3 files changed, 113 insertions(+), 62 deletions(-) create mode 100644 javascript/fabric.js diff --git a/javascript/fabric.js b/javascript/fabric.js new file mode 100644 index 0000000..d6bf58d --- /dev/null +++ b/javascript/fabric.js @@ -0,0 +1,19 @@ + +function fabric_selected_gallery_index(elem_id) { + const el = document.getElementById(elem_id); + const buttons = el.querySelectorAll('.thumbnail-item.thumbnail-small'); + const button = el.querySelector('.thumbnail-item.thumbnail-small.selected'); + + let result = -1; + buttons.forEach((v, i) => { + if (v == button) { + result = i; + } + }); + + console.log(result); + console.log(button); + console.log(buttons); + + return result; +} diff --git a/scripts/fabric.py b/scripts/fabric.py index 529fb3e..0c31ca4 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -10,12 +10,13 @@ from PIL import Image import modules.scripts from modules import devices, script_callbacks, images +from modules.ui_components import FormGroup from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass from scripts.helpers import WebUiComponents -__version__ = "0.1" +__version__ = "0.2" DEBUG = False @@ -43,10 +44,6 @@ class FabricParams: # TODO: replace global state with Gradio state class FabricState: batch_images = [] - selected_image = None - uploaded_image = None - liked_images = [] - disliked_images = [] def encode_to_latent(p, feedback_imgs): @@ -78,16 +75,19 @@ class FabricScript(modules.scripts.Script): return modules.scripts.AlwaysVisible def ui(self, is_img2img): + self.selected_image = gr.State(None) + liked_images = gr.State([]) + disliked_images = gr.State([]) + selected_like = gr.State(None) + selected_dislike = gr.State(None) with gr.Accordion(f"{self.title()} v{__version__}", open=DEBUG, elem_id="fabric"): if DEBUG: like_example_btn = gr.Button("👍 Example") - feedback_disabled = gr.Checkbox(label="Disable", value=False) - with gr.Tabs(): with gr.Tab("Current batch"): - self.selected_img = gr.Image(value=None, type="pil", label="Selected image").style(height=256) + self.selected_img_display = gr.Image(value=None, type="pil", label="Selected image").style(height=256) with gr.Row(): like_btn_selected = gr.Button("👍 Like") @@ -100,16 +100,28 @@ class FabricScript(modules.scripts.Script): like_btn_uploaded = gr.Button("👍 Like") dislike_btn_uploaded = gr.Button("👎 Dislike") + with gr.Tabs(initial_tab="👍 Likes"): with gr.Tab("👍 Likes"): - like_gallery = gr.Gallery(label="Liked images").style(columns=5, height=256) - clear_liked_btn = gr.Button("Clear") + with gr.Row(): + remove_selected_like_btn = gr.Button("Remove selected", interactive=False) + clear_liked_btn = gr.Button("Clear") + like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery").style(columns=5, height=128) with gr.Tab("👎 Dislikes"): - dislike_gallery = gr.Gallery(label="Disliked images").style(columns=5, height=256) - clear_disliked_btn = gr.Button("Clear") + with gr.Row(): + remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False) + clear_disliked_btn = gr.Button("Clear") + dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery").style(columns=5, height=128) - with gr.Group(): + with FormGroup(): + feedback_disabled = gr.Checkbox(label="Disable feedback", value=False) + + with FormGroup(): + with gr.Row(): + # TODO: figure out how to make the step size do what it's supposed to + feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images") + with gr.Row(): feedback_start = gr.Slider(0.0, 1.0, value=0.0, label="Feedback start") feedback_end = gr.Slider(0.0, 1.0, value=0.8, label="Feedback end") @@ -121,20 +133,49 @@ class FabricScript(modules.scripts.Script): WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select) - like_btn_selected.click(self.on_like_selected, inputs=None, outputs=like_gallery) - dislike_btn_selected.click(self.on_dislike_selected, inputs=None, outputs=dislike_gallery) + like_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, liked_images], outputs=[liked_images, like_gallery]) + dislike_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, disliked_images], outputs=[disliked_images, dislike_gallery]) - like_btn_uploaded.click(self.on_like_uploaded, inputs=upload_img_input, outputs=like_gallery) - dislike_btn_uploaded.click(self.on_dislike_uploaded, inputs=upload_img_input, outputs=dislike_gallery) + like_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, liked_images], outputs=[liked_images, like_gallery]) + dislike_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, disliked_images], outputs=[disliked_images, dislike_gallery]) - clear_liked_btn.click(self.clear_likes, inputs=None, outputs=like_gallery) - clear_disliked_btn.click(self.clear_dislikes, inputs=None, outputs=dislike_gallery) + clear_liked_btn.click(lambda _: [[], []], inputs=liked_images, outputs=[liked_images, like_gallery]) + clear_disliked_btn.click(lambda _: [[], []], inputs=disliked_images, outputs=[disliked_images, dislike_gallery]) + + like_gallery.select( + self.select_for_removal, + _js="(a, b) => [a, fabric_selected_gallery_index('fabric_like_gallery')]", + inputs=[like_gallery, like_gallery], + outputs=[selected_like, remove_selected_like_btn], + ) + + dislike_gallery.select( + self.select_for_removal, + _js="(a, b) => [a, fabric_selected_gallery_index('fabric_dislike_gallery')]", + inputs=[dislike_gallery, dislike_gallery], + outputs=[selected_dislike, remove_selected_dislike_btn], + ) + + remove_selected_like_btn.click( + self.remove_selected, + inputs=[liked_images, selected_like], + outputs=[liked_images, like_gallery, selected_like, remove_selected_like_btn], + ) + + remove_selected_dislike_btn.click( + self.remove_selected, + inputs=[disliked_images, selected_dislike], + outputs=[disliked_images, dislike_gallery, selected_dislike, remove_selected_dislike_btn], + ) if DEBUG: - like_example_btn.click(functools.partial(self.on_like_example, "example1"), inputs=None, outputs=like_gallery) + like_example_btn.click(functools.partial(self.on_like_example, example="example1"), inputs=liked_images, outputs=[liked_images, like_gallery]) return [ + liked_images, + disliked_images, feedback_disabled, + feedback_max_images, feedback_start, feedback_end, feedback_min_weight, @@ -142,41 +183,35 @@ class FabricScript(modules.scripts.Script): feedback_neg_scale, ] + + def select_for_removal(self, gallery, selected_idx): + return [ + selected_idx, + gr.update(interactive=True), + ] - def on_like_selected(self, *args): - if FabricState.selected_image is not None: - FabricState.liked_images.append(FabricState.selected_image) - return FabricState.liked_images + def remove_selected(self, images, idx): + if idx >= 0 and idx < len(images): + images.pop(idx) + + return [ + images, + images, + gr.update(value=None), + gr.update(interactive=False), + ] - def on_dislike_selected(self, *args): - if FabricState.selected_image is not None: - FabricState.disliked_images.append(FabricState.selected_image) - return FabricState.disliked_images + def add_image_to_state(self, img, images): + if img is not None: + images.append(img) + return images, images - def on_like_uploaded(self, image): - if image is not None: - FabricState.liked_images.append(image) - return FabricState.liked_images - - def on_dislike_uploaded(self, image): - if image is not None: - FabricState.disliked_images.append(image) - return FabricState.disliked_images - - def on_like_example(self, example="example1"): + def on_like_example(self, liked_images, example="example1"): img_path = Path(__file__).parent.parent.absolute() / "images" / f"{example}.png" image = Image.open(img_path) if image is not None: - FabricState.liked_images.append(image) - return FabricState.liked_images - - def clear_likes(self, *args): - FabricState.liked_images = [] - return FabricState.liked_images - - def clear_dislikes(self, *args): - FabricState.disliked_images = [] - return FabricState.disliked_images + liked_images.append(image) + return liked_images, liked_images def register_txt2img_gallery_select(self, gallery): gallery.select( @@ -186,7 +221,7 @@ class FabricScript(modules.scripts.Script): gallery, gallery, # can be any Gradio component (but not None), will be overwritten with selected gallery index ], - outputs=self.selected_img, + outputs=[self.selected_image, self.selected_img_display], ) def on_txt2img_gallery_select(self, gallery, selected_idx): @@ -194,14 +229,16 @@ class FabricScript(modules.scripts.Script): idx = selected_idx - (len(gallery) - len(images)) if idx >= 0 and idx < len(images): - FabricState.selected_image = images[idx] - return gr.update(value=images[idx]) + return images[idx], gr.update(value=images[idx]) else: - return None + return None, None def process(self, p, *args): ( + liked_images, + disliked_images, feedback_disabled, + feedback_max_images, feedback_start, feedback_end, feedback_min_weight, @@ -210,8 +247,10 @@ class FabricScript(modules.scripts.Script): ) = args print("[FABRIC] Encoding feedback images into latent space...") - pos_latents = encode_to_latent(p, FabricState.liked_images) - neg_latents = encode_to_latent(p, FabricState.disliked_images) + likes = liked_images[:int(feedback_max_images)] + dislikes = disliked_images[:int(feedback_max_images)] + pos_latents = encode_to_latent(p, likes) + neg_latents = encode_to_latent(p, dislikes) print("[FABRIC] Patching U-Net forward pass...") params = FabricParams( diff --git a/scripts/patching.py b/scripts/patching.py index 8552e57..7e1d737 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -37,8 +37,6 @@ def patch_unet_forward_pass(p, unet, params): z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0] all_zs.append(z) all_zs = torch.stack(all_zs, dim=0) - print(all_zs.shape) - print(all_zs[:, 0, 0, 0]) ## cache hidden states cached_hiddens = [] @@ -68,11 +66,6 @@ def patch_unet_forward_pass(p, unet, params): seq_len, d_model = x.shape[1:] num_pos = len(params.pos_latents) num_neg = len(params.neg_latents) - print("num_pos", num_pos) - print("num_neg", num_neg) - print("seq_len", seq_len) - print("d_model", d_model) - print(cached_hs.shape) pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_pos, dim) neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_neg, dim)