Merge pull request #38 from dvruette/fix-burnout

Burnout protection
main
Dimitri 2024-03-09 18:17:40 +01:00 committed by GitHub
commit e156cef442
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 1 deletions

View File

@ -139,6 +139,7 @@ class FabricParams:
tome_ratio: float = 0.5
tome_max_tokens: int = 4*4096
tome_seed: int = -1
burnout_protection: bool = False
# TODO: replace global state with Gradio state
@ -218,6 +219,7 @@ class FabricScript(modules.scripts.Script):
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
burnout_protection = gr.Checkbox(label="Burnout protection (enable if results contain artifacts or are especially dark)", value=False)
with gr.Accordion("Advanced options", open=DEBUG):
with FormGroup():
@ -302,6 +304,7 @@ class FabricScript(modules.scripts.Script):
(tome_ratio, "fabric_tome_ratio"),
(tome_max_tokens, "fabric_tome_max_tokens"),
(tome_seed, "fabric_tome_seed"),
(burnout_protection, "fabric_burnout_protection"),
(feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"),
(liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
(disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
@ -323,6 +326,7 @@ class FabricScript(modules.scripts.Script):
tome_ratio,
tome_max_tokens,
tome_seed,
burnout_protection,
]
@ -429,6 +433,7 @@ class FabricScript(modules.scripts.Script):
tome_ratio,
tome_max_tokens,
tome_seed,
burnout_protection,
) = args
# restore original U-Net forward pass in case previous batch errored out
@ -454,6 +459,7 @@ class FabricScript(modules.scripts.Script):
tome_ratio=(round(tome_ratio * 16) / 16),
tome_max_tokens=tome_max_tokens,
tome_seed=get_fixed_seed(int(tome_seed)),
burnout_protection=burnout_protection,
)

View File

@ -156,6 +156,10 @@ def patch_unet_forward_pass(p, unet, params):
"seed": params.tome_seed,
}
prev_vals = {
"weight_modifier": 1.0,
}
def new_forward(self, x, timesteps=None, context=None, **kwargs):
_, uncond_ids, cond_ids, context = unmark_prompt_context(context)
has_cond = len(cond_ids) > 0
@ -167,11 +171,22 @@ def patch_unet_forward_pass(p, unet, params):
if not params.feedback_during_high_res_fix:
print("[FABRIC] Skipping feedback during high-res fix")
return self._fabric_old_forward(x, timesteps, context, **kwargs)
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
if pos_weight <= 0 and neg_weight <= 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)
if params.burnout_protection and "cond" in prev_vals and "uncond" in prev_vals:
# burnout protection: if the difference betwen cond/uncond was too high in the previous step (sign of instability), slash the weight modifier
diff_std = (prev_vals["cond"] - prev_vals["uncond"]).std(dim=(2, 3)).max().item()
diff_abs_mean = (prev_vals["cond"] - prev_vals["uncond"]).mean(dim=(2, 3)).abs().max().item()
if diff_std > 0.06 or diff_abs_mean > 0.02:
prev_vals["weight_modifier"] *= 0.5
else:
prev_vals["weight_modifier"] = min(1.0, 1.5 * prev_vals["weight_modifier"])
pos_weight, neg_weight = pos_weight * prev_vals["weight_modifier"], neg_weight * prev_vals["weight_modifier"]
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
@ -288,6 +303,18 @@ def patch_unet_forward_pass(p, unet, params):
# run forward pass with cached hidden states
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
cond_outs = out[cond_ids]
uncond_outs = out[uncond_ids]
if has_cond:
prev_vals["cond"] = cond_outs.detach().clone()
if has_uncond:
prev_vals["uncond"] = uncond_outs.detach().clone()
if params.burnout_protection:
# burnout protection: recenter the output to prevent instabilities caused by mean drift
mean = out.mean(dim=(2, 3), keepdim=True)
out = out - 0.5 * mean
finally:
# restore original pass
for module in self.modules():