pull/1223/head
bmaltais 2023-07-12 08:41:00 -04:00
commit dfe847071f
1 changed files with 1 additions and 2 deletions

View File

@ -362,8 +362,7 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
# with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
if True:
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else: