diff --git a/tile_methods/abstractdiffusion.py b/tile_methods/abstractdiffusion.py index 1b3d8c5..9b153ed 100644 --- a/tile_methods/abstractdiffusion.py +++ b/tile_methods/abstractdiffusion.py @@ -690,29 +690,33 @@ class AbstractDiffusion: x = self.p.init_latent s_in = x.new_ones([x.shape[0]]) - if shared.sd_model.parameterization == "v": - skip = 1 - else: - skip = 0 - cond = self.p.sd_model.get_learned_conditioning(prompts) - # NOTE: should be List[Tensor] - cond_dict_dummy = { - 'c_crossattn': [], - 'c_concat': [], - } - cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning) + skip = 1 if shared.sd_model.parameterization == "v" else 0 sigmas = dnw.get_sigmas(steps).flip(0) - state.sampling_steps = steps + + cond = self.p.sd_model.get_learned_conditioning(prompts) + if isinstance(cond, Tensor): # SD1/SD2 + cond_dict_dummy = { + 'c_crossattn': [], # List[Tensor] + 'c_concat': [], # List[Tensor] + } + cond_in = self.make_cond_dict(cond_dict_dummy, cond, self.p.image_conditioning) + else: # SDXL + cond_dict_dummy = { + 'crossattn': None, # Tensor + 'vector': None, # Tensor + 'c_concat': [], # List[Tensor] + } + cond_in = self.make_cond_dict(cond_dict_dummy, cond['crossattn'], self.p.image_conditioning, cond['vector']) + state.sampling_steps = steps pbar = tqdm(total=steps, desc='Noise Inversion') for i in range(1, len(sigmas)): - if state.interrupted: - return x + if state.interrupted: return x + state.sampling_step += 1 x_in = x sigma_in = torch.cat([sigmas[i] * s_in]) - c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]] t = dnw.sigma_to_t(sigma_in)