automatic/modules/sd_hijack_freeu.py

161 lines
6.1 KiB
Python

import math
import torch
from modules import shared, devices
from modules.logger import log
# based on <https://github.com/ljleb/sd-webui-freeu/blob/main/lib_free_u/unet.py>
# official params are b1,b2,s1,s2
# extra params that can be made configurable if needed are:
backbone_width = 0.5
backbone_offset = 0.0
skip_cutoff = 0.0
skip_high_end_factor = 1.0
start_ratio = 0.0
stop_ratio = 1.0
transition_smoothness = 0.0
# internal state
state_enabled = False
def to_denoising_step(number, steps=None) -> int:
if steps is None:
steps = shared.state.sampling_steps
if isinstance(number, float):
return int(number * steps)
return number
def get_schedule_ratio():
start_step = to_denoising_step(start_ratio)
stop_step = to_denoising_step(stop_ratio)
if start_step == stop_step:
smooth_schedule_ratio = 0.0
elif shared.state.sampling_step < start_step:
smooth_schedule_ratio = min(1.0, max(0.0, shared.state.sampling_step / start_step))
else:
smooth_schedule_ratio = min(1.0, max(0.0, 1 + (shared.state.sampling_step - start_step) / (start_step - stop_step)))
flat_schedule_ratio = 1.0 if start_step <= shared.state.sampling_step < stop_step else 0.0
return lerp(flat_schedule_ratio, smooth_schedule_ratio, transition_smoothness)
def lerp(a, b, r):
return (1-r)*a + r*b
def free_u_cat_hijack(hs, *args, original_function, **kwargs):
if not shared.opts.freeu_enabled:
return original_function(hs, *args, **kwargs)
schedule_ratio = get_schedule_ratio()
if schedule_ratio == 0:
return original_function(hs, *args, **kwargs)
try:
h, h_skip = hs
if list(kwargs.keys()) != ["dim"] or kwargs.get("dim", -1) != 1:
return original_function(hs, *args, **kwargs)
except ValueError:
return original_function(hs, *args, **kwargs)
dims = h.shape[1]
if dims not in [1280, 640, 320]:
return original_function(hs, *args, **kwargs)
index = [1280, 640, 320].index(dims)
if index > 1: # not 1st or 2nd stage
return original_function([h, h_skip], *args, **kwargs)
region_begin, region_end, region_inverted = ratio_to_region(backbone_width, backbone_offset, dims)
mask = torch.arange(dims)
mask = (region_begin <= mask) & (mask <= region_end)
if region_inverted:
mask = ~mask
backbone_factor = shared.opts.freeu_b1 if index == 0 else shared.opts.freeu_b2
skip_factor = shared.opts.freeu_s1 if index == 0 else shared.opts.freeu_s2
h[:, mask] *= lerp(1, backbone_factor, schedule_ratio)
h_skip = filter_skip(h_skip, threshold=skip_cutoff, scale=lerp(1, skip_factor, schedule_ratio), scale_high=lerp(1, skip_high_end_factor, schedule_ratio))
return original_function([h, h_skip], *args, **kwargs)
torch_fft_device = None
def get_fft_device():
global torch_fft_device # pylint: disable=global-statement
if torch_fft_device is None:
try:
tensor = torch.randn(4, 4)
tensor = tensor.to(device=devices.device, dtype=devices.dtype)
_fft_result = torch.fft.fftn(tensor)
_ifft_result = torch.fft.ifftn(_fft_result)
_shifted_tensor = torch.fft.fftshift(tensor)
_ishifted_tensor = torch.fft.ifftshift(_shifted_tensor)
torch_fft_device = devices.device
except Exception:
torch_fft_device = devices.cpu
log.warning(f'FreeU: device={devices.device} dtype={devices.dtype} does not support FFT')
return torch_fft_device
def no_gpu_complex_support():
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
try:
import torch_directml
except ImportError:
dml_available = False
else:
dml_available = torch_directml.is_available()
return mps_available or dml_available
def filter_skip(x, threshold, scale, scale_high):
if scale == 1 and scale_high == 1:
return x
fft_device = get_fft_device()
# if no_gpu_complex_support():
# fft_device = "cpu"
# FFT
x_freq = torch.fft.fftn(x.to(fft_device).float(), dim=(-2, -1)) # pylint: disable=E1102
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) # pylint: disable=E1102
B, C, H, W = x_freq.shape
mask = torch.full((B, C, H, W), float(scale_high), device=fft_device)
crow, ccol = H // 2, W // 2
threshold_row = max(1, math.floor(crow * threshold))
threshold_col = max(1, math.floor(ccol * threshold))
mask[..., crow - threshold_row:crow + threshold_row, ccol - threshold_col:ccol + threshold_col] = scale
x_freq *= mask
# IFFT
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) # pylint: disable=E1102
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real.to(device=x.device, dtype=x.dtype) # pylint: disable=E1102
return x_filtered
def ratio_to_region(width: float, offset: float, n: int):
if width < 0:
offset += width
width = -width
width = min(width, 1)
if offset < 0:
offset = 1 + offset - int(offset)
offset = math.fmod(offset, 1.0)
if width + offset <= 1:
inverted = False
start = offset * n
end = (width + offset) * n
else:
inverted = True
start = (width + offset - 1) * n
end = offset * n
return round(start), round(end), inverted
def apply_freeu(p):
global state_enabled # pylint: disable=global-statement
if hasattr(shared.sd_model, 'enable_freeu'):
if shared.opts.freeu_enabled:
freeu_device = get_fft_device()
if freeu_device != devices.cpu:
p.extra_generation_params['FreeU'] = f'b1={shared.opts.freeu_b1} b2={shared.opts.freeu_b2} s1={shared.opts.freeu_s1} s2={shared.opts.freeu_s2}'
shared.sd_model.enable_freeu(s1=shared.opts.freeu_s1, s2=shared.opts.freeu_s2, b1=shared.opts.freeu_b1, b2=shared.opts.freeu_b2)
state_enabled = True
elif state_enabled:
shared.sd_model.disable_freeu()
state_enabled = False
if shared.opts.freeu_enabled and state_enabled:
log.info(f'Applying Free-U: b1={shared.opts.freeu_b1} b2={shared.opts.freeu_b2} s1={shared.opts.freeu_s1} s2={shared.opts.freeu_s2}')