More work on train_utils, integration of train_utils in
train_dreambooth building the pipeline up to the start of training.dev_rework
parent
a09bca43a2
commit
34e98b24ba
|
|
@ -1,6 +1,7 @@
|
|||
#IDE
|
||||
.idea
|
||||
.vscode/settings.json
|
||||
.vscode/extensions.json
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import torch
|
|||
import torch.backends.cuda
|
||||
import torch.backends.cudnn
|
||||
import torch.nn.functional as F
|
||||
from utils.train_utils import TrainUtils, get_model_dtypes, prepare_pipeline_for_inference, prepare_pipeline_for_training
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils.random import set_seed as set_seed2
|
||||
from diffusers import (
|
||||
|
|
@ -65,6 +66,7 @@ from helpers.ema_model import EMAModel
|
|||
from helpers.log_parser import LogParser
|
||||
from helpers.mytqdm import mytqdm
|
||||
from lora_diffusion.lora import (
|
||||
LoraInjectedConv2d,
|
||||
set_lora_requires_grad,
|
||||
)
|
||||
|
||||
|
|
@ -212,23 +214,18 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if not args.pad_tokens and args.max_token_length > 75:
|
||||
logger.warning("Cannot raise token length limit above 75 when pad_tokens=False")
|
||||
|
||||
verify_locon_installed(args)
|
||||
model_dtypes = TrainUtils.get_model_dtypes()
|
||||
|
||||
precision = args.mixed_precision if not shared.force_cpu else "no"
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
weight_dtype = model_dtypes["weight_dtype"]
|
||||
tenc_dtype = model_dtypes["tenc_dtype"]
|
||||
vae_dtype = model_dtypes["vae_dtype"]
|
||||
|
||||
try:
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
mixed_precision=precision,
|
||||
mixed_precision=weight_dtype,
|
||||
log_with="all",
|
||||
project_dir=logging_dir,
|
||||
cpu=shared.force_cpu,
|
||||
)
|
||||
|
||||
run_name = "dreambooth.events"
|
||||
|
|
@ -306,7 +303,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
)
|
||||
enable_safe_unpickle()
|
||||
new_vae.requires_grad_(False)
|
||||
new_vae.to(accelerator.device, dtype=weight_dtype)
|
||||
new_vae.to(accelerator.device, dtype=vae_dtype)
|
||||
return new_vae
|
||||
|
||||
disable_safe_unpickle()
|
||||
|
|
@ -344,7 +341,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=tenc_dtype,
|
||||
)
|
||||
|
||||
if args.model_type == "SDXL":
|
||||
|
|
@ -361,7 +358,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="text_encoder_2",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=tenc_dtype,
|
||||
)
|
||||
|
||||
printm("Created tenc")
|
||||
|
|
@ -376,49 +373,11 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
args.get_pretrained_model_name_or_path(),
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
torch_dtype=torch.float32,
|
||||
torch_dtype=tenc_dtype,
|
||||
)
|
||||
|
||||
if args.attention == "xformers" and not shared.force_cpu:
|
||||
xformerify(unet, use_lora=args.use_lora)
|
||||
xformerify(vae, use_lora=args.use_lora)
|
||||
|
||||
unet = torch2ify(unet)
|
||||
|
||||
if args.full_mixed_precision:
|
||||
if args.mixed_precision == "fp16":
|
||||
patch_accelerator_for_fp16_training(accelerator)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
# Check that all trainable models are in full precision
|
||||
low_precision_error_string = (
|
||||
"Please make sure to always have all model weights in full float32 precision when starting training - "
|
||||
"even if doing mixed precision training. copy of the weights should still be float32."
|
||||
)
|
||||
|
||||
if accelerator.unwrap_model(unet).dtype != torch.float32:
|
||||
logger.warning(
|
||||
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
|
||||
)
|
||||
|
||||
if (
|
||||
args.stop_text_encoder != 0
|
||||
and accelerator.unwrap_model(text_encoder).dtype != torch.float32
|
||||
):
|
||||
logger.warning(
|
||||
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
|
||||
f" {low_precision_error_string}"
|
||||
)
|
||||
|
||||
if (
|
||||
args.stop_text_encoder != 0
|
||||
and accelerator.unwrap_model(text_encoder_two).dtype != torch.float32
|
||||
):
|
||||
logger.warning(
|
||||
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder_two).dtype}."
|
||||
f" {low_precision_error_string}"
|
||||
)
|
||||
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
if args.train_unet:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
|
@ -432,10 +391,37 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(True)
|
||||
if args.model_type == "SDXL":
|
||||
text_encoder_two.text_model.embeddings.position_embedding.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if vae is None:
|
||||
printm("Loading vae.")
|
||||
vae = create_vae()
|
||||
|
||||
printm("Creating pipeline.")
|
||||
if args.model_type == "SDXL":
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
s_pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_two,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
)
|
||||
else:
|
||||
s_pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_two,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
)
|
||||
s_pipeline = TrainUtils.prepare_pipeline_for_inference(s_pipeline, accelerator)
|
||||
unet = s_pipeline.unet
|
||||
text_encoder = s_pipeline.text_encoder
|
||||
if args.model_type == "SDXL":
|
||||
text_encoder_two = s_pipeline.text_encoder_two
|
||||
vae = s_pipeline.vae
|
||||
accelerator = s_pipeline.accelerator
|
||||
|
||||
ema_model = None
|
||||
if args.use_ema:
|
||||
|
|
@ -452,12 +438,11 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if args.attention == "xformers" and not shared.force_cpu:
|
||||
xformerify(ema_unet, use_lora=args.use_lora)
|
||||
|
||||
|
||||
ema_model = EMAModel(
|
||||
ema_unet, device=accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
ema_unet.to(accelerator.device, dtype=weight_dtype)
|
||||
del ema_unet
|
||||
else:
|
||||
ema_model = EMAModel(
|
||||
|
|
@ -468,6 +453,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
|
||||
learning_rate = args.learning_rate
|
||||
txt_learning_rate = args.txt_learning_rate
|
||||
|
||||
if args.use_lora:
|
||||
learning_rate = args.lora_learning_rate
|
||||
txt_learning_rate = args.lora_txt_learning_rate
|
||||
|
|
@ -517,12 +503,12 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
if stop_text_percentage != 0:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder, dtype=torch.float32, rank=args.lora_txt_rank
|
||||
text_encoder, dtype=tenc_dtype, rank=args.lora_txt_rank
|
||||
)
|
||||
|
||||
if args.model_type == "SDXL":
|
||||
text_encoder_lora_params_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.lora_txt_rank
|
||||
text_encoder_two, dtype=tenc_dtype, rank=args.lora_txt_rank
|
||||
)
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_params, text_encoder_lora_params, text_encoder_lora_params_two))
|
||||
|
|
@ -532,6 +518,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
else:
|
||||
params_to_optimize = unet_lora_params
|
||||
|
||||
|
||||
# Load LoRA weights if specified
|
||||
if args.lora_model_name is not None and args.lora_model_name != "":
|
||||
logger.debug(f"Load lora from {args.lora_model_name}")
|
||||
|
|
@ -542,7 +529,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder)
|
||||
if text_encoder_two is not None:
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two)
|
||||
LoraInjectedConv2d. lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two)
|
||||
|
||||
|
||||
elif stop_text_percentage != 0:
|
||||
|
|
@ -560,7 +547,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
else:
|
||||
params_to_optimize = unet.parameters()
|
||||
|
||||
optimizer = get_optimizer(args.optimizer, learning_rate, args.weight_decay, params_to_optimize)
|
||||
accelerator.optimizer = get_optimizer(args.optimizer, learning_rate, args.weight_decay, params_to_optimize)
|
||||
optimizer = accelerator.optimizer
|
||||
if len(optimizer.param_groups) > 1:
|
||||
try:
|
||||
optimizer.param_groups[1]["weight_decay"] = args.tenc_weight_decay
|
||||
|
|
@ -576,9 +564,11 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
except:
|
||||
logger.warning("Exception setting tenc weight decay")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
noise_scheduler = get_noise_scheduler(args)
|
||||
global to_delete
|
||||
# TODO: Do we need this??
|
||||
to_delete = [unet, text_encoder, text_encoder_two, tokenizer, tokenizer_two, optimizer, vae]
|
||||
def cleanup_memory():
|
||||
try:
|
||||
|
|
@ -608,11 +598,6 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
pass
|
||||
cleanup(True)
|
||||
|
||||
if args.cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
if status.interrupted:
|
||||
result.msg = "Training interrupted."
|
||||
stop_profiler(profiler)
|
||||
|
|
@ -824,11 +809,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if not args.cache_latents and vae is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if stop_text_percentage == 0:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=tenc_dtype)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers will initialize automatically on the main process.
|
||||
|
|
@ -887,7 +869,7 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
logger.debug(f" Resuming from checkpoint: {resume_from_checkpoint}")
|
||||
logger.debug(f" First resume epoch: {first_epoch}")
|
||||
logger.debug(f" First resume step: {resume_step}")
|
||||
logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {precision}")
|
||||
logger.debug(f" Lora: {args.use_lora}, Optimizer: {args.optimizer}, Prec: {args.mixed_precision}")
|
||||
logger.debug(f" Gradient Checkpointing: {args.gradient_checkpointing}")
|
||||
logger.debug(f" EMA: {args.use_ema}")
|
||||
logger.debug(f" UNET: {args.train_unet}")
|
||||
|
|
@ -1028,7 +1010,6 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
torch_dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
)
|
||||
xformerify(s_pipeline.unet,use_lora=args.use_lora)
|
||||
else:
|
||||
s_pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.get_pretrained_model_name_or_path(),
|
||||
|
|
@ -1040,8 +1021,11 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
torch_dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
)
|
||||
xformerify(s_pipeline.unet,use_lora=args.use_lora)
|
||||
xformerify(s_pipeline.vae,use_lora=args.use_lora)
|
||||
|
||||
train_utils = TrainUtils(pipeline=s_pipeline, accelerator=accelerator)
|
||||
train_utils.get_model_dtypes()
|
||||
s_pipeline = train_utils.prepare_pipeline_for_inference()
|
||||
|
||||
|
||||
weights_dir = args.get_pretrained_model_name_or_path()
|
||||
|
||||
|
|
@ -1545,6 +1529,8 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
s_pipeline = TrainUtils.prepare_pipeline_for_training()
|
||||
|
||||
if args.model_type != "SDXL":
|
||||
# TODO: set a prior preservation flag and use that to ensure this ony happens in dreambooth
|
||||
if not args.split_loss and not with_prior_preservation:
|
||||
|
|
|
|||
|
|
@ -1,48 +1,34 @@
|
|||
|
||||
from pyexpat import model
|
||||
from anyio import get_all_backends
|
||||
from httpx import get
|
||||
from numpy import save
|
||||
import torch
|
||||
import torch.backends.cuda
|
||||
import torch.backends.cudnn
|
||||
|
||||
from dreambooth import shared
|
||||
from dreambooth.dataclasses.db_config import DreamboothConfig
|
||||
from dreambooth.shared import DreamState
|
||||
from dreambooth.shared import db_model_config
|
||||
from dreambooth.utils.model_utils import (
|
||||
disable_safe_unpickle,
|
||||
enable_safe_unpickle,
|
||||
xformerify,
|
||||
torch2ify,
|
||||
)
|
||||
from dreambooth.utils.text_utils import encode_hidden_state
|
||||
from dreambooth.utils.utils import (cleanup, printm,)
|
||||
from dreambooth.webhook import send_training_update
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor2_0, LoRAAttnProcessor
|
||||
from diffusers.utils import logging as dl, randn_tensor
|
||||
from torch.cuda.profiler import profile
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
from dreambooth.utils.model_utils import xformerify
|
||||
|
||||
|
||||
class TrainUtils:
|
||||
def __init__(self, pipeline, accelerator):
|
||||
self.pipeline = pipeline
|
||||
self.accelerator = accelerator
|
||||
self.model_config = DreamboothConfig()
|
||||
self.model_config = DreamboothConfig
|
||||
self.model_type = self.model_config.model_type
|
||||
self.model_path = self.model_config.model_path
|
||||
self.precision = self.model_config.mixed_precision
|
||||
self.save_lora =self.model_config.save_lora_during
|
||||
self.save_lora = self.model_config.save_lora_during
|
||||
self.save_checkpoint = self.model_config.save_ckpt_during
|
||||
self.save_difusers = self.model_config.save_state_during
|
||||
model_dtypes = self.get_model_dtypes()
|
||||
self.tenc_dtype = model_dtypes["tenc_dtype"]
|
||||
self.weight_dtype = model_dtypes["weight_dtype"]
|
||||
self.vae_dtype = model_dtypes["vae_dtype"]
|
||||
self.max_train_epochs = self.model_config.num_train_epochs
|
||||
self.stop_text_percentage = self.model_config.stop_text_encoder
|
||||
self.text_encoder_epochs = round(self.max_train_epochs * self.stop_text_percentage)
|
||||
self.train_unet = self.model_config.train_unet
|
||||
self.optimizer = accelerator.optimizer
|
||||
|
||||
def get_model_dtypes(self, precision=None, model_type=None):
|
||||
precision = self.precision if precision is None else precision
|
||||
|
|
@ -71,7 +57,22 @@ class TrainUtils:
|
|||
self.vae_dtype = vae_dtype
|
||||
|
||||
return model_dtypes
|
||||
|
||||
|
||||
|
||||
""" TODO Use this to pass the params in
|
||||
def prepare_accelerator():
|
||||
return
|
||||
def prepare_optimizer():
|
||||
return
|
||||
def prepare_scheduler():
|
||||
return
|
||||
|
||||
def get_accelerator():
|
||||
return self.accelerator
|
||||
|
||||
def get_optimizer():
|
||||
return self.accelerator.optimizer
|
||||
"""
|
||||
def prepare_pipeline_for_inference(self, pipeline=None):
|
||||
accelerator = self.accelerator
|
||||
pipeline = self.pipeline if pipeline is None else pipeline
|
||||
|
|
@ -79,7 +80,7 @@ class TrainUtils:
|
|||
weight_dtype = self.weight_dtype
|
||||
tenc_dtype = self.tenc_dtype
|
||||
vae_dtype = self.vae_dtype
|
||||
|
||||
|
||||
# Send all the models to the same device with correct dtypes
|
||||
pipeline.unet.to(accelerator.device, weight_dtype)
|
||||
pipeline.text_encoder.to(accelerator.device, tenc_dtype)
|
||||
|
|
@ -95,7 +96,7 @@ class TrainUtils:
|
|||
pipeline.vae.eval()
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def prepare_pipeline_for_training(self, pipeline=None):
|
||||
accelerator = self.accelerator
|
||||
pipeline = self.pipeline if pipeline is None else pipeline
|
||||
|
|
@ -112,24 +113,39 @@ class TrainUtils:
|
|||
if model_type == "SDXL":
|
||||
pipeline.text_encoder_two.to(accelerator.device, dtype=tenc_dtype)
|
||||
|
||||
|
||||
#xformerify the unet and vae in the pipeline for some sweet crossattention
|
||||
xformerify(pipeline.unet, self.model_config.use_lora)
|
||||
xformerify(pipeline.vae, self.model_config.use_lora)
|
||||
|
||||
|
||||
# Get the models ready for training
|
||||
# TODO: Add logic to restore correct state according to model type
|
||||
# and tenc_training state
|
||||
pipeline.unet.train()
|
||||
pipeline.text_encoder.train()
|
||||
if model_type == "SDXL":
|
||||
pipeline.text_encoder_two.train()
|
||||
pipeline.vae.train()
|
||||
if self.model_config.cache_latents:
|
||||
pipeline.vae.requires_grad_(False)
|
||||
pipeline.vae.eval()
|
||||
if self.model_config.freeze_clip_normalization:
|
||||
pipeline.text_encoder.eval()
|
||||
if self.model_config.model_type == "SDXL":
|
||||
pipeline.text_encoder_two.eval()
|
||||
else:
|
||||
# and tenc_training state
|
||||
pipeline.unet.train()
|
||||
pipeline.text_encoder.train()
|
||||
if model_type == "SDXL":
|
||||
pipeline.text_encoder_two.train()
|
||||
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def save_pipeline(self, pipeline=None):
|
||||
"""
|
||||
Save the currrent pipeline state to lora and/or checkpoint
|
||||
"""
|
||||
#pipeline = self.prepare_pipeline_for_inference(self.pipeline) \
|
||||
# if pipeline is None or self.prepare_pipeline_for_inference(pipeline)
|
||||
if pipeline is None:
|
||||
pipeline = self.prepare_pipeline_for_inference(self.pipeline)
|
||||
else:
|
||||
pipeline = self.prepare_pipeline_for_inference(pipeline)
|
||||
|
||||
model_type = self.model_type
|
||||
save_lora = self.save_lora
|
||||
|
|
@ -143,36 +159,41 @@ class TrainUtils:
|
|||
# Do lora stuff
|
||||
if model_type == "SDXL":
|
||||
# Do SDXL Stuff for lora
|
||||
pass
|
||||
except:
|
||||
print("Error saving lora")
|
||||
success = False
|
||||
|
||||
pass
|
||||
|
||||
if save_checkpoint:
|
||||
try:
|
||||
# Do checkpoint stuff
|
||||
if model_type == "SDXL":
|
||||
# Do SDXL Stuff for checkpoints
|
||||
# Do SDXL Stuff for checkpoints
|
||||
pass
|
||||
except:
|
||||
print("Error saving checkpoints")
|
||||
success = False
|
||||
pass
|
||||
|
||||
if save_diffusers:
|
||||
try:
|
||||
# Do difuser stuff
|
||||
if model_type == "SDXL":
|
||||
# Do SDXL Stuff for difusers
|
||||
# Do SDXL Stuff for difusers
|
||||
pass
|
||||
except:
|
||||
print("Error saving difusers")
|
||||
# Keep track of error and continue
|
||||
success
|
||||
|
||||
success = False
|
||||
pass
|
||||
|
||||
# We done saving
|
||||
# Get the pipeline ready to resume training
|
||||
self.pipeline = self.prepare_pipeline_for_training(pipeline)
|
||||
# return false if anything went wrong so we can handle it externally
|
||||
return success
|
||||
|
||||
|
||||
|
||||
def save_samples(self, pipeline):
|
||||
"""
|
||||
Save sample images using current pipeline
|
||||
|
|
@ -180,11 +201,12 @@ class TrainUtils:
|
|||
pipeline = self.pipeline or pipeline
|
||||
# Get the pipeline ready for inference
|
||||
pipeline = self.prepare_pipeline_for_inference(pipeline)
|
||||
|
||||
|
||||
success = True
|
||||
try:
|
||||
# Image inference stuff goes here
|
||||
images = pipeline.save_image()
|
||||
pass
|
||||
except:
|
||||
print("Error saving image")
|
||||
success = False
|
||||
|
|
@ -193,39 +215,42 @@ class TrainUtils:
|
|||
self.pipeline = self.prepare_pipeline_for_training(pipeline)
|
||||
return images or success
|
||||
|
||||
def save_training_state(self, accelerator=None, pipeline=None, model_path=None, model_config=None):
|
||||
""" def save_training_state(self, accelerator=None, pipeline=None, model_path=None, model_config=None):
|
||||
accelerator = self.accelerator if accelerator is None else accelerator
|
||||
model_path = self.model_path
|
||||
pipeline = self.pipeline if pipeline is None else pipeline
|
||||
model_config = self.model_config if model_config is None else model_config
|
||||
|
||||
|
||||
# Save the optimizer and scheduler states to model_path
|
||||
accelerator.save(self.pipeline.unet.optimizer.state_dict(), f"{model_path}/optimizer.pt")
|
||||
accelerator.save(self.pipeline.unet.scheduler.state_dict(), f"{model_path}/scheduler.pt")
|
||||
accelerator.save(self.pipeline.unet.optimizer.state_dict(),
|
||||
f"{model_path}/optimizer.pt")
|
||||
accelerator.save(self.pipeline.unet.scheduler.state_dict(),
|
||||
f"{model_path}/scheduler.pt")
|
||||
# Save the accelerator state to model_path
|
||||
accelerator.save(accelerator.state_dict(), f"{model_path}/accelerator.pt")
|
||||
accelerator.save(accelerator.state_dict(),
|
||||
f"{model_path}/accelerator.pt")
|
||||
# Save the pipeline state to model_path
|
||||
accelerator.save(self.pipeline.state_dict(), f"{model_path}/pipeline.pt")
|
||||
accelerator.save(self.pipeline.state_dict(),
|
||||
f"{model_path}/pipeline.pt")
|
||||
# Save the model config to model_path
|
||||
accelerator.save(shared.model_config, f"{model_path}/model_config.pt")
|
||||
return True
|
||||
return True """
|
||||
|
||||
def load_training_state(self, accelerator=None, pipeline=None, model_path=None, model_config=None):
|
||||
""" def load_training_state(self, accelerator=None, pipeline=None, model_path=None, model_config=None):
|
||||
accelerator = self.accelerator if accelerator is None else accelerator
|
||||
model_path = self.model_path if model_path is None else model_path
|
||||
pipeline = self.pipeline if pipeline is None else pipeline
|
||||
model_config = self.model_config if model_config is None else model_config
|
||||
|
||||
# Load all the .pt files located in model_path back into their relevant objects
|
||||
accelerator.
|
||||
optimizer.load_state_dict(accelerator.load(f"{model_path}/optimizer.pt"))
|
||||
pipeline.unet.scheduler.load_state_dict(accelerator.load(f"{model_path}/scheduler.pt"))
|
||||
accelerator.load_state_dict(accelerator.load(f"{model_path}/accelerator.pt"))
|
||||
|
||||
optimizer.load_state_dict(
|
||||
accelerator.load(f"{model_path}/optimizer.pt"))
|
||||
pipeline.unet.scheduler.load_state_dict(
|
||||
accelerator.load(f"{model_path}/scheduler.pt"))
|
||||
accelerator.load_state_dict(
|
||||
accelerator.load(f"{model_path}/accelerator.pt"))
|
||||
pipeline.load_state_dict(accelerator.load(f"{model_path}/pipeline.pt"))
|
||||
model_config = accelerator.load(f"{model_path}/model_config.pt")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
|
||||
return True """
|
||||
|
|
|
|||
Loading…
Reference in New Issue