diff --git a/scripts/fabric.py b/scripts/fabric.py index 08ea15f..46b09ce 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -27,7 +27,7 @@ except ImportError: from modules.ui import create_refresh_button -__version__ = "0.6.1" +__version__ = "0.6.2" DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1") @@ -172,22 +172,22 @@ class FabricScript(modules.scripts.Script): feedback_enabled = gr.Checkbox(label="Enable", value=False) feedback_during_high_res_fix = gr.Checkbox(label="Enable during hires. fix", value=False) - with gr.Row(): - presets_list = gr.Dropdown(label="Presets", choices=_load_presets(), default=None, live=False) - create_refresh_button(presets_list, lambda: None, lambda: {"choices": _load_presets()}, "fabric_reload_presets_btn") + with gr.Row(): + presets_list = gr.Dropdown(label="Presets", choices=_load_presets(), default=None, live=False) + create_refresh_button(presets_list, lambda: None, lambda: {"choices": _load_presets()}, "fabric_reload_presets_btn") with gr.Tabs(): with gr.Tab("Current batch"): # TODO: figure out why the display is shared between tabs - self.img2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=is_img2img).style(height=256) - self.txt2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=not is_img2img).style(height=256) + self.img2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=is_img2img, height=256) + self.txt2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=not is_img2img, height=256) with gr.Row(): like_btn_selected = gr.Button("👍 Like") dislike_btn_selected = gr.Button("👎 Dislike") with gr.Tab("Upload image"): - upload_img_input = gr.Image(type="pil", label="Upload image").style(height=256) + upload_img_input = gr.Image(type="pil", label="Upload image", height=256) with gr.Row(): like_btn_uploaded = gr.Button("👍 Like") @@ -198,37 +198,35 @@ class FabricScript(modules.scripts.Script): with gr.Row(): remove_selected_like_btn = gr.Button("Remove selected", interactive=False) clear_liked_btn = gr.Button("Clear") - like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery").style(columns=4, height=128) + like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery", columns=4, height=192) with gr.Tab("👎 Dislikes"): with gr.Row(): remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False) clear_disliked_btn = gr.Button("Clear") - dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery").style(columns=4, height=128) + dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery", columns=4, height=192) save_preset_btn = gr.Button("Save as preset") - gr.HTML("
") + gr.HTML("
") - with gr.Column(): + with FormGroup(): with FormRow(): feedback_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Feedback start") feedback_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="Feedback end") - with FormRow(): - feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight") - with FormRow(): - tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False) + + feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight") + tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False) with gr.Accordion("Advanced options", open=DEBUG): with FormGroup(): feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. strength", info="Minimum feedback strength at every diffusion step.", elem_id="fabric_min_weight") feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Negative weight", info="Strength of negative feedback relative to positive feedback.", elem_id="fabric_neg_scale") - with FormGroup(): tome_ratio = gr.Slider(minimum=0.0, maximum=0.75, step=0.125, value=0.5, label="ToMe merge ratio", info="Percentage of tokens to be merged (higher improves speed)", elem_id="fabric_tome_ratio") tome_max_tokens = gr.Slider(minimum=4096, maximum=16*4096, step=4096, value=2*4096, label="ToMe max. tokens", info="Maximum number of tokens after merging (lower improves VRAM usage)", elem_id="fabric_tome_max_tokens") - tome_seed = gr.Number(label="ToMe seed", value=-1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed") + tome_seed = gr.Number(label="ToMe seed", value=-1, step=1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed") @@ -455,7 +453,7 @@ class FabricScript(modules.scripts.Script): tome_enabled=tome_enabled, tome_ratio=(round(tome_ratio * 16) / 16), tome_max_tokens=tome_max_tokens, - tome_seed=get_fixed_seed(tome_seed), + tome_seed=get_fixed_seed(int(tome_seed)), ) @@ -468,6 +466,11 @@ class FabricScript(modules.scripts.Script): log_params["neg_images"] = json.dumps(disliked_paths) del log_params["enabled"] + if not params.tome_enabled: + del log_params["tome_ratio"] + del log_params["tome_max_tokens"] + del log_params["tome_seed"] + log_params = {f"fabric_{k}": v for k, v in log_params.items()} p.extra_generation_params.update(log_params) diff --git a/scripts/patching.py b/scripts/patching.py index e89a88b..c2a8963 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -202,10 +202,11 @@ def patch_unet_forward_pass(p, unet, params): merge, _ = compute_merge(feedback_ctx, args=tome_args, size=(tome_h_latent * num_fb, w_latent), max_tokens=params.tome_max_tokens) feedback_ctx = merge(feedback_ctx) ctx = torch.cat([context, feedback_ctx], dim=1) # (n_cond, seq + seq*n_pos, dim) + weights = torch.ones(ctx.shape[1], device=ctx.device, dtype=ctx.dtype) # (seq + seq*n_pos,) + weights[_x.shape[1]:] = w else: ctx = context - weights = torch.ones_like(ctx[0, :, 0]) # (seq + seq*n_pos,) - weights[_x.shape[1]:] = w + weights = None return weighted_attention(attn1, attn1._fabric_old_forward, _x, ctx, weights, **kwargs) # (n_cond, seq, dim) outs = [] diff --git a/scripts/weighted_attention.py b/scripts/weighted_attention.py index 5e74209..4263632 100644 --- a/scripts/weighted_attention.py +++ b/scripts/weighted_attention.py @@ -63,12 +63,20 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs): def _get_attn_bias(weights, shape=None, dtype=torch.float32): + # shape of weights needs to be divisible by 8 in order for xformers attn bias to work + last_dim = ((weights.shape[-1] - 1) // 8 + 1) * 8 + w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=weights.dtype) + min_val = torch.finfo(dtype).min - w_bias = weights.log().clamp(min=min_val) + w_bias[..., :weights.shape[-1]] = weights.log().clamp(min=min_val) + if shape is not None: - assert shape[-1] == w_bias.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)" - w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape) + assert shape[-1] == weights.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)" + w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape[:-1] + (last_dim,)) + + # cast first in order to preserve multiple-of-8 stride w_bias = w_bias.to(dtype=dtype) + w_bias = w_bias[..., :weights.shape[-1]] return w_bias ### The following attn functions are copied and adapted from modules.sd_hijack_optimizations diff --git a/style.css b/style.css index 8cab3fa..726c975 100644 --- a/style.css +++ b/style.css @@ -2,3 +2,7 @@ #fabric_like_gallery img, #fabric_dislike_gallery img { object-fit: scale-down; } + +#fabric { + --layout-gap: 0.75rem; +} \ No newline at end of file