notify #316, support SDXL on img2img Noise Inversion

pull/356/head
Kahsolt 2023-10-20 13:12:26 +08:00
parent d91a9b7ac1
commit 350dc0e3e0
1 changed files with 19 additions and 15 deletions

View File

@ -690,29 +690,33 @@ class AbstractDiffusion:
x = self.p.init_latent
s_in = x.new_ones([x.shape[0]])
if shared.sd_model.parameterization == "v":
skip = 1
else:
skip = 0
cond = self.p.sd_model.get_learned_conditioning(prompts)
# NOTE: should be List[Tensor]
cond_dict_dummy = {
'c_crossattn': [],
'c_concat': [],
}
cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning)
skip = 1 if shared.sd_model.parameterization == "v" else 0
sigmas = dnw.get_sigmas(steps).flip(0)
state.sampling_steps = steps
cond = self.p.sd_model.get_learned_conditioning(prompts)
if isinstance(cond, Tensor): # SD1/SD2
cond_dict_dummy = {
'c_crossattn': [], # List[Tensor]
'c_concat': [], # List[Tensor]
}
cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning)
else: # SDXL
cond_dict_dummy = {
'crossattn': None, # Tensor
'vector': None, # Tensor
'c_concat': [], # List[Tensor]
}
cond_in = self.make_cond_dict(cond_dict_dummy, cond['crossattn'], self.p.image_conditioning, cond['vector'])
state.sampling_steps = steps
pbar = tqdm(total=steps, desc='Noise Inversion')
for i in range(1, len(sigmas)):
if state.interrupted:
return x
if state.interrupted: return x
state.sampling_step += 1
x_in = x
sigma_in = torch.cat([sigmas[i] * s_in])
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
t = dnw.sigma_to_t(sigma_in)