72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
import torch
|
|
import ldm.models.diffusion.ddpm
|
|
from modules import shared
|
|
|
|
|
|
class Scheduler:
|
|
""" Proportional Noise Step Scheduler"""
|
|
def __init__(self, cycle_step=128, repeat=True):
|
|
self.disabled = True
|
|
self.cycle_step = int(cycle_step)
|
|
self.repeat = repeat
|
|
self.run_assertion()
|
|
|
|
def __call__(self, value, step):
|
|
if self.disabled:
|
|
return value
|
|
if self.repeat:
|
|
step %= self.cycle_step
|
|
return max(1, int(value * step / self.cycle_step))
|
|
else:
|
|
return value if step >= self.cycle_step else max(1, int(value * step / self.cycle_step))
|
|
|
|
def run_assertion(self):
|
|
assert type(self.cycle_step) is int
|
|
assert type(self.repeat) is bool
|
|
assert not self.repeat or self.cycle_step > 0
|
|
|
|
def set(self, cycle_step=-1, repeat=-1, disabled=True):
|
|
self.disabled = disabled
|
|
if cycle_step >= 0:
|
|
self.cycle_step = int(cycle_step)
|
|
if repeat != -1:
|
|
self.repeat = repeat
|
|
self.run_assertion()
|
|
|
|
|
|
training_scheduler = Scheduler(cycle_step=-1, repeat=False)
|
|
|
|
|
|
def get_current(value, step=None):
|
|
if step is None:
|
|
if hasattr(shared, 'accessible_hypernetwork'):
|
|
hypernetwork = shared.accessible_hypernetwork
|
|
else:
|
|
return value
|
|
if hasattr(hypernetwork, 'step') and hypernetwork.training and hypernetwork.step is not None:
|
|
return training_scheduler(value, hypernetwork.step)
|
|
return value
|
|
return max(1, training_scheduler(value, step))
|
|
|
|
|
|
def set_scheduler(cycle_step, repeat, enabled=False):
|
|
global training_scheduler
|
|
training_scheduler.set(cycle_step, repeat, not enabled)
|
|
|
|
|
|
def forward(self, x, c, *args, **kwargs):
|
|
t = torch.randint(0, get_current(self.num_timesteps), (x.shape[0],), device=self.device).long()
|
|
if self.model.conditioning_key is not None:
|
|
assert c is not None
|
|
if self.cond_stage_trainable:
|
|
c = self.get_learned_conditioning(c)
|
|
if self.shorten_cond_schedule: # TODO: drop this option
|
|
tc = self.cond_ids[t].to(self.device)
|
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
|
return self.p_losses(x, c, t, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
ldm.models.diffusion.ddpm.LatentDiffusion.forward = forward
|