Merge https://github.com/d8ahazard/sd_dreambooth_extension into TrainingOverhaul
commit
672cc66cd4
|
|
@ -315,7 +315,11 @@ class DreamboothConfig(BaseConfig):
|
|||
def get_pretrained_model_name_or_path(self):
|
||||
if self.shared_diffusers_path != "" and not self.use_lora:
|
||||
raise Exception(f"shared_diffusers_path is \"{self.shared_diffusers_path}\" but use_lora is false")
|
||||
return self.shared_diffusers_path if self.shared_diffusers_path != "" else self.pretrained_model_name_or_path
|
||||
if self.shared_diffusers_path != "":
|
||||
return self.shared_diffusers_path
|
||||
if not self.pretrained_model_name_or_path or self.pretrained_model_name_or_path == "":
|
||||
return os.path.join(self.model_dir, "working")
|
||||
return self.pretrained_model_name_or_path
|
||||
|
||||
def load_from_file(self, model_dir=None):
|
||||
"""
|
||||
|
|
@ -364,6 +368,7 @@ def concepts_from_file(concepts_path: str):
|
|||
concepts.append(concept.__dict__)
|
||||
except Exception as e:
|
||||
print(f"Exception parsing concepts: {e}")
|
||||
print(f"Loaded concepts: {concepts}")
|
||||
return concepts
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ class ClassDataset(Dataset):
|
|||
|
||||
for concept_idx, concept in enumerate(concepts):
|
||||
pbar.set_description(f"Processing concept {concept_idx + 1}/{len(concepts)}")
|
||||
instance_dir = concept.instance_data_dir
|
||||
|
||||
if not concept.is_valid:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -238,6 +238,79 @@ def main(args: TrainingConfig, user: str = None) -> TrainResult:
|
|||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
msg = f"Exception initializing accelerator: {e}"
|
||||
logger.warning(msg)
|
||||
result.msg = msg
|
||||
result.config = args
|
||||
stop_profiler(profiler)
|
||||
return result
|
||||
|
||||
# This is the secondary status bar
|
||||
pbar2 = mytqdm(
|
||||
disable=not accelerator.is_local_main_process,
|
||||
position=1,
|
||||
user=user,
|
||||
target="dreamProgress",
|
||||
index=1
|
||||
)
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with
|
||||
# accelerate.accumulate This will be enabled soon in accelerate. For now, we don't allow gradient
|
||||
# accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if (
|
||||
stop_text_percentage != 0
|
||||
and gradient_accumulation_steps > 1
|
||||
and accelerator.num_processes > 1
|
||||
):
|
||||
msg = (
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future. Text "
|
||||
"encoder training will be disabled."
|
||||
)
|
||||
logger.warning(msg)
|
||||
status.textinfo = msg
|
||||
update_status({"status": msg})
|
||||
stop_text_percentage = 0
|
||||
pretrained_path = args.get_pretrained_model_name_or_path()
|
||||
logger.debug(f"Pretrained path: {pretrained_path}")
|
||||
|
||||
count, instance_prompts, class_prompts = generate_classifiers(
|
||||
args, class_gen_method=class_gen_method, accelerator=accelerator, ui=False, pbar=pbar2
|
||||
)
|
||||
pbar2.reset()
|
||||
if status.interrupted:
|
||||
result.msg = "Training interrupted."
|
||||
stop_profiler(profiler)
|
||||
return result
|
||||
|
||||
if class_gen_method == "Native Diffusers" and count > 0:
|
||||
unload_system_models()
|
||||
|
||||
def create_vae():
|
||||
vae_path = (
|
||||
args.pretrained_vae_name_or_path
|
||||
if args.pretrained_vae_name_or_path
|
||||
else args.get_pretrained_model_name_or_path()
|
||||
)
|
||||
disable_safe_unpickle()
|
||||
new_vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder=None if args.pretrained_vae_name_or_path else "vae",
|
||||
revision=args.revision,
|
||||
)
|
||||
enable_safe_unpickle()
|
||||
new_vae.requires_grad_(False)
|
||||
new_vae.to(accelerator.device, dtype=weight_dtype)
|
||||
return new_vae
|
||||
|
||||
disable_safe_unpickle()
|
||||
# Load the tokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
os.path.join(pretrained_path, "tokenizer"),
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
del validation_pipeline
|
||||
cleanup()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
accelerate~=0.19.0
|
||||
accelerate~=0.21.0
|
||||
bitsandbytes==0.35.4
|
||||
dadaptation==3.1
|
||||
diffusers~=0.19.0
|
||||
diffusers~=0.19.3
|
||||
discord-webhook~=1.1.0
|
||||
fastapi~=0.94.1
|
||||
gitpython~=3.1.31
|
||||
|
|
|
|||
Loading…
Reference in New Issue