diff --git a/scripts/fabric.py b/scripts/fabric.py index 1cd1117..d3c2b68 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -139,6 +139,7 @@ class FabricParams: tome_ratio: float = 0.5 tome_max_tokens: int = 4*4096 tome_seed: int = -1 + burnout_protection: bool = False # TODO: replace global state with Gradio state @@ -218,6 +219,7 @@ class FabricScript(modules.scripts.Script): feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight") tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False) + burnout_protection = gr.Checkbox(label="Burnout protection (enable if results contain artifacts or are especially dark)", value=False) with gr.Accordion("Advanced options", open=DEBUG): with FormGroup(): @@ -302,6 +304,7 @@ class FabricScript(modules.scripts.Script): (tome_ratio, "fabric_tome_ratio"), (tome_max_tokens, "fabric_tome_max_tokens"), (tome_seed, "fabric_tome_seed"), + (burnout_protection, "fabric_burnout_protection"), (feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"), (liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None), (disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None), @@ -323,6 +326,7 @@ class FabricScript(modules.scripts.Script): tome_ratio, tome_max_tokens, tome_seed, + burnout_protection, ] @@ -429,6 +433,7 @@ class FabricScript(modules.scripts.Script): tome_ratio, tome_max_tokens, tome_seed, + burnout_protection, ) = args # restore original U-Net forward pass in case previous batch errored out @@ -454,6 +459,7 @@ class FabricScript(modules.scripts.Script): tome_ratio=(round(tome_ratio * 16) / 16), tome_max_tokens=tome_max_tokens, tome_seed=get_fixed_seed(int(tome_seed)), + burnout_protection=burnout_protection, ) diff --git a/scripts/patching.py b/scripts/patching.py index e3ed907..a298a45 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -156,6 +156,10 @@ def patch_unet_forward_pass(p, unet, params): "seed": params.tome_seed, } + prev_vals = { + "weight_modifier": 1.0, + } + def new_forward(self, x, timesteps=None, context=None, **kwargs): _, uncond_ids, cond_ids, context = unmark_prompt_context(context) has_cond = len(cond_ids) > 0 @@ -167,11 +171,22 @@ def patch_unet_forward_pass(p, unet, params): if not params.feedback_during_high_res_fix: print("[FABRIC] Skipping feedback during high-res fix") return self._fabric_old_forward(x, timesteps, context, **kwargs) - + pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps) if pos_weight <= 0 and neg_weight <= 0: return self._fabric_old_forward(x, timesteps, context, **kwargs) + if params.burnout_protection and "cond" in prev_vals and "uncond" in prev_vals: + # burnout protection: if the difference betwen cond/uncond was too high in the previous step (sign of instability), slash the weight modifier + diff_std = (prev_vals["cond"] - prev_vals["uncond"]).std(dim=(2, 3)).max().item() + diff_abs_mean = (prev_vals["cond"] - prev_vals["uncond"]).mean(dim=(2, 3)).abs().max().item() + if diff_std > 0.06 or diff_abs_mean > 0.02: + prev_vals["weight_modifier"] *= 0.5 + else: + prev_vals["weight_modifier"] = min(1.0, 1.5 * prev_vals["weight_modifier"]) + + pos_weight, neg_weight = pos_weight * prev_vals["weight_modifier"], neg_weight * prev_vals["weight_modifier"] + pos_latents, neg_latents = get_latents_from_params(p, params, w, h) pos_latents = pos_latents if has_cond else [] neg_latents = neg_latents if has_uncond else [] @@ -288,6 +303,18 @@ def patch_unet_forward_pass(p, unet, params): # run forward pass with cached hidden states out = self._fabric_old_forward(x, timesteps, context, **kwargs) + cond_outs = out[cond_ids] + uncond_outs = out[uncond_ids] + + if has_cond: + prev_vals["cond"] = cond_outs.detach().clone() + if has_uncond: + prev_vals["uncond"] = uncond_outs.detach().clone() + + if params.burnout_protection: + # burnout protection: recenter the output to prevent instabilities caused by mean drift + mean = out.mean(dim=(2, 3), keepdim=True) + out = out - 0.5 * mean finally: # restore original pass for module in self.modules():