commit
e156cef442
|
|
@ -139,6 +139,7 @@ class FabricParams:
|
|||
tome_ratio: float = 0.5
|
||||
tome_max_tokens: int = 4*4096
|
||||
tome_seed: int = -1
|
||||
burnout_protection: bool = False
|
||||
|
||||
|
||||
# TODO: replace global state with Gradio state
|
||||
|
|
@ -218,6 +219,7 @@ class FabricScript(modules.scripts.Script):
|
|||
|
||||
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
|
||||
tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
|
||||
burnout_protection = gr.Checkbox(label="Burnout protection (enable if results contain artifacts or are especially dark)", value=False)
|
||||
|
||||
with gr.Accordion("Advanced options", open=DEBUG):
|
||||
with FormGroup():
|
||||
|
|
@ -302,6 +304,7 @@ class FabricScript(modules.scripts.Script):
|
|||
(tome_ratio, "fabric_tome_ratio"),
|
||||
(tome_max_tokens, "fabric_tome_max_tokens"),
|
||||
(tome_seed, "fabric_tome_seed"),
|
||||
(burnout_protection, "fabric_burnout_protection"),
|
||||
(feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"),
|
||||
(liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
|
||||
(disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
|
||||
|
|
@ -323,6 +326,7 @@ class FabricScript(modules.scripts.Script):
|
|||
tome_ratio,
|
||||
tome_max_tokens,
|
||||
tome_seed,
|
||||
burnout_protection,
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -429,6 +433,7 @@ class FabricScript(modules.scripts.Script):
|
|||
tome_ratio,
|
||||
tome_max_tokens,
|
||||
tome_seed,
|
||||
burnout_protection,
|
||||
) = args
|
||||
|
||||
# restore original U-Net forward pass in case previous batch errored out
|
||||
|
|
@ -454,6 +459,7 @@ class FabricScript(modules.scripts.Script):
|
|||
tome_ratio=(round(tome_ratio * 16) / 16),
|
||||
tome_max_tokens=tome_max_tokens,
|
||||
tome_seed=get_fixed_seed(int(tome_seed)),
|
||||
burnout_protection=burnout_protection,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -156,6 +156,10 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
"seed": params.tome_seed,
|
||||
}
|
||||
|
||||
prev_vals = {
|
||||
"weight_modifier": 1.0,
|
||||
}
|
||||
|
||||
def new_forward(self, x, timesteps=None, context=None, **kwargs):
|
||||
_, uncond_ids, cond_ids, context = unmark_prompt_context(context)
|
||||
has_cond = len(cond_ids) > 0
|
||||
|
|
@ -167,11 +171,22 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
if not params.feedback_during_high_res_fix:
|
||||
print("[FABRIC] Skipping feedback during high-res fix")
|
||||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
|
||||
pos_weight, neg_weight = get_curr_feedback_weight(p, params, timesteps[0].item(), num_timesteps=num_timesteps)
|
||||
if pos_weight <= 0 and neg_weight <= 0:
|
||||
return self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
if params.burnout_protection and "cond" in prev_vals and "uncond" in prev_vals:
|
||||
# burnout protection: if the difference betwen cond/uncond was too high in the previous step (sign of instability), slash the weight modifier
|
||||
diff_std = (prev_vals["cond"] - prev_vals["uncond"]).std(dim=(2, 3)).max().item()
|
||||
diff_abs_mean = (prev_vals["cond"] - prev_vals["uncond"]).mean(dim=(2, 3)).abs().max().item()
|
||||
if diff_std > 0.06 or diff_abs_mean > 0.02:
|
||||
prev_vals["weight_modifier"] *= 0.5
|
||||
else:
|
||||
prev_vals["weight_modifier"] = min(1.0, 1.5 * prev_vals["weight_modifier"])
|
||||
|
||||
pos_weight, neg_weight = pos_weight * prev_vals["weight_modifier"], neg_weight * prev_vals["weight_modifier"]
|
||||
|
||||
pos_latents, neg_latents = get_latents_from_params(p, params, w, h)
|
||||
pos_latents = pos_latents if has_cond else []
|
||||
neg_latents = neg_latents if has_uncond else []
|
||||
|
|
@ -288,6 +303,18 @@ def patch_unet_forward_pass(p, unet, params):
|
|||
# run forward pass with cached hidden states
|
||||
out = self._fabric_old_forward(x, timesteps, context, **kwargs)
|
||||
|
||||
cond_outs = out[cond_ids]
|
||||
uncond_outs = out[uncond_ids]
|
||||
|
||||
if has_cond:
|
||||
prev_vals["cond"] = cond_outs.detach().clone()
|
||||
if has_uncond:
|
||||
prev_vals["uncond"] = uncond_outs.detach().clone()
|
||||
|
||||
if params.burnout_protection:
|
||||
# burnout protection: recenter the output to prevent instabilities caused by mean drift
|
||||
mean = out.mean(dim=(2, 3), keepdim=True)
|
||||
out = out - 0.5 * mean
|
||||
finally:
|
||||
# restore original pass
|
||||
for module in self.modules():
|
||||
|
|
|
|||
Loading…
Reference in New Issue