v0.2
parent
70130b385d
commit
d80e6e77c5
|
|
@ -0,0 +1,19 @@
|
|||
|
||||
function fabric_selected_gallery_index(elem_id) {
|
||||
const el = document.getElementById(elem_id);
|
||||
const buttons = el.querySelectorAll('.thumbnail-item.thumbnail-small');
|
||||
const button = el.querySelector('.thumbnail-item.thumbnail-small.selected');
|
||||
|
||||
let result = -1;
|
||||
buttons.forEach((v, i) => {
|
||||
if (v == button) {
|
||||
result = i;
|
||||
}
|
||||
});
|
||||
|
||||
console.log(result);
|
||||
console.log(button);
|
||||
console.log(buttons);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -10,12 +10,13 @@ from PIL import Image
|
|||
|
||||
import modules.scripts
|
||||
from modules import devices, script_callbacks, images
|
||||
from modules.ui_components import FormGroup
|
||||
|
||||
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
|
||||
from scripts.helpers import WebUiComponents
|
||||
|
||||
|
||||
__version__ = "0.1"
|
||||
__version__ = "0.2"
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
|
@ -43,10 +44,6 @@ class FabricParams:
|
|||
# TODO: replace global state with Gradio state
|
||||
class FabricState:
|
||||
batch_images = []
|
||||
selected_image = None
|
||||
uploaded_image = None
|
||||
liked_images = []
|
||||
disliked_images = []
|
||||
|
||||
|
||||
def encode_to_latent(p, feedback_imgs):
|
||||
|
|
@ -78,16 +75,19 @@ class FabricScript(modules.scripts.Script):
|
|||
return modules.scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.selected_image = gr.State(None)
|
||||
liked_images = gr.State([])
|
||||
disliked_images = gr.State([])
|
||||
selected_like = gr.State(None)
|
||||
selected_dislike = gr.State(None)
|
||||
|
||||
with gr.Accordion(f"{self.title()} v{__version__}", open=DEBUG, elem_id="fabric"):
|
||||
if DEBUG:
|
||||
like_example_btn = gr.Button("👍 Example")
|
||||
|
||||
feedback_disabled = gr.Checkbox(label="Disable", value=False)
|
||||
|
||||
with gr.Tabs():
|
||||
with gr.Tab("Current batch"):
|
||||
self.selected_img = gr.Image(value=None, type="pil", label="Selected image").style(height=256)
|
||||
self.selected_img_display = gr.Image(value=None, type="pil", label="Selected image").style(height=256)
|
||||
|
||||
with gr.Row():
|
||||
like_btn_selected = gr.Button("👍 Like")
|
||||
|
|
@ -100,16 +100,28 @@ class FabricScript(modules.scripts.Script):
|
|||
like_btn_uploaded = gr.Button("👍 Like")
|
||||
dislike_btn_uploaded = gr.Button("👎 Dislike")
|
||||
|
||||
|
||||
with gr.Tabs(initial_tab="👍 Likes"):
|
||||
with gr.Tab("👍 Likes"):
|
||||
like_gallery = gr.Gallery(label="Liked images").style(columns=5, height=256)
|
||||
with gr.Row():
|
||||
remove_selected_like_btn = gr.Button("Remove selected", interactive=False)
|
||||
clear_liked_btn = gr.Button("Clear")
|
||||
like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery").style(columns=5, height=128)
|
||||
with gr.Tab("👎 Dislikes"):
|
||||
dislike_gallery = gr.Gallery(label="Disliked images").style(columns=5, height=256)
|
||||
with gr.Row():
|
||||
remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False)
|
||||
clear_disliked_btn = gr.Button("Clear")
|
||||
dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery").style(columns=5, height=128)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
with FormGroup():
|
||||
feedback_disabled = gr.Checkbox(label="Disable feedback", value=False)
|
||||
|
||||
with FormGroup():
|
||||
with gr.Row():
|
||||
# TODO: figure out how to make the step size do what it's supposed to
|
||||
feedback_max_images = gr.Slider(minimum=0, maximum=10, step=1, value=4, label="Max. feedback images")
|
||||
|
||||
with gr.Row():
|
||||
feedback_start = gr.Slider(0.0, 1.0, value=0.0, label="Feedback start")
|
||||
feedback_end = gr.Slider(0.0, 1.0, value=0.8, label="Feedback end")
|
||||
|
|
@ -121,20 +133,49 @@ class FabricScript(modules.scripts.Script):
|
|||
|
||||
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
|
||||
|
||||
like_btn_selected.click(self.on_like_selected, inputs=None, outputs=like_gallery)
|
||||
dislike_btn_selected.click(self.on_dislike_selected, inputs=None, outputs=dislike_gallery)
|
||||
like_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, liked_images], outputs=[liked_images, like_gallery])
|
||||
dislike_btn_selected.click(self.add_image_to_state, inputs=[self.selected_image, disliked_images], outputs=[disliked_images, dislike_gallery])
|
||||
|
||||
like_btn_uploaded.click(self.on_like_uploaded, inputs=upload_img_input, outputs=like_gallery)
|
||||
dislike_btn_uploaded.click(self.on_dislike_uploaded, inputs=upload_img_input, outputs=dislike_gallery)
|
||||
like_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, liked_images], outputs=[liked_images, like_gallery])
|
||||
dislike_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, disliked_images], outputs=[disliked_images, dislike_gallery])
|
||||
|
||||
clear_liked_btn.click(self.clear_likes, inputs=None, outputs=like_gallery)
|
||||
clear_disliked_btn.click(self.clear_dislikes, inputs=None, outputs=dislike_gallery)
|
||||
clear_liked_btn.click(lambda _: [[], []], inputs=liked_images, outputs=[liked_images, like_gallery])
|
||||
clear_disliked_btn.click(lambda _: [[], []], inputs=disliked_images, outputs=[disliked_images, dislike_gallery])
|
||||
|
||||
like_gallery.select(
|
||||
self.select_for_removal,
|
||||
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_like_gallery')]",
|
||||
inputs=[like_gallery, like_gallery],
|
||||
outputs=[selected_like, remove_selected_like_btn],
|
||||
)
|
||||
|
||||
dislike_gallery.select(
|
||||
self.select_for_removal,
|
||||
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_dislike_gallery')]",
|
||||
inputs=[dislike_gallery, dislike_gallery],
|
||||
outputs=[selected_dislike, remove_selected_dislike_btn],
|
||||
)
|
||||
|
||||
remove_selected_like_btn.click(
|
||||
self.remove_selected,
|
||||
inputs=[liked_images, selected_like],
|
||||
outputs=[liked_images, like_gallery, selected_like, remove_selected_like_btn],
|
||||
)
|
||||
|
||||
remove_selected_dislike_btn.click(
|
||||
self.remove_selected,
|
||||
inputs=[disliked_images, selected_dislike],
|
||||
outputs=[disliked_images, dislike_gallery, selected_dislike, remove_selected_dislike_btn],
|
||||
)
|
||||
|
||||
if DEBUG:
|
||||
like_example_btn.click(functools.partial(self.on_like_example, "example1"), inputs=None, outputs=like_gallery)
|
||||
like_example_btn.click(functools.partial(self.on_like_example, example="example1"), inputs=liked_images, outputs=[liked_images, like_gallery])
|
||||
|
||||
return [
|
||||
liked_images,
|
||||
disliked_images,
|
||||
feedback_disabled,
|
||||
feedback_max_images,
|
||||
feedback_start,
|
||||
feedback_end,
|
||||
feedback_min_weight,
|
||||
|
|
@ -143,40 +184,34 @@ class FabricScript(modules.scripts.Script):
|
|||
]
|
||||
|
||||
|
||||
def on_like_selected(self, *args):
|
||||
if FabricState.selected_image is not None:
|
||||
FabricState.liked_images.append(FabricState.selected_image)
|
||||
return FabricState.liked_images
|
||||
def select_for_removal(self, gallery, selected_idx):
|
||||
return [
|
||||
selected_idx,
|
||||
gr.update(interactive=True),
|
||||
]
|
||||
|
||||
def on_dislike_selected(self, *args):
|
||||
if FabricState.selected_image is not None:
|
||||
FabricState.disliked_images.append(FabricState.selected_image)
|
||||
return FabricState.disliked_images
|
||||
def remove_selected(self, images, idx):
|
||||
if idx >= 0 and idx < len(images):
|
||||
images.pop(idx)
|
||||
|
||||
def on_like_uploaded(self, image):
|
||||
if image is not None:
|
||||
FabricState.liked_images.append(image)
|
||||
return FabricState.liked_images
|
||||
return [
|
||||
images,
|
||||
images,
|
||||
gr.update(value=None),
|
||||
gr.update(interactive=False),
|
||||
]
|
||||
|
||||
def on_dislike_uploaded(self, image):
|
||||
if image is not None:
|
||||
FabricState.disliked_images.append(image)
|
||||
return FabricState.disliked_images
|
||||
def add_image_to_state(self, img, images):
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
return images, images
|
||||
|
||||
def on_like_example(self, example="example1"):
|
||||
def on_like_example(self, liked_images, example="example1"):
|
||||
img_path = Path(__file__).parent.parent.absolute() / "images" / f"{example}.png"
|
||||
image = Image.open(img_path)
|
||||
if image is not None:
|
||||
FabricState.liked_images.append(image)
|
||||
return FabricState.liked_images
|
||||
|
||||
def clear_likes(self, *args):
|
||||
FabricState.liked_images = []
|
||||
return FabricState.liked_images
|
||||
|
||||
def clear_dislikes(self, *args):
|
||||
FabricState.disliked_images = []
|
||||
return FabricState.disliked_images
|
||||
liked_images.append(image)
|
||||
return liked_images, liked_images
|
||||
|
||||
def register_txt2img_gallery_select(self, gallery):
|
||||
gallery.select(
|
||||
|
|
@ -186,7 +221,7 @@ class FabricScript(modules.scripts.Script):
|
|||
gallery,
|
||||
gallery, # can be any Gradio component (but not None), will be overwritten with selected gallery index
|
||||
],
|
||||
outputs=self.selected_img,
|
||||
outputs=[self.selected_image, self.selected_img_display],
|
||||
)
|
||||
|
||||
def on_txt2img_gallery_select(self, gallery, selected_idx):
|
||||
|
|
@ -194,14 +229,16 @@ class FabricScript(modules.scripts.Script):
|
|||
idx = selected_idx - (len(gallery) - len(images))
|
||||
|
||||
if idx >= 0 and idx < len(images):
|
||||
FabricState.selected_image = images[idx]
|
||||
return gr.update(value=images[idx])
|
||||
return images[idx], gr.update(value=images[idx])
|
||||
else:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
def process(self, p, *args):
|
||||
(
|
||||
liked_images,
|
||||
disliked_images,
|
||||
feedback_disabled,
|
||||
feedback_max_images,
|
||||
feedback_start,
|
||||
feedback_end,
|
||||
feedback_min_weight,
|
||||
|
|
@ -210,8 +247,10 @@ class FabricScript(modules.scripts.Script):
|
|||
) = args
|
||||
|
||||
print("[FABRIC] Encoding feedback images into latent space...")
|
||||
pos_latents = encode_to_latent(p, FabricState.liked_images)
|
||||
neg_latents = encode_to_latent(p, FabricState.disliked_images)
|
||||
likes = liked_images[:int(feedback_max_images)]
|
||||
dislikes = disliked_images[:int(feedback_max_images)]
|
||||
pos_latents = encode_to_latent(p, likes)
|
||||
neg_latents = encode_to_latent(p, dislikes)
|
||||
|
||||
print("[FABRIC] Patching U-Net forward pass...")
|
||||
params = FabricParams(
|
||||
|
|
|
|||
|
|
@ -37,8 +37,6 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
z = p.sd_model.q_sample(latent.unsqueeze(0), torch.round(timesteps.float()).long())[0]
|
||||
all_zs.append(z)
|
||||
all_zs = torch.stack(all_zs, dim=0)
|
||||
print(all_zs.shape)
|
||||
print(all_zs[:, 0, 0, 0])
|
||||
|
||||
## cache hidden states
|
||||
cached_hiddens = []
|
||||
|
|
@ -68,11 +66,6 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
seq_len, d_model = x.shape[1:]
|
||||
num_pos = len(params.pos_latents)
|
||||
num_neg = len(params.neg_latents)
|
||||
print("num_pos", num_pos)
|
||||
print("num_neg", num_neg)
|
||||
print("seq_len", seq_len)
|
||||
print("d_model", d_model)
|
||||
print(cached_hs.shape)
|
||||
|
||||
pos_hs = cached_hs[:num_pos].view(1, num_pos * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_pos, dim)
|
||||
neg_hs = cached_hs[num_pos:].view(1, num_neg * seq_len, d_model).expand(batch_size, -1, -1) # (bs, seq * n_neg, dim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue