v0.2
parent
70130b385d
commit
d80e6e77c5
|
|
@ -0,0 +1,19 @@
|
||||||
|
|
||||||
|
function fabric_selected_gallery_index(elem_id) {
|
||||||
|
const el = document.getElementById(elem_id);
|
||||||
|
const buttons = el.querySelectorAll('.thumbnail-item.thumbnail-small');
|
||||||
|
const button = el.querySelector('.thumbnail-item.thumbnail-small.selected');
|
||||||
|
|
||||||
|
let result = -1;
|
||||||
|
buttons.forEach((v, i) => {
|
||||||
|
if (v == button) {
|
||||||
|
result = i;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(result);
|
||||||
|
console.log(button);
|
||||||
|
console.log(buttons);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
@ -10,12 +10,13 @@ from PIL import Image
|
||||||
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import devices, script_callbacks, images
|
from modules import devices, script_callbacks, images
|
||||||
|
from modules.ui_components import FormGroup
|
||||||
|
|
||||||
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
|
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
|
||||||
from scripts.helpers import WebUiComponents
|
from scripts.helpers import WebUiComponents
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1"
|
__version__ = "0.2"
|
||||||
|
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
|
|
@ -43,10 +44,6 @@ class FabricParams:
|
||||||
# TODO: replace global state with Gradio state
|
# TODO: replace global state with Gradio state
|
||||||
class FabricState:
|
class FabricState:
|
||||||
batch_images = []
|
batch_images = []
|
||||||
selected_image = None
|
|
||||||
uploaded_image = None
|
|
||||||
liked_images = []
|
|
||||||
disliked_images = []
|
|
||||||
|
|
||||||
|
|
||||||
def encode_to_latent(p, feedback_imgs):
|
def encode_to_latent(p, feedback_imgs):
|
||||||
|
|
@ -78,16 +75,19 @@ class FabricScript(modules.scripts.Script):
|
||||||
return modules.scripts.AlwaysVisible
|
return modules.scripts.AlwaysVisible
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
|
self.selected_image = gr.State(None)
|
||||||
|
liked_images = gr.State([])
|
||||||
|
disliked_images = gr.State([])
|
||||||
|
selected_like = gr.State(None)
|
||||||
|
selected_dislike = gr.State(None)
|
||||||
|
|
||||||
with gr.Accordion(f"{self.title()} v{__version__}", open=DEBUG, elem_id="fabric"):
|
with gr.Accordion(f"{self.title()} v{__version__}", open=DEBUG, elem_id="fabric"):
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
like_example_btn = gr.Button("👍 Example")
|
like_example_btn = gr.Button("👍 Example")
|
||||||
|
|
||||||
feedback_disabled = gr.Checkbox(label="Disable", value=False)
|
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.Tab("Current batch"):
|
with gr.Tab("Current batch"):
|
||||||
self.selected_img = gr.Image(value=None, type="pil", label="Selected image").style(height=256)
|
self.selected_img_display = gr.Image(value=None, type="pil", label="Selected image").style(height=256)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
like_btn_selected = gr.Button("👍 Like")
|
like_btn_selected = gr.Button("👍 Like")
|
||||||
|
|
@ -100,16 +100,28 @@ class FabricScript(modules.scripts.Script):
|
||||||
like_btn_uploaded = gr.Button("👍 Like")
|
like_btn_uploaded = gr.Button("👍 Like")
|
||||||
dislike_btn_uploaded = gr.Button("👎 Dislike")
|
dislike_btn_uploaded = gr.Button("👎 Dislike")
|
||||||
|
|
||||||
|
|
||||||
with gr.Tabs(initial_tab="👍 Likes"):
|
with gr.Tabs(initial_tab="👍 Likes"):
|
||||||
with gr.Tab("👍 Likes"):
|
with gr.Tab("👍 Likes"):
|
||||||
like_gallery = gr.Gallery(label="Liked images").style(columns=5, height=256)
|
with gr.Row():
|
||||||
clear_liked_btn = gr.Button("Clear")
|
remove_selected_like_btn = gr.Button("Remove selected", interactive=False)
|
||||||
|
clear_liked_btn = gr.Button("Clear")
|
||||||
|
like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery").style(columns=5, height=128)
|
||||||
with gr.Tab("👎 Dislikes"):
|
with gr.Tab("👎 Dislikes"):
|
||||||
dislike_gallery = gr.Gallery(label="Disliked images").style(columns=5, height=256)
|
with gr.Row():
|
||||||
clear_disliked_btn = gr.Button("Clear")
|
remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False)
|
||||||
|
clear_disliked_btn = gr.Button("Clear")
|
||||||
|
dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery").style(columns=5, height=128)
|
||||||
|
|
||||||
|
|
||||||
with gr.Group():
|
with FormGroup():
|
||||||
|
feedback_disabled = gr.Checkbox(label="Disable feedback", value=False)
|
||||||
|
|
||||||
|
with FormGroup():
|
||||||
|
with gr.Row():
|
||||||
|
# TODO: figure out how to make the step size do what it's supposed to
|
||||||
|
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
feedback_start = gr.Slider(0.0, 1.0, value=0.0, label="Feedback start")
|
feedback_start = gr.Slider(0.0, 1.0, value=0.0, label="Feedback start")
|
||||||
feedback_end = gr.Slider(0.0, 1.0, value=0.8, label="Feedback end")
|
feedback_end = gr.Slider(0.0, 1.0, value=0.8, label="Feedback end")
|
||||||
|
|
@ -121,20 +133,49 @@ class FabricScript(modules.scripts.Script):
|
||||||
|
|
||||||
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
|
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
|
||||||
|
|
||||||
like_btn_selected.click(self.on_like_selected, inputs=None, outputs=like_gallery)
|
like_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, liked_images], outputs=[liked_images, like_gallery])
|
||||||
dislike_btn_selected.click(self.on_dislike_selected, inputs=None, outputs=dislike_gallery)
|
dislike_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, disliked_images], outputs=[disliked_images, dislike_gallery])
|
||||||
|
|
||||||
like_btn_uploaded.click(self.on_like_uploaded, inputs=upload_img_input, outputs=like_gallery)
|
like_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, liked_images], outputs=[liked_images, like_gallery])
|
||||||
dislike_btn_uploaded.click(self.on_dislike_uploaded, inputs=upload_img_input, outputs=dislike_gallery)
|
dislike_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, disliked_images], outputs=[disliked_images, dislike_gallery])
|
||||||
|
|
||||||
clear_liked_btn.click(self.clear_likes, inputs=None, outputs=like_gallery)
|
clear_liked_btn.click(lambda _: [[], []], inputs=liked_images, outputs=[liked_images, like_gallery])
|
||||||
clear_disliked_btn.click(self.clear_dislikes, inputs=None, outputs=dislike_gallery)
|
clear_disliked_btn.click(lambda _: [[], []], inputs=disliked_images, outputs=[disliked_images, dislike_gallery])
|
||||||
|
|
||||||
|
like_gallery.select(
|
||||||
|
self.select_for_removal,
|
||||||
|
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_like_gallery')]",
|
||||||
|
inputs=[like_gallery, like_gallery],
|
||||||
|
outputs=[selected_like, remove_selected_like_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
dislike_gallery.select(
|
||||||
|
self.select_for_removal,
|
||||||
|
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_dislike_gallery')]",
|
||||||
|
inputs=[dislike_gallery, dislike_gallery],
|
||||||
|
outputs=[selected_dislike, remove_selected_dislike_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
remove_selected_like_btn.click(
|
||||||
|
self.remove_selected,
|
||||||
|
inputs=[liked_images, selected_like],
|
||||||
|
outputs=[liked_images, like_gallery, selected_like, remove_selected_like_btn],
|
||||||
|
)
|
||||||
|
|
||||||
|
remove_selected_dislike_btn.click(
|
||||||
|
self.remove_selected,
|
||||||
|
inputs=[disliked_images, selected_dislike],
|
||||||
|
outputs=[disliked_images, dislike_gallery, selected_dislike, remove_selected_dislike_btn],
|
||||||
|
)
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
like_example_btn.click(functools.partial(self.on_like_example, "example1"), inputs=None, outputs=like_gallery)
|
like_example_btn.click(functools.partial(self.on_like_example, example="example1"), inputs=liked_images, outputs=[liked_images, like_gallery])
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
liked_images,
|
||||||
|
disliked_images,
|
||||||
feedback_disabled,
|
feedback_disabled,
|
||||||
|
feedback_max_images,
|
||||||
feedback_start,
|
feedback_start,
|
||||||
feedback_end,
|
feedback_end,
|
||||||
feedback_min_weight,
|
feedback_min_weight,
|
||||||
|
|
@ -142,41 +183,35 @@ class FabricScript(modules.scripts.Script):
|
||||||
feedback_neg_scale,
|
feedback_neg_scale,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def select_for_removal(self, gallery, selected_idx):
|
||||||
|
return [
|
||||||
|
selected_idx,
|
||||||
|
gr.update(interactive=True),
|
||||||
|
]
|
||||||
|
|
||||||
def on_like_selected(self, *args):
|
def remove_selected(self, images, idx):
|
||||||
if FabricState.selected_image is not None:
|
if idx >= 0 and idx < len(images):
|
||||||
FabricState.liked_images.append(FabricState.selected_image)
|
images.pop(idx)
|
||||||
return FabricState.liked_images
|
|
||||||
|
return [
|
||||||
|
images,
|
||||||
|
images,
|
||||||
|
gr.update(value=None),
|
||||||
|
gr.update(interactive=False),
|
||||||
|
]
|
||||||
|
|
||||||
def on_dislike_selected(self, *args):
|
def add_image_to_state(self, img, images):
|
||||||
if FabricState.selected_image is not None:
|
if img is not None:
|
||||||
FabricState.disliked_images.append(FabricState.selected_image)
|
images.append(img)
|
||||||
return FabricState.disliked_images
|
return images, images
|
||||||
|
|
||||||
def on_like_uploaded(self, image):
|
def on_like_example(self, liked_images, example="example1"):
|
||||||
if image is not None:
|
|
||||||
FabricState.liked_images.append(image)
|
|
||||||
return FabricState.liked_images
|
|
||||||
|
|
||||||
def on_dislike_uploaded(self, image):
|
|
||||||
if image is not None:
|
|
||||||
FabricState.disliked_images.append(image)
|
|
||||||
return FabricState.disliked_images
|
|
||||||
|
|
||||||
def on_like_example(self, example="example1"):
|
|
||||||
img_path = Path(__file__).parent.parent.absolute() / "images" / f"{example}.png"
|
img_path = Path(__file__).parent.parent.absolute() / "images" / f"{example}.png"
|
||||||
image = Image.open(img_path)
|
image = Image.open(img_path)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
FabricState.liked_images.append(image)
|
liked_images.append(image)
|
||||||
return FabricState.liked_images
|
return liked_images, liked_images
|
||||||
|
|
||||||
def clear_likes(self, *args):
|
|
||||||
FabricState.liked_images = []
|
|
||||||
return FabricState.liked_images
|
|
||||||
|
|
||||||
def clear_dislikes(self, *args):
|
|
||||||
FabricState.disliked_images = []
|
|
||||||
return FabricState.disliked_images
|
|
||||||
|
|
||||||
def register_txt2img_gallery_select(self, gallery):
|
def register_txt2img_gallery_select(self, gallery):
|
||||||
gallery.select(
|
gallery.select(
|
||||||
|
|
@ -186,7 +221,7 @@ class FabricScript(modules.scripts.Script):
|
||||||
gallery,
|
gallery,
|
||||||
gallery, # can be any Gradio component (but not None), will be overwritten with selected gallery index
|
gallery, # can be any Gradio component (but not None), will be overwritten with selected gallery index
|
||||||
],
|
],
|
||||||
outputs=self.selected_img,
|
outputs=[self.selected_image, self.selected_img_display],
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_txt2img_gallery_select(self, gallery, selected_idx):
|
def on_txt2img_gallery_select(self, gallery, selected_idx):
|
||||||
|
|
@ -194,14 +229,16 @@ class FabricScript(modules.scripts.Script):
|
||||||
idx = selected_idx - (len(gallery) - len(images))
|
idx = selected_idx - (len(gallery) - len(images))
|
||||||
|
|
||||||
if idx >= 0 and idx < len(images):
|
if idx >= 0 and idx < len(images):
|
||||||
FabricState.selected_image = images[idx]
|
return images[idx], gr.update(value=images[idx])
|
||||||
return gr.update(value=images[idx])
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
def process(self, p, *args):
|
def process(self, p, *args):
|
||||||
(
|
(
|
||||||
|
liked_images,
|
||||||
|
disliked_images,
|
||||||
feedback_disabled,
|
feedback_disabled,
|
||||||
|
feedback_max_images,
|
||||||
feedback_start,
|
feedback_start,
|
||||||
feedback_end,
|
feedback_end,
|
||||||
feedback_min_weight,
|
feedback_min_weight,
|
||||||
|
|
@ -210,8 +247,10 @@ class FabricScript(modules.scripts.Script):
|
||||||
) = args
|
) = args
|
||||||
|
|
||||||
print("[FABRIC] Encoding feedback images into latent space...")
|
print("[FABRIC] Encoding feedback images into latent space...")
|
||||||
pos_latents = encode_to_latent(p, FabricState.liked_images)
|
likes = liked_images[:int(feedback_max_images)]
|
||||||
neg_latents = encode_to_latent(p, FabricState.disliked_images)
|
dislikes = disliked_images[:int(feedback_max_images)]
|
||||||
|
pos_latents = encode_to_latent(p, likes)
|
||||||
|
neg_latents = encode_to_latent(p, dislikes)
|
||||||
|
|
||||||
print("[FABRIC] Patching U-Net forward pass...")
|
print("[FABRIC] Patching U-Net forward pass...")
|
||||||
params = FabricParams(
|
params = FabricParams(
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,6 @@ def patch_unet_forward_pass(p, unet, params):
|
||||||
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
|
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
|
||||||
all_zs.append(z)
|
all_zs.append(z)
|
||||||
all_zs = torch.stack(all_zs, dim=0)
|
all_zs = torch.stack(all_zs, dim=0)
|
||||||
print(all_zs.shape)
|
|
||||||
print(all_zs[:, 0, 0, 0])
|
|
||||||
|
|
||||||
## cache hidden states
|
## cache hidden states
|
||||||
cached_hiddens = []
|
cached_hiddens = []
|
||||||
|
|
@ -68,11 +66,6 @@ def patch_unet_forward_pass(p, unet, params):
|
||||||
seq_len, d_model = x.shape[1:]
|
seq_len, d_model = x.shape[1:]
|
||||||
num_pos = len(params.pos_latents)
|
num_pos = len(params.pos_latents)
|
||||||
num_neg = len(params.neg_latents)
|
num_neg = len(params.neg_latents)
|
||||||
print("num_pos", num_pos)
|
|
||||||
print("num_neg", num_neg)
|
|
||||||
print("seq_len", seq_len)
|
|
||||||
print("d_model", d_model)
|
|
||||||
print(cached_hs.shape)
|
|
||||||
|
|
||||||
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_pos, dim)
|
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_pos, dim)
|
||||||
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_neg, dim)
|
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_neg, dim)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue