cleanup and some more safety
parent
b338aaa437
commit
3aa30df1e4
|
|
@ -939,10 +939,12 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
|
|||
load_training_options = load_training_options.split(',')
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
for load_hypernetworks_option in load_hypernetworks_options:
|
||||
load_hypernetworks_option = load_hypernetworks_option.strip(' ')
|
||||
if get_training_option(load_hypernetworks_option) is False:
|
||||
print(f"Cannot load from {load_hypernetworks_option}!")
|
||||
continue
|
||||
for load_training_option in load_training_options:
|
||||
load_training_option = load_training_option.strip(' ')
|
||||
if get_training_option(load_training_option) is False:
|
||||
print(f"Cannot load from {load_training_option}!")
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -319,10 +319,6 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
del x
|
||||
_loss_step += loss.item()
|
||||
scaler.scale(loss).backward()
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
if param.grad is None:
|
||||
print("Found no grad!")
|
||||
# go back until we reach gradient accumulation steps
|
||||
if (j + 1) % gradient_step != 0:
|
||||
continue
|
||||
|
|
|
|||
Loading…
Reference in New Issue