pull/23/head
Dimitri 2023-07-21 22:53:44 +02:00
parent 70130b385d
commit d80e6e77c5
3 changed files with 113 additions and 62 deletions

19
javascript/fabric.js Normal file
View File

@ -0,0 +1,19 @@
function fabric_selected_gallery_index(elem_id) {
const el = document.getElementById(elem_id);
const buttons = el.querySelectorAll('.thumbnail-item.thumbnail-small');
const button = el.querySelector('.thumbnail-item.thumbnail-small.selected');
let result = -1;
buttons.forEach((v, i) => {
if (v == button) {
result = i;
}
});
console.log(result);
console.log(button);
console.log(buttons);
return result;
}

View File

@ -10,12 +10,13 @@ from PIL import Image
import modules.scripts
from modules import devices, script_callbacks, images
from modules.ui_components import FormGroup
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
from scripts.helpers import WebUiComponents
__version__ = "0.1"
__version__ = "0.2"
DEBUG = False
@ -43,10 +44,6 @@ class FabricParams:
# TODO: replace global state with Gradio state
class FabricState:
batch_images = []
selected_image = None
uploaded_image = None
liked_images = []
disliked_images = []
def encode_to_latent(p, feedback_imgs):
@ -78,16 +75,19 @@ class FabricScript(modules.scripts.Script):
return modules.scripts.AlwaysVisible
def ui(self, is_img2img):
self.selected_image = gr.State(None)
liked_images = gr.State([])
disliked_images = gr.State([])
selected_like = gr.State(None)
selected_dislike = gr.State(None)
with gr.Accordion(f"{self.title()} v{__version__}", open=DEBUG, elem_id="fabric"):
if DEBUG:
like_example_btn = gr.Button("👍 Example")
feedback_disabled = gr.Checkbox(label="Disable", value=False)
with gr.Tabs():
with gr.Tab("Current batch"):
self.selected_img = gr.Image(value=None, type="pil", label="Selected image").style(height=256)
self.selected_img_display = gr.Image(value=None, type="pil", label="Selected image").style(height=256)
with gr.Row():
like_btn_selected = gr.Button("👍 Like")
@ -100,16 +100,28 @@ class FabricScript(modules.scripts.Script):
like_btn_uploaded = gr.Button("👍 Like")
dislike_btn_uploaded = gr.Button("👎 Dislike")
with gr.Tabs(initial_tab="👍 Likes"):
with gr.Tab("👍 Likes"):
like_gallery = gr.Gallery(label="Liked images").style(columns=5, height=256)
with gr.Row():
remove_selected_like_btn = gr.Button("Remove selected", interactive=False)
clear_liked_btn = gr.Button("Clear")
like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery").style(columns=5, height=128)
with gr.Tab("👎 Dislikes"):
dislike_gallery = gr.Gallery(label="Disliked images").style(columns=5, height=256)
with gr.Row():
remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False)
clear_disliked_btn = gr.Button("Clear")
dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery").style(columns=5, height=128)
with gr.Group():
with FormGroup():
feedback_disabled = gr.Checkbox(label="Disable feedback", value=False)
with FormGroup():
with gr.Row():
# TODO: figure out how to make the step size do what it's supposed to
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
with gr.Row():
feedback_start = gr.Slider(0.0, 1.0, value=0.0, label="Feedback start")
feedback_end = gr.Slider(0.0, 1.0, value=0.8, label="Feedback end")
@ -121,20 +133,49 @@ class FabricScript(modules.scripts.Script):
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
like_btn_selected.click(self.on_like_selected, inputs=None, outputs=like_gallery)
dislike_btn_selected.click(self.on_dislike_selected, inputs=None, outputs=dislike_gallery)
like_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, liked_images], outputs=[liked_images, like_gallery])
dislike_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, disliked_images], outputs=[disliked_images, dislike_gallery])
like_btn_uploaded.click(self.on_like_uploaded, inputs=upload_img_input, outputs=like_gallery)
dislike_btn_uploaded.click(self.on_dislike_uploaded, inputs=upload_img_input, outputs=dislike_gallery)
like_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, liked_images], outputs=[liked_images, like_gallery])
dislike_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, disliked_images], outputs=[disliked_images, dislike_gallery])
clear_liked_btn.click(self.clear_likes, inputs=None, outputs=like_gallery)
clear_disliked_btn.click(self.clear_dislikes, inputs=None, outputs=dislike_gallery)
clear_liked_btn.click(lambda _: [[], []], inputs=liked_images, outputs=[liked_images, like_gallery])
clear_disliked_btn.click(lambda _: [[], []], inputs=disliked_images, outputs=[disliked_images, dislike_gallery])
like_gallery.select(
self.select_for_removal,
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_like_gallery')]",
inputs=[like_gallery, like_gallery],
outputs=[selected_like, remove_selected_like_btn],
)
dislike_gallery.select(
self.select_for_removal,
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_dislike_gallery')]",
inputs=[dislike_gallery, dislike_gallery],
outputs=[selected_dislike, remove_selected_dislike_btn],
)
remove_selected_like_btn.click(
self.remove_selected,
inputs=[liked_images, selected_like],
outputs=[liked_images, like_gallery, selected_like, remove_selected_like_btn],
)
remove_selected_dislike_btn.click(
self.remove_selected,
inputs=[disliked_images, selected_dislike],
outputs=[disliked_images, dislike_gallery, selected_dislike, remove_selected_dislike_btn],
)
if DEBUG:
like_example_btn.click(functools.partial(self.on_like_example, "example1"), inputs=None, outputs=like_gallery)
like_example_btn.click(functools.partial(self.on_like_example, example="example1"), inputs=liked_images, outputs=[liked_images, like_gallery])
return [
liked_images,
disliked_images,
feedback_disabled,
feedback_max_images,
feedback_start,
feedback_end,
feedback_min_weight,
@ -143,40 +184,34 @@ class FabricScript(modules.scripts.Script):
]
def on_like_selected(self, *args):
if FabricState.selected_image is not None:
FabricState.liked_images.append(FabricState.selected_image)
return FabricState.liked_images
def select_for_removal(self, gallery, selected_idx):
return [
selected_idx,
gr.update(interactive=True),
]
def on_dislike_selected(self, *args):
if FabricState.selected_image is not None:
FabricState.disliked_images.append(FabricState.selected_image)
return FabricState.disliked_images
def remove_selected(self, images, idx):
if idx >= 0 and idx < len(images):
images.pop(idx)
def on_like_uploaded(self, image):
if image is not None:
FabricState.liked_images.append(image)
return FabricState.liked_images
return [
images,
images,
gr.update(value=None),
gr.update(interactive=False),
]
def on_dislike_uploaded(self, image):
if image is not None:
FabricState.disliked_images.append(image)
return FabricState.disliked_images
def add_image_to_state(self, img, images):
if img is not None:
images.append(img)
return images, images
def on_like_example(self, example="example1"):
def on_like_example(self, liked_images, example="example1"):
img_path = Path(__file__).parent.parent.absolute() / "images" / f"{example}.png"
image = Image.open(img_path)
if image is not None:
FabricState.liked_images.append(image)
return FabricState.liked_images
def clear_likes(self, *args):
FabricState.liked_images = []
return FabricState.liked_images
def clear_dislikes(self, *args):
FabricState.disliked_images = []
return FabricState.disliked_images
liked_images.append(image)
return liked_images, liked_images
def register_txt2img_gallery_select(self, gallery):
gallery.select(
@ -186,7 +221,7 @@ class FabricScript(modules.scripts.Script):
gallery,
gallery, # can be any Gradio component (but not None), will be overwritten with selected gallery index
],
outputs=self.selected_img,
outputs=[self.selected_image, self.selected_img_display],
)
def on_txt2img_gallery_select(self, gallery, selected_idx):
@ -194,14 +229,16 @@ class FabricScript(modules.scripts.Script):
idx = selected_idx - (len(gallery) - len(images))
if idx >= 0 and idx < len(images):
FabricState.selected_image = images[idx]
return gr.update(value=images[idx])
return images[idx], gr.update(value=images[idx])
else:
return None
return None, None
def process(self, p, *args):
(
liked_images,
disliked_images,
feedback_disabled,
feedback_max_images,
feedback_start,
feedback_end,
feedback_min_weight,
@ -210,8 +247,10 @@ class FabricScript(modules.scripts.Script):
) = args
print("[FABRIC] Encoding feedback images into latent space...")
pos_latents = encode_to_latent(p, FabricState.liked_images)
neg_latents = encode_to_latent(p, FabricState.disliked_images)
likes = liked_images[:int(feedback_max_images)]
dislikes = disliked_images[:int(feedback_max_images)]
pos_latents = encode_to_latent(p, likes)
neg_latents = encode_to_latent(p, dislikes)
print("[FABRIC] Patching U-Net forward pass...")
params = FabricParams(

View File

@ -37,8 +37,6 @@ def patch_unet_forward_pass(p, unet, params):
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
all_zs.append(z)
all_zs = torch.stack(all_zs, dim=0)
print(all_zs.shape)
print(all_zs[:, 0, 0, 0])
## cache hidden states
cached_hiddens = []
@ -68,11 +66,6 @@ def patch_unet_forward_pass(p, unet, params):
seq_len, d_model = x.shape[1:]
num_pos = len(params.pos_latents)
num_neg = len(params.neg_latents)
print("num_pos", num_pos)
print("num_neg", num_neg)
print("seq_len", seq_len)
print("d_model", d_model)
print(cached_hs.shape)
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_pos, dim)
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_neg, dim)