diff --git a/scripts/fabric.py b/scripts/fabric.py
index 08ea15f..46b09ce 100644
--- a/scripts/fabric.py
+++ b/scripts/fabric.py
@@ -27,7 +27,7 @@ except ImportError:
from modules.ui import create_refresh_button
-__version__ = "0.6.1"
+__version__ = "0.6.2"
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
@@ -172,22 +172,22 @@ class FabricScript(modules.scripts.Script):
feedback_enabled = gr.Checkbox(label="Enable", value=False)
feedback_during_high_res_fix = gr.Checkbox(label="Enable during hires. fix", value=False)
- with gr.Row():
- presets_list = gr.Dropdown(label="Presets", choices=_load_presets(), default=None, live=False)
- create_refresh_button(presets_list, lambda: None, lambda: {"choices": _load_presets()}, "fabric_reload_presets_btn")
+ with gr.Row():
+ presets_list = gr.Dropdown(label="Presets", choices=_load_presets(), default=None, live=False)
+ create_refresh_button(presets_list, lambda: None, lambda: {"choices": _load_presets()}, "fabric_reload_presets_btn")
with gr.Tabs():
with gr.Tab("Current batch"):
# TODO: figure out why the display is shared between tabs
- self.img2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=is_img2img).style(height=256)
- self.txt2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=not is_img2img).style(height=256)
+ self.img2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=is_img2img, height=256)
+ self.txt2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=not is_img2img, height=256)
with gr.Row():
like_btn_selected = gr.Button("👍 Like")
dislike_btn_selected = gr.Button("👎 Dislike")
with gr.Tab("Upload image"):
- upload_img_input = gr.Image(type="pil", label="Upload image").style(height=256)
+ upload_img_input = gr.Image(type="pil", label="Upload image", height=256)
with gr.Row():
like_btn_uploaded = gr.Button("👍 Like")
@@ -198,37 +198,35 @@ class FabricScript(modules.scripts.Script):
with gr.Row():
remove_selected_like_btn = gr.Button("Remove selected", interactive=False)
clear_liked_btn = gr.Button("Clear")
- like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery").style(columns=4, height=128)
+ like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery", columns=4, height=192)
with gr.Tab("👎 Dislikes"):
with gr.Row():
remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False)
clear_disliked_btn = gr.Button("Clear")
- dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery").style(columns=4, height=128)
+ dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery", columns=4, height=192)
save_preset_btn = gr.Button("Save as preset")
- gr.HTML("
")
+ gr.HTML("
")
- with gr.Column():
+ with FormGroup():
with FormRow():
feedback_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Feedback start")
feedback_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="Feedback end")
- with FormRow():
- feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
- with FormRow():
- tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
+
+ feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
+ tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
with gr.Accordion("Advanced options", open=DEBUG):
with FormGroup():
feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. strength", info="Minimum feedback strength at every diffusion step.", elem_id="fabric_min_weight")
feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Negative weight", info="Strength of negative feedback relative to positive feedback.", elem_id="fabric_neg_scale")
- with FormGroup():
tome_ratio = gr.Slider(minimum=0.0, maximum=0.75, step=0.125, value=0.5, label="ToMe merge ratio", info="Percentage of tokens to be merged (higher improves speed)", elem_id="fabric_tome_ratio")
tome_max_tokens = gr.Slider(minimum=4096, maximum=16*4096, step=4096, value=2*4096, label="ToMe max. tokens", info="Maximum number of tokens after merging (lower improves VRAM usage)", elem_id="fabric_tome_max_tokens")
- tome_seed = gr.Number(label="ToMe seed", value=-1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed")
+ tome_seed = gr.Number(label="ToMe seed", value=-1, step=1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed")
@@ -455,7 +453,7 @@ class FabricScript(modules.scripts.Script):
tome_enabled=tome_enabled,
tome_ratio=(round(tome_ratio * 16) / 16),
tome_max_tokens=tome_max_tokens,
- tome_seed=get_fixed_seed(tome_seed),
+ tome_seed=get_fixed_seed(int(tome_seed)),
)
@@ -468,6 +466,11 @@ class FabricScript(modules.scripts.Script):
log_params["neg_images"] = json.dumps(disliked_paths)
del log_params["enabled"]
+ if not params.tome_enabled:
+ del log_params["tome_ratio"]
+ del log_params["tome_max_tokens"]
+ del log_params["tome_seed"]
+
log_params = {f"fabric_{k}": v for k, v in log_params.items()}
p.extra_generation_params.update(log_params)
diff --git a/scripts/patching.py b/scripts/patching.py
index e89a88b..c2a8963 100644
--- a/scripts/patching.py
+++ b/scripts/patching.py
@@ -202,10 +202,11 @@ def patch_unet_forward_pass(p, unet, params):
merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens)
feedback_ctx = merge(feedback_ctx)
ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim)
+ weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,)
+ weights[_x.shape[1]:] = w
else:
ctx = context
- weights = torch.ones_like(ctx[0, :, 0]) # (seq + seq*n_pos,)
- weights[_x.shape[1]:] = w
+ weights = None
return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim)
outs = []
diff --git a/scripts/weighted_attention.py b/scripts/weighted_attention.py
index 5e74209..4263632 100644
--- a/scripts/weighted_attention.py
+++ b/scripts/weighted_attention.py
@@ -63,12 +63,20 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
def _get_attn_bias(weights, shape=None, dtype=torch.float32):
+ # shape of weights needs to be divisible by 8 in order for xformers attn bias to work
+ last_dim = ((weights.shape[-1] - 1) // 8 + 1) * 8
+ w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=weights.dtype)
+
min_val = torch.finfo(dtype).min
- w_bias = weights.log().clamp(min=min_val)
+ w_bias[..., :weights.shape[-1]] = weights.log().clamp(min=min_val)
+
if shape is not None:
- assert shape[-1] == w_bias.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)"
- w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape)
+ assert shape[-1] == weights.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)"
+ w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape[:-1] + (last_dim,))
+
+ # cast first in order to preserve multiple-of-8 stride
w_bias = w_bias.to(dtype=dtype)
+ w_bias = w_bias[..., :weights.shape[-1]]
return w_bias
### The following attn functions are copied and adapted from modules.sd_hijack_optimizations
diff --git a/style.css b/style.css
index 8cab3fa..726c975 100644
--- a/style.css
+++ b/style.css
@@ -2,3 +2,7 @@
#fabric_like_gallery img, #fabric_dislike_gallery img {
object-fit: scale-down;
}
+
+#fabric {
+ --layout-gap: 0.75rem;
+}
\ No newline at end of file