pull/699/head
bmaltais 2023-04-27 09:03:59 -04:00
commit 7bd1cb9d08
18 changed files with 1640 additions and 285 deletions

View File

@ -305,6 +305,15 @@ This will store a backup file with your current locally installed pip packages a
## Change History ## Change History
* 2023/04/25 (v21.5.7)
- `tag_images_by_wd14_tagger.py` can now get arguments from outside. [PR #453](https://github.com/kohya-ss/sd-scripts/pull/453) Thanks to mio2333!
- Added `--save_every_n_steps` option to each training script. The model is saved every specified steps.
- `--save_last_n_steps` option can be used to save only the specified number of models (old models will be deleted).
- If you specify the `--save_state` option, the state will also be saved at the same time. You can specify the number of steps to keep the state with the `--save_last_n_steps_state` option (the same value as `--save_last_n_steps` is used if omitted).
- You can use the epoch-based model saving and state saving options together.
- Not tested in multi-GPU environment. Please report any bugs.
- `--cache_latents_to_disk` option automatically enables `--cache_latents` option when specified. [#438](https://github.com/kohya-ss/sd-scripts/issues/438)
- Fixed a bug in `gen_img_diffusers.py` where latents upscaler would fail with a batch size of 2 or more.
* 2023/04/24 (v21.5.6) * 2023/04/24 (v21.5.6)
- Fix triton error - Fix triton error
- Fix issue with merge lora path with spaces - Fix issue with merge lora path with spaces
@ -337,13 +346,3 @@ This will store a backup file with your current locally installed pip packages a
- Implemented DyLoRA GUI support. There will now be a new 'DyLoRA Unit` slider when the LoRA type is selected as `kohya DyLoRA` to specify the desired Unit value for DyLoRA training. - Implemented DyLoRA GUI support. There will now be a new 'DyLoRA Unit` slider when the LoRA type is selected as `kohya DyLoRA` to specify the desired Unit value for DyLoRA training.
- Update gui.bat and gui.ps1 based on: https://github.com/bmaltais/kohya_ss/issues/188 - Update gui.bat and gui.ps1 based on: https://github.com/bmaltais/kohya_ss/issues/188
- Update `setup.bat` to install torch 2.0.0 instead of 1.2.1. If you want to upgrade from 1.2.1 to 2.0.0 run setup.bat again, select 1 to uninstall the previous torch modules, then select 2 for torch 2.0.0 - Update `setup.bat` to install torch 2.0.0 instead of 1.2.1. If you want to upgrade from 1.2.1 to 2.0.0 run setup.bat again, select 1 to uninstall the previous torch modules, then select 2 for torch 2.0.0
* 2023/04/09 (v21.5.2)
- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution!
- Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI.
- This option is also applicable to token strings of the DreamBooth method.
- The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible.
- If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses.
- Run gui.sh from any place

View File

@ -108,7 +108,11 @@ def save_configuration(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,weighted_captions, min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -217,7 +221,11 @@ def open_configuration(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,weighted_captions, min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -308,7 +316,11 @@ def train_model(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,weighted_captions, min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -333,9 +345,12 @@ def train_model(
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
if optimizer == 'Adafactor' and lr_warmup != '0': if optimizer == 'Adafactor' and lr_warmup != '0':
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") msgbox(
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
title='Warning',
)
lr_warmup = '0' lr_warmup = '0'
# Get a list of all subfolders in train_data_dir, excluding hidden folders # Get a list of all subfolders in train_data_dir, excluding hidden folders
@ -525,6 +540,9 @@ def train_model(
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma, min_snr_gamma=min_snr_gamma,
save_every_n_steps=save_every_n_steps,
save_last_n_steps=save_last_n_steps,
save_last_n_steps_state=save_last_n_steps_state,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -736,6 +754,9 @@ def dreambooth_tab(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -839,6 +860,9 @@ def dreambooth_tab(
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
weighted_captions, weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
] ]
button_open_config.click( button_open_config.click(

View File

@ -275,7 +275,7 @@ def train(args):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype) latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
else: else:
# latentに変換 # latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@ -285,18 +285,19 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder): with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning # Get the text embedding for conditioning
if args.weighted_captions: if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer, encoder_hidden_states = get_weighted_text_embeddings(
text_encoder, tokenizer,
batch["captions"], text_encoder,
accelerator.device, batch["captions"],
args.max_token_length // 75 if args.max_token_length else 1, accelerator.device,
clip_skip=args.clip_skip, args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
) )
else: else:
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
) )
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
@ -351,6 +352,27 @@ def train(args):
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
) )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@ -376,21 +398,23 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path if accelerator.is_main_process:
train_util.save_sd_model_on_epoch_end( src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
args, train_util.save_sd_model_on_epoch_end_or_stepwise(
accelerator, args,
src_path, True,
save_stable_diffusion_format, accelerator,
use_safetensors, src_path,
save_dtype, save_stable_diffusion_format,
epoch, use_safetensors,
num_train_epochs, save_dtype,
global_step, epoch,
unwrap_model(text_encoder), num_train_epochs,
unwrap_model(unet), global_step,
vae, unwrap_model(text_encoder),
) unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
@ -401,7 +425,7 @@ def train(args):
accelerator.end_training() accelerator.end_training()
if args.save_state: if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator) train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す del accelerator # この後メモリを使うのでこれは消す
@ -437,4 +461,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)
train(args) train(args)

View File

@ -224,7 +224,7 @@ def main(args):
print("done!") print("done!")
if __name__ == "__main__": def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument( parser.add_argument(
@ -284,6 +284,11 @@ if __name__ == "__main__":
) )
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args() args = parser.parse_args()
# スペルミスしていたオプションを復元する # スペルミスしていたオプションを復元する

View File

@ -107,7 +107,11 @@ def save_configuration(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,weighted_captions, min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -222,7 +226,11 @@ def open_configuration(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,weighted_captions, min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -319,13 +327,20 @@ def train_model(
sample_prompts, sample_prompts,
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma,weighted_captions, min_snr_gamma,
weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
if optimizer == 'Adafactor' and lr_warmup != '0': if optimizer == 'Adafactor' and lr_warmup != '0':
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") msgbox(
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
title='Warning',
)
lr_warmup = '0' lr_warmup = '0'
# create caption json file # create caption json file
@ -487,6 +502,9 @@ def train_model(
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma, min_snr_gamma=min_snr_gamma,
save_every_n_steps=save_every_n_steps,
save_last_n_steps=save_last_n_steps,
save_last_n_steps_state=save_last_n_steps_state,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -709,6 +727,9 @@ def finetune_tab():
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -806,6 +827,9 @@ def finetune_tab():
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
weighted_captions, weighted_captions,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
] ]
button_run.click(train_model, inputs=settings_list) button_run.click(train_model, inputs=settings_list)

View File

@ -945,7 +945,7 @@ class PipelineLike:
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype) init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[1:] == (height // 8, width // 8): if init_image.size()[-2:] == (height // 8, width // 8):
init_latents = init_image init_latents = init_image
else: else:
if vae_batch_size >= batch_size: if vae_batch_size >= batch_size:

View File

@ -21,33 +21,65 @@ import shutil
import logging import logging
import subprocess import subprocess
log = logging.getLogger("sd") log = logging.getLogger('sd')
# setup console and file logging # setup console and file logging
def setup_logging(clean=False): def setup_logging(clean=False):
try: try:
if clean and os.path.isfile('setup.log'): if clean and os.path.isfile('setup.log'):
os.remove('setup.log') os.remove('setup.log')
time.sleep(0.1) # prevent race condition time.sleep(0.1) # prevent race condition
except: except:
pass pass
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s', filename='setup.log', filemode='a', encoding='utf-8', force=True) logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s',
filename='setup.log',
filemode='a',
encoding='utf-8',
force=True,
)
from rich.theme import Theme from rich.theme import Theme
from rich.logging import RichHandler from rich.logging import RichHandler
from rich.console import Console from rich.console import Console
from rich.pretty import install as pretty_install from rich.pretty import install as pretty_install
from rich.traceback import install as traceback_install from rich.traceback import install as traceback_install
console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({
"traceback.border": "black", console = Console(
"traceback.border.syntax_error": "black", log_time=True,
"inspect.value.border": "black", log_time_format='%H:%M:%S-%f',
})) theme=Theme(
{
'traceback.border': 'black',
'traceback.border.syntax_error': 'black',
'inspect.value.border': 'black',
}
),
)
pretty_install(console=console) pretty_install(console=console)
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=[]) traceback_install(
rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=logging.DEBUG if args.debug else logging.INFO, console=console) console=console,
extra_lines=1,
width=console.width,
word_wrap=False,
indent_guides=False,
suppress=[],
)
rh = RichHandler(
show_time=True,
omit_repeated_times=False,
show_level=True,
show_path=False,
markup=False,
rich_tracebacks=True,
log_time_format='%H:%M:%S-%f',
level=logging.DEBUG if args.debug else logging.INFO,
console=console,
)
rh.set_name(logging.DEBUG if args.debug else logging.INFO) rh.set_name(logging.DEBUG if args.debug else logging.INFO)
log.addHandler(rh) log.addHandler(rh)
def UI(**kwargs): def UI(**kwargs):
css = '' css = ''
@ -56,7 +88,9 @@ def UI(**kwargs):
print('Load CSS...') print('Load CSS...')
css += file.read() + '\n' css += file.read() + '\n'
interface = gr.Blocks(css=css, title='Kohya_ss GUI', theme=gr.themes.Default()) interface = gr.Blocks(
css=css, title='Kohya_ss GUI', theme=gr.themes.Default()
)
with interface: with interface:
with gr.Tab('Dreambooth'): with gr.Tab('Dreambooth'):

View File

@ -803,6 +803,16 @@ def gradio_advanced_training():
label='Additional parameters', label='Additional parameters',
placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"', placeholder='(Optional) Use to provide additional parameters not handled by the GUI. Eg: --some_parameters "value"',
) )
with gr.Row():
save_every_n_steps = gr.Number(
label='Save every N steps', value=0, precision=0, info='(Optional) The model is saved every specified steps'
)
save_last_n_steps = gr.Number(
label='Save last N steps', value=0, precision=0, info='(Optional) Save only the specified number of models (old models will be deleted)'
)
save_last_n_steps_state = gr.Number(
label='Save last N steps', value=0, precision=0, info='(Optional) Save only the specified number of states (old models will be deleted)'
)
with gr.Row(): with gr.Row():
keep_tokens = gr.Slider( keep_tokens = gr.Slider(
label='Keep n tokens', value='0', minimum=0, maximum=32, step=1 label='Keep n tokens', value='0', minimum=0, maximum=32, step=1
@ -917,6 +927,9 @@ def gradio_advanced_training():
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
) )
@ -952,6 +965,15 @@ def run_cmd_advanced_training(**kwargs):
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}' f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
if int(kwargs.get('bucket_reso_steps', 64)) >= 1 if int(kwargs.get('bucket_reso_steps', 64)) >= 1
else '', else '',
f' --save_every_n_steps="{int(kwargs.get("save_every_n_steps", 0))}"'
if int(kwargs.get('save_every_n_steps')) > 0
else '',
f' --save_last_n_steps="{int(kwargs.get("save_last_n_steps", 0))}"'
if int(kwargs.get('save_last_n_steps')) > 0
else '',
f' --save_last_n_steps_state="{int(kwargs.get("save_last_n_steps_state", 0))}"'
if int(kwargs.get('save_last_n_steps_state')) > 0
else '',
f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}' f' --min_snr_gamma={int(kwargs.get("min_snr_gamma", 0))}'
if int(kwargs.get('min_snr_gamma', 0)) >= 1 if int(kwargs.get('min_snr_gamma', 0)) >= 1
else '', else '',

View File

@ -74,6 +74,11 @@ LAST_STATE_NAME = "{}-state"
DEFAULT_EPOCH_NAME = "epoch" DEFAULT_EPOCH_NAME = "epoch"
DEFAULT_LAST_OUTPUT_NAME = "last" DEFAULT_LAST_OUTPUT_NAME = "last"
DEFAULT_STEP_NAME = "at"
STEP_STATE_NAME = "{}-step{:08d}-state"
STEP_FILE_NAME = "{}-step{:08d}"
STEP_DIFFUSERS_DIR_NAME = "{}-step{:08d}"
# region dataset # region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
@ -1986,18 +1991,38 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument( parser.add_argument(
"--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する" "--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する"
) )
parser.add_argument(
"--save_every_n_steps", type=int, default=None, help="save checkpoint every N steps / 学習中のモデルを指定ステップごとに保存する"
)
parser.add_argument( parser.add_argument(
"--save_n_epoch_ratio", "--save_n_epoch_ratio",
type=int, type=int,
default=None, default=None,
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存するたとえば5を指定すると最低5個のファイルが保存される", help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存するたとえば5を指定すると最低5個のファイルが保存される",
) )
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument(
"--save_last_n_epochs",
type=int,
default=None,
help="save last N checkpoints when saving every N epochs (remove older checkpoints) / 指定エポックごとにモデルを保存するとき最大Nエポック保存する古いチェックポイントは削除する",
)
parser.add_argument( parser.add_argument(
"--save_last_n_epochs_state", "--save_last_n_epochs_state",
type=int, type=int,
default=None, default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)", help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する--save_last_n_epochsの指定を上書きする",
)
parser.add_argument(
"--save_last_n_steps",
type=int,
default=None,
help="save checkpoints until N steps elapsed (remove older checkpoints if N steps elapsed) / 指定ステップごとにモデルを保存するとき、このステップ数経過するまで保存する(このステップ数経過したら削除する)",
)
parser.add_argument(
"--save_last_n_steps_state",
type=int,
default=None,
help="save states until N steps elapsed (remove older states if N steps elapsed, overrides --save_last_n_steps) / 指定ステップごとにstateを保存するとき、このステップ数経過するまで保存するこのステップ数経過したら削除する。--save_last_n_stepsを上書きする",
) )
parser.add_argument( parser.add_argument(
"--save_state", "--save_state",
@ -2160,6 +2185,12 @@ def verify_training_args(args: argparse.Namespace):
if args.v2 and args.clip_skip is not None: if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
print(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
def add_dataset_arguments( def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
@ -2903,26 +2934,53 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
return encoder_hidden_states return encoder_hidden_states
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): def default_if_none(value, default):
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name return default if value is None else value
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
return model_name, ckpt_name
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): def get_epoch_ckpt_name(args: argparse.Namespace, ext: str, epoch_no: int):
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
if saving: return EPOCH_FILE_NAME.format(model_name, epoch_no) + ext
os.makedirs(args.output_dir, exist_ok=True)
save_func()
if args.save_last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
remove_old_func(remove_epoch_no)
return saving
def save_sd_model_on_epoch_end( def get_step_ckpt_name(args: argparse.Namespace, ext: str, step_no: int):
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
return STEP_FILE_NAME.format(model_name, step_no) + ext
def get_last_ckpt_name(args: argparse.Namespace, ext: str):
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
return model_name + ext
def get_remove_epoch_no(args: argparse.Namespace, epoch_no: int):
if args.save_last_n_epochs is None:
return None
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
if remove_epoch_no < 0:
return None
return remove_epoch_no
def get_remove_step_no(args: argparse.Namespace, step_no: int):
if args.save_last_n_steps is None:
return None
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
# save_every_n_steps=10, save_last_n_steps=30の場合、50step目には30step分残し、10step目を削除する
remove_step_no = step_no - args.save_last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no < 0:
return None
return remove_step_no
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd_model_on_epoch_end_or_stepwise(
args: argparse.Namespace, args: argparse.Namespace,
on_epoch_end: bool,
accelerator, accelerator,
src_path: str, src_path: str,
save_stable_diffusion_format: bool, save_stable_diffusion_format: bool,
@ -2935,57 +2993,87 @@ def save_sd_model_on_epoch_end(
unet, unet,
vae, vae,
): ):
epoch_no = epoch + 1 if on_epoch_end:
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) epoch_no = epoch + 1
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
if not saving:
return
if save_stable_diffusion_format: model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
remove_no = get_remove_epoch_no(args, epoch_no)
def save_sd():
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
save_func = save_sd
remove_old_func = remove_sd
else: else:
# 保存するか否かは呼び出し側で判断済み
def save_du(): model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
epoch_no = epoch # 例: 最初のepochの途中で保存したら0になる、SDモデルに保存される
remove_no = get_remove_step_no(args, global_step)
os.makedirs(args.output_dir, exist_ok=True)
if save_stable_diffusion_format:
ext = ".safetensors" if use_safetensors else ".ckpt"
if on_epoch_end:
ckpt_name = get_epoch_ckpt_name(args, ext, epoch_no)
else:
ckpt_name = get_step_ckpt_name(args, ext, global_step)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
# remove older checkpoints
if remove_no is not None:
if on_epoch_end:
remove_ckpt_name = get_epoch_ckpt_name(args, ext, remove_no)
else:
remove_ckpt_name = get_step_ckpt_name(args, ext, remove_no)
remove_ckpt_file = os.path.join(args.output_dir, remove_ckpt_name)
if os.path.exists(remove_ckpt_file):
print(f"removing old checkpoint: {remove_ckpt_file}")
os.remove(remove_ckpt_file)
else:
if on_epoch_end:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
print(f"saving model: {out_dir}") else:
os.makedirs(out_dir, exist_ok=True) out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
def remove_du(old_epoch_no): print(f"saving model: {out_dir}")
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) model_util.save_diffusers_checkpoint(
if os.path.exists(out_dir_old): args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
print(f"removing old model: {out_dir_old}") )
shutil.rmtree(out_dir_old) if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
save_func = save_du # remove older checkpoints
remove_old_func = remove_du if remove_no is not None:
if on_epoch_end:
remove_out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
else:
remove_out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, remove_no))
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) if os.path.exists(remove_out_dir):
if saving and args.save_state: print(f"removing old model: {remove_out_dir}")
save_state_on_epoch_end(args, accelerator, model_name, epoch_no) shutil.rmtree(remove_out_dir)
if on_epoch_end:
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
else:
save_and_remove_state_stepwise(args, accelerator, global_step)
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no): def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
print("saving state.") model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
print(f"saving state at epoch {epoch_no}")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)) state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir) accelerator.save_state(state_dir)
if args.save_state_to_huggingface: if args.save_state_to_huggingface:
@ -3001,12 +3089,40 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
shutil.rmtree(state_dir_old) shutil.rmtree(state_dir_old)
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
print(f"saving state at step {step_no}")
os.makedirs(args.output_dir, exist_ok=True)
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + STEP_STATE_NAME.format(model_name, step_no))
last_n_steps = args.save_last_n_steps_state if args.save_last_n_steps_state else args.save_last_n_steps
if last_n_steps is not None:
# last_n_steps前のstep_noから、save_every_n_stepsの倍数のstep_noを計算して削除する
remove_step_no = step_no - last_n_steps - 1
remove_step_no = remove_step_no - (remove_step_no % args.save_every_n_steps)
if remove_step_no > 0:
state_dir_old = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, remove_step_no))
if os.path.exists(state_dir_old):
print(f"removing old state: {state_dir_old}")
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator): def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
print("saving last state.") print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)) state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir) accelerator.save_state(state_dir)
if args.save_state_to_huggingface: if args.save_state_to_huggingface:
print("uploading last state to huggingface.") print("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name)) huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
@ -3024,7 +3140,7 @@ def save_sd_model_on_train_end(
unet, unet,
vae, vae,
): ):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
if save_stable_diffusion_format: if save_stable_diffusion_format:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)

View File

@ -126,8 +126,19 @@ def save_configuration(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, down_lr_weight,
weighted_captions,unit, mid_lr_weight,
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
conv_dims,
conv_alphas,
weighted_captions,
unit,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -247,8 +258,19 @@ def open_configuration(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, down_lr_weight,
weighted_captions,unit, mid_lr_weight,
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
conv_dims,
conv_alphas,
weighted_captions,
unit,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -359,8 +381,19 @@ def train_model(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, down_lr_weight,
weighted_captions,unit, mid_lr_weight,
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
conv_dims,
conv_alphas,
weighted_captions,
unit,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
print_only_bool = True if print_only.get('label') == 'True' else False print_only_bool = True if print_only.get('label') == 'True' else False
@ -400,9 +433,12 @@ def train_model(
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
if optimizer == 'Adafactor' and lr_warmup != '0': if optimizer == 'Adafactor' and lr_warmup != '0':
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") msgbox(
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
title='Warning',
)
lr_warmup = '0' lr_warmup = '0'
# If string is empty set string to 0. # If string is empty set string to 0.
@ -431,7 +467,7 @@ def train_model(
try: try:
# Extract the number of repeats from the folder name # Extract the number of repeats from the folder name
repeats = int(folder.split('_')[0]) repeats = int(folder.split('_')[0])
# Count the number of images in the folder # Count the number of images in the folder
num_images = len( num_images = len(
[ [
@ -455,10 +491,12 @@ def train_model(
print(f'Folder {folder}: {steps} steps') print(f'Folder {folder}: {steps} steps')
total_steps += steps total_steps += steps
except ValueError: except ValueError:
# Handle the case where the folder name does not contain an underscore # Handle the case where the folder name does not contain an underscore
print(f"Error: '{folder}' does not contain an underscore, skipping...") print(
f"Error: '{folder}' does not contain an underscore, skipping..."
)
# calculate max_train_steps # calculate max_train_steps
max_train_steps = int( max_train_steps = int(
@ -535,13 +573,25 @@ def train_model(
return return
run_cmd += f' --network_module=lycoris.kohya' run_cmd += f' --network_module=lycoris.kohya'
run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"' run_cmd += f' --network_args "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=loha"'
if LoRA_type in ['Kohya LoCon', 'Standard']: if LoRA_type in ['Kohya LoCon', 'Standard']:
kohya_lora_var_list = ['down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas'] kohya_lora_var_list = [
'down_lr_weight',
'mid_lr_weight',
'up_lr_weight',
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'conv_dims',
'conv_alphas',
]
run_cmd += f' --network_module=networks.lora' run_cmd += f' --network_module=networks.lora'
kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value} kohya_lora_vars = {
key: value
for key, value in vars().items()
if key in kohya_lora_var_list and value
}
network_args = '' network_args = ''
if LoRA_type == 'Kohya LoCon': if LoRA_type == 'Kohya LoCon':
@ -553,12 +603,28 @@ def train_model(
if network_args: if network_args:
run_cmd += f' --network_args{network_args}' run_cmd += f' --network_args{network_args}'
if LoRA_type in ['Kohya DyLoRA']: if LoRA_type in ['Kohya DyLoRA']:
kohya_lora_var_list = ['conv_dim', 'conv_alpha', 'down_lr_weight', 'mid_lr_weight', 'up_lr_weight', 'block_lr_zero_threshold', 'block_dims', 'block_alphas', 'conv_dims', 'conv_alphas', 'unit'] kohya_lora_var_list = [
'conv_dim',
'conv_alpha',
'down_lr_weight',
'mid_lr_weight',
'up_lr_weight',
'block_lr_zero_threshold',
'block_dims',
'block_alphas',
'conv_dims',
'conv_alphas',
'unit',
]
run_cmd += f' --network_module=networks.dylora' run_cmd += f' --network_module=networks.dylora'
kohya_lora_vars = {key: value for key, value in vars().items() if key in kohya_lora_var_list and value} kohya_lora_vars = {
key: value
for key, value in vars().items()
if key in kohya_lora_var_list and value
}
network_args = '' network_args = ''
@ -641,6 +707,9 @@ def train_model(
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma, min_snr_gamma=min_snr_gamma,
save_every_n_steps=save_every_n_steps,
save_last_n_steps=save_last_n_steps,
save_last_n_steps_state=save_last_n_steps_state,
) )
run_cmd += run_cmd_sample( run_cmd += run_cmd_sample(
@ -650,7 +719,7 @@ def train_model(
sample_prompts, sample_prompts,
output_dir, output_dir,
) )
# if not down_lr_weight == '': # if not down_lr_weight == '':
# run_cmd += f' --down_lr_weight="{down_lr_weight}"' # run_cmd += f' --down_lr_weight="{down_lr_weight}"'
# if not mid_lr_weight == '': # if not mid_lr_weight == '':
@ -667,9 +736,6 @@ def train_model(
# run_cmd += f' --conv_dims="{conv_dims}"' # run_cmd += f' --conv_dims="{conv_dims}"'
# if not conv_alphas == '': # if not conv_alphas == '':
# run_cmd += f' --conv_alphas="{conv_alphas}"' # run_cmd += f' --conv_alphas="{conv_alphas}"'
if print_only_bool: if print_only_bool:
print( print(
@ -903,17 +969,27 @@ def lora_tab(
step=1, step=1,
interactive=True, interactive=True,
) )
# Show of hide LoCon conv settings depending on LoRA type selection # Show of hide LoCon conv settings depending on LoRA type selection
def update_LoRA_settings(LoRA_type): def update_LoRA_settings(LoRA_type):
# Print a message when LoRA type is changed # Print a message when LoRA type is changed
print('LoRA type changed...') print('LoRA type changed...')
# Determine if LoCon_row should be visible based on LoRA_type # Determine if LoCon_row should be visible based on LoRA_type
LoCon_row = LoRA_type in {'LoCon', 'Kohya DyLoRA', 'Kohya LoCon', 'LyCORIS/LoHa', 'LyCORIS/LoCon'} LoCon_row = LoRA_type in {
'LoCon',
'Kohya DyLoRA',
'Kohya LoCon',
'LyCORIS/LoHa',
'LyCORIS/LoCon',
}
# Determine if LoRA_type_change should be visible based on LoRA_type # Determine if LoRA_type_change should be visible based on LoRA_type
LoRA_type_change = LoRA_type in {'Standard', 'Kohya DyLoRA', 'Kohya LoCon'} LoRA_type_change = LoRA_type in {
'Standard',
'Kohya DyLoRA',
'Kohya LoCon',
}
# Determine if kohya_dylora_visible should be visible based on LoRA_type # Determine if kohya_dylora_visible should be visible based on LoRA_type
kohya_dylora_visible = LoRA_type == 'Kohya DyLoRA' kohya_dylora_visible = LoRA_type == 'Kohya DyLoRA'
@ -925,7 +1001,6 @@ def lora_tab(
gr.Group.update(visible=kohya_dylora_visible), gr.Group.update(visible=kohya_dylora_visible),
) )
with gr.Row(): with gr.Row():
max_resolution = gr.Textbox( max_resolution = gr.Textbox(
label='Max resolution', label='Max resolution',
@ -941,9 +1016,12 @@ def lora_tab(
label='Stop text encoder training', label='Stop text encoder training',
info='After what % of steps should the text encoder stop being trained. 0 = train for all steps.', info='After what % of steps should the text encoder stop being trained. 0 = train for all steps.',
) )
enable_bucket = gr.Checkbox(label='Enable buckets', value=True, enable_bucket = gr.Checkbox(
info='Allow non similar resolution dataset images to be trained on.',) label='Enable buckets',
value=True,
info='Allow non similar resolution dataset images to be trained on.',
)
with gr.Accordion('Advanced Configuration', open=False): with gr.Accordion('Advanced Configuration', open=False):
with gr.Row(visible=True) as kohya_advanced_lora: with gr.Row(visible=True) as kohya_advanced_lora:
with gr.Tab(label='Weights'): with gr.Tab(label='Weights'):
@ -951,46 +1029,46 @@ def lora_tab(
down_lr_weight = gr.Textbox( down_lr_weight = gr.Textbox(
label='Down LR weights', label='Down LR weights',
placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1', placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1',
info='Specify the learning rate weight of the down blocks of U-Net.' info='Specify the learning rate weight of the down blocks of U-Net.',
) )
mid_lr_weight = gr.Textbox( mid_lr_weight = gr.Textbox(
label='Mid LR weights', label='Mid LR weights',
placeholder='(Optional) eg: 0.5', placeholder='(Optional) eg: 0.5',
info='Specify the learning rate weight of the mid block of U-Net.' info='Specify the learning rate weight of the mid block of U-Net.',
) )
up_lr_weight = gr.Textbox( up_lr_weight = gr.Textbox(
label='Up LR weights', label='Up LR weights',
placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1', placeholder='(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1',
info='Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.' info='Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.',
) )
block_lr_zero_threshold = gr.Textbox( block_lr_zero_threshold = gr.Textbox(
label='Blocks LR zero threshold', label='Blocks LR zero threshold',
placeholder='(Optional) eg: 0.1', placeholder='(Optional) eg: 0.1',
info='If the weight is not more than this value, the LoRA module is not created. The default is 0.' info='If the weight is not more than this value, the LoRA module is not created. The default is 0.',
) )
with gr.Tab(label='Blocks'): with gr.Tab(label='Blocks'):
with gr.Row(visible=True): with gr.Row(visible=True):
block_dims = gr.Textbox( block_dims = gr.Textbox(
label='Block dims', label='Block dims',
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
info='Specify the dim (rank) of each block. Specify 25 numbers.' info='Specify the dim (rank) of each block. Specify 25 numbers.',
) )
block_alphas = gr.Textbox( block_alphas = gr.Textbox(
label='Block alphas', label='Block alphas',
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.' info='Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.',
) )
with gr.Tab(label='Conv'): with gr.Tab(label='Conv'):
with gr.Row(visible=True): with gr.Row(visible=True):
conv_dims = gr.Textbox( conv_dims = gr.Textbox(
label='Conv dims', label='Conv dims',
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
info='Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.' info='Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.',
) )
conv_alphas = gr.Textbox( conv_alphas = gr.Textbox(
label='Conv alphas', label='Conv alphas',
placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2', placeholder='(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2',
info='Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.' info='Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.',
) )
with gr.Row(): with gr.Row():
no_token_padding = gr.Checkbox( no_token_padding = gr.Checkbox(
@ -1000,7 +1078,9 @@ def lora_tab(
label='Gradient accumulate steps', value='1' label='Gradient accumulate steps', value='1'
) )
weighted_captions = gr.Checkbox( weighted_captions = gr.Checkbox(
label='Weighted captions', value=False, info='Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.', label='Weighted captions',
value=False,
info='Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder.',
) )
with gr.Row(): with gr.Row():
prior_loss_weight = gr.Number( prior_loss_weight = gr.Number(
@ -1041,6 +1121,9 @@ def lora_tab(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -1054,9 +1137,11 @@ def lora_tab(
sample_sampler, sample_sampler,
sample_prompts, sample_prompts,
) = sample_gradio_config() ) = sample_gradio_config()
LoRA_type.change( LoRA_type.change(
update_LoRA_settings, inputs=[LoRA_type], outputs=[LoCon_row, kohya_advanced_lora, kohya_dylora] update_LoRA_settings,
inputs=[LoRA_type],
outputs=[LoCon_row, kohya_advanced_lora, kohya_dylora],
) )
with gr.Tab('Tools'): with gr.Tab('Tools'):
@ -1164,8 +1249,19 @@ def lora_tab(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
down_lr_weight,mid_lr_weight,up_lr_weight,block_lr_zero_threshold,block_dims,block_alphas,conv_dims,conv_alphas, down_lr_weight,
weighted_captions, unit, mid_lr_weight,
up_lr_weight,
block_lr_zero_threshold,
block_dims,
block_alphas,
conv_dims,
conv_alphas,
weighted_captions,
unit,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
] ]
button_open_config.click( button_open_config.click(

View File

@ -8,7 +8,6 @@ easygui==0.98.3
einops==0.6.0 einops==0.6.0
ftfy==6.1.1 ftfy==6.1.1
gradio==3.27.0; sys_platform != 'darwin' gradio==3.27.0; sys_platform != 'darwin'
# gradio==3.19.1; sys_platform != 'darwin'
gradio==3.23.0; sys_platform == 'darwin' gradio==3.23.0; sys_platform == 'darwin'
lion-pytorch==0.0.6 lion-pytorch==0.0.6
opencv-python==4.7.0.68 opencv-python==4.7.0.68
@ -20,6 +19,7 @@ tk==0.1.0
toml==0.10.2 toml==0.10.2
transformers==4.26.0 transformers==4.26.0
voluptuous==0.13.1 voluptuous==0.13.1
wandb==0.15.0
# for BLIP captioning # for BLIP captioning
fairscale==0.4.13 fairscale==0.4.13
requests==2.28.2 requests==2.28.2

View File

@ -115,6 +115,9 @@ def save_configuration(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -230,6 +233,9 @@ def open_configuration(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
# Get list of function parameters and values # Get list of function parameters and values
parameters = list(locals().items()) parameters = list(locals().items())
@ -327,6 +333,9 @@ def train_model(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
): ):
if pretrained_model_name_or_path == '': if pretrained_model_name_or_path == '':
msgbox('Source model information is missing') msgbox('Source model information is missing')
@ -362,9 +371,12 @@ def train_model(
if check_if_model_exist(output_name, output_dir, save_model_as): if check_if_model_exist(output_name, output_dir, save_model_as):
return return
if optimizer == 'Adafactor' and lr_warmup != '0': if optimizer == 'Adafactor' and lr_warmup != '0':
msgbox("Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", title="Warning") msgbox(
"Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.",
title='Warning',
)
lr_warmup = '0' lr_warmup = '0'
# Get a list of all subfolders in train_data_dir # Get a list of all subfolders in train_data_dir
@ -525,6 +537,9 @@ def train_model(
additional_parameters=additional_parameters, additional_parameters=additional_parameters,
vae_batch_size=vae_batch_size, vae_batch_size=vae_batch_size,
min_snr_gamma=min_snr_gamma, min_snr_gamma=min_snr_gamma,
save_every_n_steps=save_every_n_steps,
save_last_n_steps=save_last_n_steps,
save_last_n_steps_state=save_last_n_steps_state,
) )
run_cmd += f' --token_string="{token_string}"' run_cmd += f' --token_string="{token_string}"'
run_cmd += f' --init_word="{init_word}"' run_cmd += f' --init_word="{init_word}"'
@ -791,6 +806,9 @@ def ti_tab(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
) = gradio_advanced_training() ) = gradio_advanced_training()
color_aug.change( color_aug.change(
color_aug_changed, color_aug_changed,
@ -899,6 +917,9 @@ def ti_tab(
additional_parameters, additional_parameters,
vae_batch_size, vae_batch_size,
min_snr_gamma, min_snr_gamma,
save_every_n_steps,
save_last_n_steps,
save_last_n_steps_state,
] ]
button_open_config.click( button_open_config.click(

View File

@ -243,7 +243,13 @@ def create_upscaler(**kwargs):
model = Upscaler() model = Upscaler()
print(f"Loading weights from {weights}...") print(f"Loading weights from {weights}...")
model.load_state_dict(torch.load(weights, map_location=torch.device("cpu"))) if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
sd = load_file(weights)
else:
sd = torch.load(weights, map_location=torch.device("cpu"))
model.load_state_dict(sd)
return model return model

900
train_README-zh.md Normal file
View File

@ -0,0 +1,900 @@
__由于文档正在更新中描述可能有错误。__
# 关于本学习文档,通用描述
本库支持模型微调(fine tuning)、DreamBooth、训练LoRA和文本反转(Textual Inversion)(包括[XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)
本文档将说明它们通用的学习数据准备方法和选项等。
# 概要
请提前参考本仓库的README准备好环境。
以下本节说明。
1. 关于准备学习数据的新形式(使用设置文件)
1. 对于在学习中使用的术语的简要解释
1. 先前的指定格式(不使用设置文件,而是从命令行指定)
1. 生成学习过程中的示例图像
1. 各脚本中常用的共同选项
1. 准备 fine tuning 方法的元数据:如说明文字(打标签)等
1. 如果只执行一次,学习就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
# 关于准备训练数据
在任意文件夹(也可以是多个文件夹)中准备好训练数据的图像文件。支持 `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` 格式的文件。通常不需要进行任何预处理,如调整大小等。
但是请勿使用极小的图像其尺寸比训练分辨率稍后将提到还小建议事先使用超分辨率AI等进行放大。另外请注意不要使用过大的图像约为3000 x 3000像素以上因为这可能会导致错误建议事先缩小。
在训练时,需要整理要用于训练模型的图像数据,并将其指定给脚本。根据训练数据的数量、训练目标和说明(图像描述)是否可用等因素,可以使用几种方法指定训练数据。以下是其中的一些方法(每个名称都不是通用的,而是该存储库自定义的定义)。有关正则化图像的信息将在稍后提供。
1. DreamBooth、class + identifier方式可使用正则化图像
将训练目标与特定单词identifier相关联进行训练。无需准备说明。例如当要学习特定角色时由于无需准备说明因此比较方便但由于学习数据的所有元素都与identifier相关联例如发型、服装、背景等因此在生成时可能会出现无法更换服装的情况。
2. DreamBooth、说明方式可使用正则化图像
准备记录每个图像说明的文本文件进行训练。例如通过将图像详细信息如穿着白色衣服的角色A、穿着红色衣服的角色A等记录在说明中可以将角色和其他元素分离并期望模型更准确地学习角色。
3. 微调方式(不可使用正则化图像)
先将说明收集到元数据文件中。支持分离标签和说明以及预先缓存latents等功能以加速训练这些将在另一篇文档中介绍虽然名为fine tuning方式但不仅限于fine tuning。
你要学的东西和你可以使用的规范方法的组合如下。
| 学习对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
|----------------| ----- | ----- | ----- | ----- |
| fine tuning微调模型 | `fine_tune.py`| x | x | o |
| DreamBooth训练模型 | `train_db.py`| o | o | x |
| LoRA | `train_network.py`| o | o | o |
| Textual Invesion | `train_textual_inversion.py`| o | o | o |
## 选择哪一个
如果您想要学习LoRA、Textual Inversion而不需要准备简介文件则建议使用DreamBooth class+identifier。如果您能够准备好则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用规则化图像则请考虑使用fine-tuning方法。
对于DreamBooth也是一样的但不能使用fine-tuning方法。对于fine-tuning方法只能使用fine-tuning方式。
# 每种方法的指定方式
在这里,我们只介绍每种指定方法的典型模式。有关更详细的指定方法,请参见[数据集设置](./config_README-ja.md)。
# DreamBoothclass+identifier方法可使用规则化图像
在该方法中,每个图像将被视为使用与 `class identifier` 相同的标题进行训练(例如 `shs dog`)。
这样一来每张图片都相当于使用标题“分类标识”例如“shs dog”进行训练。
## step 1.确定identifier和class
要将学习的目标与identifier和属于该目标的class相关联。
(虽然有很多称呼,但暂时按照原始论文的说法。)
以下是简要说明(请查阅详细信息)。
class是学习目标的一般类别。例如如果要学习特定品种的狗则class将是“dog”。对于动漫角色根据模型不同可能是“boy”或“girl”也可能是“1boy”或“1girl”。
identifier是用于识别学习目标并进行学习的单词。可以使用任何单词但是根据原始论文“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
使用identifier和class例如“shs dog”可以将模型训练为从class中识别并学习所需的目标。
在图像生成时使用“shs dog”将生成所学习狗种的图像。
作为identifier我最近使用的一些参考是“shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny”等。最好是不包含在Danbooru标签中的单词。
## step 2. 决定是否使用正则化图像,并生成正则化图像
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为学习目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
通过同时学习目标图像和正则化图像,类别仍然保持不变,仅在将标识符附加到提示中时才生成目标图像。
如果您只想在LoRA或DreamBooth中使用特定的角色则可以不使用正则化图像。
在Textual Inversion中也不需要使用如果要学习的token string不包含在标题中则不会学习任何内容
一般情况下,使用在训练目标模型时只使用类别名称生成的图像作为正则化图像是常见的做法(例如 `1girl`)。但是,如果生成的图像质量不佳,可以尝试修改提示或使用从网络上另外下载的图像。
(由于正则化图像也被训练,因此其质量会影响模型。)
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法推广并学习它们的特征)。
如果要使用生成的图像请将其大小通常与训练分辨率更准确地说是bucket的分辨率相适应。
## step 2. 设置文件的描述
创建一个文本文件,并将其扩展名更改为`.toml`。例如,您可以按以下方式进行描述:
(以``开头的部分是注释,因此您可以直接复制粘贴,或者将其删除,都没有问题。)
```toml
[general]
enable_bucket = true # 是否使用Aspect Ratio Bucketing
[[datasets]]
resolution = 512 # 学习分辨率
batch_size = 4 # 批量大小
[[datasets.subsets]]
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
class_tokens = 'hoge girl' # 指定标识符类
num_repeats = 10 # 训练图像的迭代次数
# 以下仅在使用正则化图像时进行描述。不使用则删除
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
class_tokens = 'girl' # 指定类别
num_repeats = 1 # 正则化图像的迭代次数基本上1就可以了
```
基本上只需更改以下位置即可进行学习。
1. 学习分辨率
指定一个数字表示正方形(如果是 `512`,则为 512x512如果使用方括号和逗号分隔的两个数字则表示横向×纵向如果是`[512,768]`,则为 512x768。在SD1.x系列中原始学习分辨率为512。指定较大的分辨率`[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中分辨率为 `768`
1. 批量大小
指定同时学习多少个数据。这取决于GPU的VRAM大小和学习分辨率。详细信息将在后面说明。此外fine tuning/DreamBooth/LoRA等也会影响批量大小请查看各个脚本的说明。
1. 文件夹指定
指定用于学习的图像和正则化图像(仅在使用时)的文件夹。指定包含图像数据的文件夹。
1. identifier 和 class 的指定
如前所述,与示例相同。
1. 迭代次数
将在后面说明。
### 关于重复次数
重复次数用于调整正则化图像和训练用图像的数量。由于正则化图像的数量多于训练用图像,因此需要重复使用训练用图像来达到一对一的比例,从而实现训练。
请将重复次数指定为“ __训练用图像的重复次数×训练用图像的数量≥正则化图像的重复次数×正则化图像的数量__ ”。
1个epoch数据一周一次的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值则剩余的正则化图像将不会被使用。
## 步骤 3. 学习
请根据每个文档的参考进行学习。
# DreamBooth标题方式可使用规范化图像
在此方式中,每个图像都将通过标题进行学习。
## 步骤 1. 准备标题文件
请将与图像具有相同文件名且扩展名为 `.caption`(可以在设置中更改)的文件放置在用于训练图像的文件夹中。每个文件应该只有一行。编码为 `UTF-8`
## 步骤 2. 决定是否使用规范化图像,并在使用时生成规范化图像
与class+identifier格式相同。可以在规范化图像上附加标题但通常不需要。
## 步骤 2. 编写设置文件
创建一个文本文件并将扩展名更改为 `.toml`。例如,可以按以下方式进行记录。
```toml
[general]
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
[[datasets]]
resolution = 512 # 学習解像度
batch_size = 4 # 批量大小
[[datasets.subsets]]
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
caption_extension = '.caption' # 使用字幕文件扩展名 .txt 时重写
num_repeats = 10 # 训练图像的迭代次数
# 以下仅在使用正则化图像时进行描述。不使用则删除
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg' #指定包含正则化图像的文件夹
class_tokens = 'girl' # class を指定
num_repeats = 1 #
正则化图像的迭代次数基本上1就可以了
```
基本上,您可以通过仅重写以下位置来学习。除非另有说明,否则与类+标识符方法相同。
1. 学习分辨率
2. 批量大小
3. 文件夹指定
4. 标题文件的扩展名
可以指定任意的扩展名。
5. 重复次数
## 步骤 3. 学习
请参考每个文档进行学习。
# 微调方法
## 步骤 1. 准备元数据
将标题和标签整合到管理文件中称为元数据。它的扩展名为 `.json`格式为json。由于创建方法较长因此在本文档的末尾进行了描述。
## 步骤 2. 编写设置文件
创建一个文本文件,将扩展名设置为 `.toml`。例如,可以按以下方式编写:
```toml
[general]
shuffle_caption = true
keep_tokens = 1
[[datasets]]
resolution = 512 # 图像分辨率
batch_size = 4 # 批量大小
[[datasets.subsets]]
image_dir = 'C:\piyo' # 指定包含训练图像的文件夹
metadata_file = 'C:\piyo\piyo_md.json' # 元数据文件名
```
基本上您可以通过仅重写以下位置来学习。如无特别说明与DreamBooth相同类+标识符方式。
1. 学习解像度
2. 批次大小
3. 指定文件夹
4. 元数据文件名
指定使用后面所述方法创建的元数据文件。
## 第三步:学习
请参考各个文档进行学习。
# 学习中使用的术语简单解释
由于省略了细节并且我自己也没有完全理解,因此请自行查阅详细信息。
## 微调fine tuning
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中狭义的微调是指使用图像和标题进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
## 步骤step
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的标题传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
## 批次大小batch size
批次大小指定每个步骤要计算多少数据。批量计算可以提高速度。一般来说,批次大小越大,精度也越高。
“批次大小×步数”是用于训练的数据数量。因此,建议减少步数以增加批次大小。
(但是,例如,“批次大小为 1步数为 1600”和“批次大小为 4步数为 400”将不会产生相同的结果。如果使用相同的学习速率通常后者会导致模型欠拟合。请尝试增加学习率例如 `2e-6`),将步数设置为 500 等。)
批次大小越大GPU 内存消耗就越大。如果内存不足,将导致错误,或者在边缘时将导致训练速度降低。建议在任务管理器或 `nvidia-smi` 命令中检查使用的内存量进行调整。
另外,批次是指“一块数据”的意思。
## 学习率
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,也可能无法达到最佳状态。
在fine tuning、DreamBooth、LoRA等过程中学习率会有很大的差异并且也会受到训练数据、所需训练的模型、批量大小和步骤数等因素的影响。建议从一般的值开始观察训练状态并逐渐调整。
默认情况下,整个训练过程中学习率是固定的。但是可以通过调度程序指定学习率如何变化,因此结果也会有所不同。
## 时代epoch
Epoch指的是训练数据被完整训练一遍即数据一周的情况。如果指定了重复次数则在重复后的数据一周后就是1个epoch。
1个epoch的步骤数通常为“数据量÷批量大小”但如果使用Aspect Ratio Bucketing则略微增加由于不同bucket的数据不能在同一个批次中因此步骤数会增加
## 纵横比分桶Aspect Ratio Bucketing)
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分期望更准确地学习图像和标题之间的关系。
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的纵横比。
该设置在配置中有效,可以切换,但在此之前的配置文件示例中已启用(设置为 `true`)。
学习分辨率将根据参数所提供的分辨率面积即内存使用量进行调整以64像素为单位默认值可更改在纵横方向上进行调整和创建。
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照纵横比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
# 以前的指定格式(不使用 .toml 文件,而是使用命令行选项指定)
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth 标题方法、微调方法三种方式。
## DreamBooth、类+标识符方式
指定文件夹名称以指定迭代次数。还要使用 `train_data_dir``reg_data_dir` 选项。
### 第1步。准备用于训练的图像
创建一个用于存储训练图像的文件夹。__此外__按以下名称创建目录。
```
<迭代次数>_<标识符> <类别>
```
不要忘记下划线``_``。
例如如果在名为“sls frog”的提示下重复数据 20 次则为“20_sls frog”。如下所示
![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png)
### 多个类别、多个标识符的学习
该方法很简单在用于训练的图像文件夹中需要准备多个文件夹每个文件夹都是以“重复次数_<标识符> <类别>”命名的同样在正则化图像文件夹中也需要准备多个文件夹每个文件夹都是以“重复次数_<类别>”命名的。
例如如果要同时训练“sls青蛙”和“cpc兔子”则应按以下方式准备文件夹。
![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png)
如果一个类别包含多个对象可以只使用一个正则化图像文件夹。例如如果在1girl类别中有角色A和角色B则可以按照以下方式处理
- train_girls
- 10_sls 1girl
- 10_cpc 1girl
- reg_girls
- 1_1girl
### step 2. 准备正规化图像
这是使用规则化图像时的过程。
创建一个文件夹来存储规则化的图像。 __此外__ 创建一个名为``<repeat count>_<class>`` 的目录。
例如使用提示“frog”并且不重复数据仅一次
![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png)
步骤3. 执行学习
执行每个学习脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
## DreamBooth带标题方式
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption可以使用选项进行更改放置在该文件夹中然后从该文件中加载标题作为提示进行学习。
※文件夹名称(标识符类)不再用于这些图像的训练。
默认的标题文件扩展名为.caption。可以使用学习脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行学习时会对学习时的标题进行混洗。
## 微调方式
创建元数据的方式与使用配置文件相同。 使用 `in_json` 选项指定元数据文件。
# 学习过程中的样本输出
通过在训练中使用模型生成图像,可以检查学习进度。将以下选项指定为学习脚本。
- `--sample_every_n_steps` / `--sample_every_n_epochs`
指定要采样的步数或纪元数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
- `--sample_prompts`
指定示例输出的提示文件。
- `--sample_sampler`
指定用于采样输出的采样器。
`'ddim', 'pndm', 'heun', 'dpmsolver', 'dpmsolver++', 'dpmsingle', 'k_lms', 'k_euler', 'k_euler_a', 'k_dpm_2', 'k_dpm_2_a'`が選べます。
要输出样本,您需要提前准备一个包含提示的文本文件。每行输入一个提示。
```txt
# prompt 1
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
以“#”开头的行是注释。您可以使用“`--` + 小写字母”为生成的图像指定选项,例如 `--n`。您可以使用:
- `--n` 否定提示到下一个选项。
- `--w` 指定生成图像的宽度。
- `--h` 指定生成图像的高度。
- `--d` 指定生成图像的种子。
- `--l` 指定生成图像的 CFG 比例。
- `--s` 指定生成过程中的步骤数。
# 每个脚本通用的常用选项
文档更新可能跟不上脚本更新。在这种情况下,请使用 `--help` 选项检查可用选项。
## 学习模型规范
- `--v2` / `--v_parameterization`
如果使用 Hugging Face 的 stable-diffusion-2-base 或来自它的微调模型作为学习目标模型(对于在推理时指示使用 `v2-inference.yaml` 的模型),`- 当使用-v2` 选项与 stable-diffusion-2、768-v-ema.ckpt 及其微调模型(对于在推理过程中使用 `v2-inference-v.yaml` 的模型),`- 指定两个 -v2`和 `--v_parameterization` 选项。
以下几点在 Stable Diffusion 2.0 中发生了显着变化。
1. 使用分词器
2. 使用哪个Text Encoder使用哪个输出层2.0使用倒数第二层)
3. Text Encoder的输出维度(768->1024)
4. U-Net的结构CrossAttention的头数等
5. v-parameterization采样方式好像变了
其中碱基使用1-4个非碱基使用1-5个768-v。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
-`--pretrained_model_name_or_path`
指定要从中执行额外训练的模型。您可以指定稳定扩散检查点文件(.ckpt 或 .safetensors、扩散器本地磁盘上的模型目录或扩散器模型 ID例如“stabilityai/stable-diffusion-2”
## 学习设置
- `--output_dir`
指定训练后保存模型的文件夹。
- `--output_name`
指定不带扩展名的模型文件名。
- `--dataset_config`
指定描述数据集配置的 .toml 文件。
- `--max_train_steps` / `--max_train_epochs`
指定要学习的步数或纪元数。如果两者都指定,则 epoch 数优先。
-
- `--mixed_precision`
训练混合精度以节省内存。指定像`--mixed_precision = "fp16"`。与无混合精度(默认)相比,精度可能较低,但训练所需的 GPU 内存明显较少。
在RTX30系列以后也可以指定`bf16`,请配合您在搭建环境时做的加速设置)。
- `--gradient_checkpointing`
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批量大小,所以那里有影响。
另外,打开它通常会减慢速度,但可以增加批量大小,因此总的学习时间实际上可能会更快。
- `--xformers` / `--mem_eff_attn`
当指定 xformers 选项时,使用 xformers 的 CrossAttention。如果未安装 xformers 或发生错误(取决于环境,例如 `mixed_precision="no"`),请指定 `mem_eff_attn` 选项而不是使用 CrossAttention 的内存节省版本xformers 比 慢)。
- `--save_precision`
指定保存时的数据精度。为 save_precision 选项指定 float、fp16 或 bf16 将以该格式保存模型(在 DreamBooth 中保存 Diffusers 格式时无效,微调)。当您想缩小模型的尺寸时请使用它。
- `--save_every_n_epochs` / `--save_state` / `--resume`
为 save_every_n_epochs 选项指定一个数字可以在每个时期的训练期间保存模型。
如果同时指定save_state选项学习状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
学习状态输出到目标文件夹中名为“<output_name>-??????-state”??????是纪元数)的文件夹中。长时间学习时请使用。
使用 resume 选项从保存的训练状态恢复训练。指定学习状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
请注意,由于 Accelerator 规范epoch 数和全局步数不会保存,即使恢复时它们也从 1 开始。
- `--save_model_as` DreamBooth, fine tuning 仅有的)
您可以从 `ckpt, safetensors, diffusers, diffusers_safetensors` 中选择模型保存格式。
- `--save_model_as=safetensors` 指定喜欢当读取稳定扩散格式ckpt 或安全张量)并以扩散器格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
- `--clip_skip`
`2` 如果指定,则使用文本编码器 (CLIP) 的倒数第二层的输出。如果省略 1 或选项,则使用最后一层。
*SD2.0默认使用倒数第二层学习SD2.0时请不要指定。
如果被训练的模型最初被训练为使用第二层,则 2 是一个很好的值。
如果您使用的是最后一层那么整个模型都会根据该假设进行训练。因此如果再次使用第二层进行训练可能需要一定数量的teacher数据和更长时间的学习才能得到想要的学习结果。
- `--max_token_length`
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来学习。使用长字幕学习时指定。
但由于学习时token展开的规范与Automatic1111的web UI除法等规范略有不同如非必要建议用75学习。
与clip_skip一样学习与模型学习状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
- `--persistent_data_loader_workers`
在 Windows 环境中指定它可以显着减少时期之间的延迟。
- `--max_data_loader_n_workers`
指定数据加载的进程数。大量的进程会更快地加载数据并更有效地使用 GPU但会消耗更多的主内存。默认是"`8`或者`CPU并发执行线程数 - 1`,取小者"所以如果主存没有空间或者GPU使用率大概在90%以上,就看那些数字和 `2` 或将其降低到大约 `1`
- `--logging_dir` / `--log_prefix`
保存学习日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
例如,如果您指定 --logging_dir=logs将在您的工作文件夹中创建一个日志文件夹并将日志保存在日期/时间文件夹中。
此外,如果您指定 --log_prefix 选项,则指定的字符串将添加到日期和时间之前。使用“--logging_dir=logs --log_prefix=db_style1_”进行识别。
要检查 TensorBoard 中的日志,请打开另一个命令提示符并在您的工作文件夹中键入:
```
tensorboard --logdir=logs
```
我觉得tensorboard会在环境搭建的时候安装如果没有安装请用`pip install tensorboard`安装。)
然后打开浏览器到http://localhost:6006/就可以看到了。
- `--noise_offset`
本文的实现https://www.crosslabs.org//blog/diffusion-with-offset-noise
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 学习也有效。指定一个大约 0.1 的值似乎很好。
- `--debug_dataset`
通过添加此选项,您可以在学习之前检查将学习什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个纪元。
*图片在 Linux 环境(包括 Colab下不显示。
- `--vae`
如果您在 vae 选项中指定稳定扩散检查点、VAE 检查点文件、扩散模型或 VAE两者都可以指定本地或拥抱面模型 ID则该 VAE 用于学习(缓存时的潜伏)或在学习过程中获得潜伏)。
对于 DreamBooth 和微调,保存的模型将包含此 VAE
- `--cache_latents`
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体学习速度略快。
- `--min_snr_gamma`
指定最小 SNR 加权策略。细节是[这里](https://github.com/kohya-ss/sd-scripts/pull/308)请参阅。论文中推荐`5`。
## 优化器相关
- `--optimizer_type`
-- 指定优化器类型。您可以指定
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
- 与过去版本中未指定选项时相同
- AdamW8bit : 同上
- 与过去版本中指定的 --use_8bit_adam 相同
- Lion : https://github.com/lucidrains/lion-pytorch
- 与过去版本中指定的 --use_lion_optimizer 相同
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数同上
- DAdaptation : https://github.com/facebookresearch/dadaptation
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任何优化器
- `--learning_rate`
指定学习率。合适的学习率取决于学习脚本,所以请参考每个解释。
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
学习率的调度程序相关规范。
使用 lr_scheduler 选项您可以从线性、余弦、cosine_with_restarts、多项式、常数、constant_with_warmup 或任何调度程序中选择学习率调度程序。默认值是常量。
使用 lr_warmup_steps您可以指定预热调度程序的步数逐渐改变学习率
lr_scheduler_num_cycles 是 cosine with restarts 调度器中的重启次数lr_scheduler_power 是多项式调度器中的多项式幂。
有关详细信息,请自行研究。
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
### 关于指定优化器
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外您可以指定多个值以逗号分隔。例如要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。
指定可选参数时,请检查每个优化器的规格。
一些优化器有一个必需的参数,如果省略它会自动添加(例如 SGDNesterov 的动量)。检查控制台输出。
D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是学习率本身而是D-Adaptation决定的学习率的应用率所以通常指定1.0。如果您希望 Text Encoder 的学习率是 U-Net 的一半,请指定 ``--text_encoder_lr=0.5 --unet_lr=1.0``。
如果指定 relative_step=TrueAdaFactor 优化器可以自动调整学习率(如果省略,将默认添加)。自动调整时,学习率调度器被迫使用 adafactor_scheduler。此外指定 scale_parameter 和 warmup_init 似乎也不错。
自动调整的选项类似于``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"``。
如果您不想自动调整学习率,请添加可选参数 ``relative_step=False``。在那种情况下,似乎建议将 constant_with_warmup 用于学习率调度程序,而不要为梯度剪裁范数。所以参数就像``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0``。
### 使用任何优化器
使用 ``torch.optim`` 优化器时,仅指定类名(例如 ``--optimizer_type=RMSprop``),使用其他模块的优化器时,指定“模块名.类名”。(例如``--optimizer_type=bitsandbytes.optim.lamb.LAMB``)。
(内部仅通过 importlib 未确认操作。如果需要,请安装包。)
<!--
## 使用任意大小的图像进行训练 --resolution
你可以在广场外学习。请在分辨率中指定“宽度、高度”如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
就我个人而言我经常生成垂直长的图像所以我有时会用“448、640”来学习。
## 纵横比分桶 --enable_bucket / --min_bucket_reso / --max_bucket_reso
它通过指定 enable_bucket 选项来启用。 Stable Diffusion 在 512x512 分辨率下训练,但也在 256x768 和 384x640 等分辨率下训练。
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下学习。
由于分辨率为 64 像素,纵横比可能与原始图像不完全相同。
您可以使用 min_bucket_reso 选项指定分辨率的最小大小,使用 max_bucket_reso 指定最大大小。默认值分别为 256 和 1024。
例如,将最小尺寸指定为 384 将不会使用 256x1024 或 320x768 等分辨率。
如果将分辨率增加到 768x768您可能需要将 1280 指定为最大尺寸。
启用 Aspect Ratio Ratio Bucketing 时,最好准备具有与训练图像相似的各种分辨率的正则化图像。
(因为一批中的图像不偏向于训练图像和正则化图像。
## 扩充 --color_aug / --flip_aug
增强是一种通过在学习过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时学习。
由于数据是动态变化的,因此不能与 cache_latents 选项一起指定。
## 使用 fp16 梯度训练(实验特征)--full_fp16
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并学习(它似乎是 full fp16 学习而不是混合精度)。
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的​​情况下学习,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下学习。
预先在加速配置中指定 fp16并可选择设置 ``mixed_precision="fp16"``bf16 不起作用)。
为了最大限度地减少内存使用,请使用 xformers、use_8bit_adam、cache_latents、gradient_checkpointing 选项并将 train_batch_size 设置为 1。
(如果你负担得起,逐步增加 train_batch_size 应该会提高一点精度。)
它是通过修补 PyTorch 源代码实现的(已通过 PyTorch 1.12.1 和 1.13.0 确认)。准确率会大幅下降,途中学习失败的概率也会增加。
学习率和步数的设置似乎很严格。请注意它们并自行承担使用它们的风险。
-->
# 创建元数据文件
## 准备教师资料
如上所述准备好你要学习的图像数据,放在任意文件夹中。
例如,存储这样的图像:
![教师数据文件夹的屏幕截图](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png)
## 自动字幕
如果您只想学习没有标题的标签,请跳过。
另外,手动准备字幕时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
### 使用 BLIP 添加字幕
最新版本不再需要 BLIP 下载、权重下载和额外的虚拟环境。按原样工作。
运行 finetune 文件夹中的 make_captions.py。
```
python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ>
```
如果batch size为8训练数据放在父文件夹train_data中则会如下所示
```
python finetune\make_captions.py --batch_size 8 ..\train_data
```
字幕文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快我认为 12GB 的 VRAM 可以多一点)。
您可以使用 max_length 选项指定标题的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
您可以使用 caption_extension 选项更改标题扩展名。默认为 .caption.txt 与稍后描述的 DeepDanbooru 冲突)。
如果有多个教师数据文件夹,则对每个文件夹执行。
请注意,推理是随机的,因此每次运行时结果都会发生变化。如果要修复它,请使用 --seed 选项指定一个随机数种子,例如 `--seed 42`
其他的选项请参考help with `--help`(好像没有文档说明参数的含义,得看源码)。
默认情况下,会生成扩展名为 .caption 的字幕文件。
![caption生成的文件夹](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png)
例如,标题如下:
![字幕和图像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
## 由 DeepDanbooru 标记
如果不想给danbooru标签本身打标签请继续“标题和标签信息的预处理”。
标记是使用 DeepDanbooru 或 WD14Tagger 完成的。 WD14Tagger 似乎更准确。如果您想使用 WD14Tagger 进行标记,请跳至下一章。
### 环境布置
将 DeepDanbooru https://github.com/KichangKim/DeepDanbooru 克隆到您的工作文件夹中,或下载并展开 zip。我解压缩了它。
另外,从 DeepDanbooru 发布页面 https://github.com/KichangKim/DeepDanbooru/releases 上的“DeepDanbooru 预训练模型 v3-20211112-sgd-e28”的资产下载 deepdanbooru-v3-20211112-sgd-e28.zip 并解压到 DeepDanbooru 文件夹。
从下面下载。单击以打开资产并从那里下载。
![DeepDanbooru下载页面](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png)
做一个这样的目录结构
![DeepDanbooru的目录结构](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png)
为扩散器环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io
```
pip install -r requirements.txt
```
接下来,安装 DeepDanbooru 本身。
```
pip install .
```
这样就完成了标注环境的准备工作。
### 实施标记
转到 DeepDanbooru 的文件夹并运行 deepdanbooru 进行标记。
```
deepdanbooru evaluate <教师资料夹> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
```
如果将训练数据放在父文件夹train_data中则如下所示。
```
deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt
```
在与教师数据图像相同的目录中创建具有相同文件名和扩展名.txt 的标记文件。它很慢,因为它是一个接一个地处理的。
如果有多个教师数据文件夹,则对每个文件夹执行。
它生成如下。
![DeepDanbooru生成的文件](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png)
它会被这样标记(信息量很大...)。
![DeepDanbooru标签和图片](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png)
## WD14Tagger标记为
此过程使用 WD14Tagger 而不是 DeepDanbooru。
使用 Mr. Automatic1111 的 WebUI 中使用的标记器。我参考了这个 github 页面上的信息 (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger)。
初始环境维护所需的模块已经安装。权重自动从 Hugging Face 下载。
### 实施标记
运行脚本以进行标记。
```
python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ>
```
如果将训练数据放在父文件夹train_data中则如下所示
```
python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
```
模型文件将在首次启动时自动下载到 wd14_tagger_model 文件夹(文件夹可以在选项中更改)。它将如下所示。
![下载文件](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png)
在与教师数据图像相同的目录中创建具有相同文件名和扩展名.txt 的标记文件。
![生成的标签文件](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![标签和图片](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
使用 thresh 选项,您可以指定确定的标签的置信度数以附加标签。默认值为 0.35,与 WD14Tagger 示例相同。较低的值给出更多的标签,但准确性较低。
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快我认为 12GB 的 VRAM 可以多一点)。您可以使用 caption_extension 选项更改标记文件扩展名。默认为 .txt。
您可以使用 model_dir 选项指定保存模型的文件夹。
此外,如果指定 force_download 选项,即使有保存目标文件夹,也会重新下载模型。
如果有多个教师数据文件夹,则对每个文件夹执行。
## 预处理字幕和标签信息
将字幕和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
### 字幕预处理
要将字幕放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用字幕进行学习,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
```
python merge_captions_to_metadata.py --full_path <教师资料夹>
  --in_json <要读取的元数据文件名> <元数据文件名>
```
元数据文件名是任意名称。
如果训练数据为train_data没有读取元数据文件元数据文件为meta_cap.json则会如下。
```
python merge_captions_to_metadata.py --full_path train_data meta_cap.json
```
您可以使用 caption_extension 选项指定标题扩展。
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行。
```
python merge_captions_to_metadata.py --full_path
train_data1 meta_cap1.json
python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
train_data2 meta_cap2.json
```
如果省略in_json如果有写入目标元数据文件将从那里读取并覆盖。
__* 每次重写 in_json 选项和写入目标并写入单独的元数据文件是安全的。 __
### 标签预处理
同样,标签也收集在元数据中(如果标签不用于学习,则无需这样做)。
```
python merge_dd_tags_to_metadata.py --full_path <教师资料夹>
--in_json <要读取的元数据文件名> <要写入的元数据文件名>
```
同样的目录结构读取meta_cap.json和写入meta_cap_dd.json时会是这样的。
```
python merge_dd_tags_to_metadata.py --full_path train_data --in_json meta_cap.json meta_cap_dd.json
```
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行。
```
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json
train_data1 meta_cap_dd1.json
python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json
train_data2 meta_cap_dd2.json
```
如果省略in_json如果有写入目标元数据文件将从那里读取并覆盖。
__※ 通过每次重写 in_json 选项和写入目标,写入单独的元数据文件是安全的。 __
### 标题和标签清理
到目前为止标题和DeepDanbooru标签已经被整理到元数据文件中。然而自动标题生成的标题存在表达差异等微妙问题而标签中可能包含下划线和评级DeepDanbooru的情况下。因此最好使用编辑器的替换功能清理标题和标签。
※例如如果要学习动漫中的女孩标题可能会包含girl/girls/woman/women等不同的表达方式。另外将"anime girl"简单地替换为"girl"可能更合适。
我们提供了用于清理的脚本,请根据情况编辑脚本并使用它。
(不需要指定教师数据文件夹。将清理元数据中的所有数据。)
```
python clean_captions_and_tags.py <要读取的元数据文件名> <要写入的元数据文件名>
```
--in_json 请注意,不包括在内。例如:
```
python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
```
标题和标签的预处理现已完成。
## 预先获取 latents
※ 这一步骤并非必须。即使省略此步骤,也可以在训练过程中获取 latents。但是如果在训练时执行 `random_crop``color_aug` 等操作,则无法预先获取 latents因为每次图像都会改变。如果不进行预先获取则可以使用到目前为止的元数据进行训练。
提前获取图像的潜在表达并保存到磁盘上。这样可以加速训练过程。同时进行 bucketing根据宽高比对训练数据进行分类
请在工作文件夹中输入以下内容。
```
python prepare_buckets_latents.py --full_path <教师资料夹>
<要读取的元数据文件名> <要写入的元数据文件名>
<要微调的模型名称或检查点>
--batch_size <批量大小>
--max_resolution <分辨率宽、高>
--mixed_precision <准确性>
```
如果要从meta_clean.json中读取元数据并将其写入meta_lat.json使用模型model.ckpt批处理大小为4训练分辨率为512*512精度为nofloat32则应如下所示。
```
python prepare_buckets_latents.py --full_path
train_data meta_clean.json meta_lat.json model.ckpt
--batch_size 4 --max_resolution 512,512 --mixed_precision no
```
教师数据文件夹中latents以numpy的npz格式保存。
您可以使用--min_bucket_reso选项指定最小分辨率大小--max_bucket_reso指定最大大小。默认值分别为256和1024。例如如果指定最小大小为384则将不再使用分辨率为256 * 1024或320 * 768等。如果将分辨率增加到768 * 768等较大的值则最好将最大大小指定为1280等。
如果指定--flip_aug选项则进行左右翻转的数据增强。虽然这可以使数据量伪造一倍但如果数据不是左右对称的例如角色外观、发型等则可能会导致训练不成功。
对于翻转的图像也会获取latents并保存名为\ *_flip.npz的文件这是一个简单的实现。在fline_tune.py中不需要特定的选项。如果有带有\_flip的文件则会随机加载带有和不带有flip的文件。
即使VRAM为12GB批量大小也可以稍微增加。分辨率以“宽度高度”的形式指定必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中512,512似乎是极限*。如果有16GB则可以将其提高到512,704或512,768。即使分辨率为256,256等VRAM 8GB也很难承受因为参数、优化器等与分辨率无关需要一定的内存
*有报道称在batch size为1的训练中使用12GB VRAM和640,640的分辨率。
以下是bucketing结果的显示方式。
![bucketing的結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png)
如果有多个教师数据文件夹,请指定 full_path 参数并为每个文件夹执行
```
python prepare_buckets_latents.py --full_path
train_data1 meta_clean.json meta_lat1.json model.ckpt
--batch_size 4 --max_resolution 512,512 --mixed_precision no
python prepare_buckets_latents.py --full_path
train_data2 meta_lat1.json meta_lat2.json model.ckpt
--batch_size 4 --max_resolution 512,512 --mixed_precision no
```
可以将读取源和写入目标设为相同,但分开设定更为安全。
__※建议每次更改参数并将其写入另一个元数据文件以确保安全性。__

View File

@ -25,6 +25,7 @@ from library.config_util import (
import library.custom_train_functions as custom_train_functions import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
def train(args): def train(args):
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False) train_util.prepare_dataset_args(args, False)
@ -273,18 +274,19 @@ def train(args):
# Get the text embedding for conditioning # Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions: if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer, encoder_hidden_states = get_weighted_text_embeddings(
text_encoder, tokenizer,
batch["captions"], text_encoder,
accelerator.device, batch["captions"],
args.max_token_length // 75 if args.max_token_length else 1, accelerator.device,
clip_skip=args.clip_skip, args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
) )
else: else:
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
) )
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@ -335,6 +337,27 @@ def train(args):
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
) )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@ -364,21 +387,24 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path if accelerator.is_main_process:
train_util.save_sd_model_on_epoch_end( # checking for saving is in util
args, src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
accelerator, train_util.save_sd_model_on_epoch_end_or_stepwise(
src_path, args,
save_stable_diffusion_format, True,
use_safetensors, accelerator,
save_dtype, src_path,
epoch, save_stable_diffusion_format,
num_train_epochs, use_safetensors,
global_step, save_dtype,
unwrap_model(text_encoder), epoch,
unwrap_model(unet), num_train_epochs,
vae, global_step,
) unwrap_model(text_encoder),
unwrap_model(unet),
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
@ -389,7 +415,7 @@ def train(args):
accelerator.end_training() accelerator.end_training()
if args.save_state: if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator) train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す del accelerator # この後メモリを使うのでこれは消す
@ -434,4 +460,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)
train(args) train(args)

View File

@ -549,6 +549,27 @@ def train(args):
# else: # else:
# on_step_start = lambda *args, **kwargs: None # on_step_start = lambda *args, **kwargs: None
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
metadata["ss_training_finished_at"] = str(time.time())
metadata["ss_steps"] = str(steps)
metadata["ss_epoch"] = str(epoch_no)
unwrapped_nw.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process: if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
@ -638,6 +659,21 @@ def train(args):
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
) )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, unwrap_model(network), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if epoch == 0: if epoch == 0:
loss_list.append(current_loss) loss_list.append(current_loss)
@ -662,35 +698,26 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1)
def save_func(): remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as if remove_epoch_no is not None:
ckpt_file = os.path.join(args.output_dir, ckpt_name) remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
metadata["ss_training_finished_at"] = str(time.time()) remove_model(remove_ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no): if args.save_state:
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
if is_main_process:
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch # end of epoch
metadata["ss_epoch"] = str(num_train_epochs) # metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
if is_main_process: if is_main_process:
@ -698,22 +725,15 @@ def train(args):
accelerator.end_training() accelerator.end_training()
if args.save_state: if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator) train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す del accelerator # この後メモリを使うのでこれは消す
if is_main_process: if is_main_process:
os.makedirs(args.output_dir, exist_ok=True) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.") print("model saved.")

View File

@ -339,6 +339,23 @@ def train(args):
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
@ -423,6 +440,23 @@ def train(args):
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
) )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, updated_embs, global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@ -449,26 +483,18 @@ def train(args):
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if accelerator.is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
def save_func(): remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as if remove_epoch_no is not None:
ckpt_file = os.path.join(args.output_dir, ckpt_name) remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
print(f"saving checkpoint: {ckpt_file}") remove_model(remove_ckpt_name)
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no): if args.save_state:
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
train_util.sample_images( train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
@ -482,7 +508,7 @@ def train(args):
accelerator.end_training() accelerator.end_training()
if args.save_state: if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator) train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone() updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
@ -490,16 +516,9 @@ def train(args):
del accelerator # この後メモリを使うのでこれは消す del accelerator # この後メモリを使うのでこれは消す
if is_main_process: if is_main_process:
os.makedirs(args.output_dir, exist_ok=True) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.") print("model saved.")

View File

@ -373,6 +373,23 @@ def train(args):
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
@ -462,6 +479,23 @@ def train(args):
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# ) # )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, updated_embs, global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
@ -488,26 +522,18 @@ def train(args):
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if accelerator.is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
def save_func(): remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as if remove_epoch_no is not None:
ckpt_file = os.path.join(args.output_dir, ckpt_name) remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
print(f"saving checkpoint: {ckpt_file}") remove_model(remove_ckpt_name)
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no): if args.save_state:
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# TODO: fix sample_images # TODO: fix sample_images
# train_util.sample_images( # train_util.sample_images(
@ -522,7 +548,7 @@ def train(args):
accelerator.end_training() accelerator.end_training()
if args.save_state: if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator) train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone() updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
@ -530,16 +556,9 @@ def train(args):
del accelerator # この後メモリを使うのでこれは消す del accelerator # この後メモリを使うのでこれは消す
if is_main_process: if is_main_process:
os.makedirs(args.output_dir, exist_ok=True) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.") print("model saved.")