diff --git a/patches/external_pr/dataset.py b/patches/external_pr/dataset.py index 7f068d4..0af6477 100644 --- a/patches/external_pr/dataset.py +++ b/patches/external_pr/dataset.py @@ -47,12 +47,16 @@ class DatasetEntry: class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, - shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1): + shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1, manual_seed=-1): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len( shared.opts.dataset_filename_word_regex) > 0 else None - seed = randrange(sys.maxsize) - set_rng(seed) # reset forked RNG state when we create dataset. - print(f"Dataset seed was set to f{seed}") + if manual_seed == -1: + seed = randrange(sys.maxsize) + set_rng(seed) # reset forked RNG state when we create dataset. + print(f"Dataset seed was set to f{seed}") + else: + set_rng(manual_seed) + print(f"Dataset seed was set to f{manual_seed}") self.placeholder_token = placeholder_token self.width = width diff --git a/patches/external_pr/hypernetwork.py b/patches/external_pr/hypernetwork.py index 98544ab..7c45aa7 100644 --- a/patches/external_pr/hypernetwork.py +++ b/patches/external_pr/hypernetwork.py @@ -485,10 +485,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height, move_optimizer=True, - load_hypernetworks_option='', load_training_options=''): + load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images base_hypernetwork_name = hypernetwork_name + manual_seed = int(manual_dataset_seed) if load_hypernetworks_option != '': timeStr = time.strftime('%Y%m%d%H%M%S') dump_hyper: dict = get_training_option(load_hypernetworks_option) @@ -687,7 +688,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, - latent_sampling_std=latent_sampling_std) + latent_sampling_std=latent_sampling_std, + manual_seed=manual_seed) latent_sampling_method = ds.latent_sampling_method @@ -977,7 +979,7 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo preview_width, preview_height, move_optimizer=True, optional_new_hypernetwork_name='', load_hypernetworks_options='', - load_training_options=''): + load_training_options='', manual_dataset_seed=-1): load_hypernetworks_options = load_hypernetworks_options.split(',') load_training_options = load_training_options.split(',') # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -998,6 +1000,6 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height, move_optimizer, - load_hypernetworks_option, load_training_option) + load_hypernetworks_option, load_training_option, manual_dataset_seed) if shared.state.interrupted: return None, None diff --git a/patches/external_pr/ui.py b/patches/external_pr/ui.py index 7230554..497a296 100644 --- a/patches/external_pr/ui.py +++ b/patches/external_pr/ui.py @@ -384,6 +384,9 @@ def on_train_tuning(params=None): label='Save a copy of model to log directory every N steps, 0 to disable', value=500, precision=0) preview_from_txt2img = gr.Checkbox( label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + manual_dataset_seed = gr.Number( + label="Manual dataset seed", value=-1, precision=0 + ) with gr.Row(): interrupt_training = gr.Button(value="Interrupt") train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') @@ -404,7 +407,8 @@ def on_train_tuning(params=None): move_optim_when_generate, optional_new_hypernetwork_name, load_hypernetworks_option, - load_training_options + load_training_options, + manual_dataset_seed ], outputs=[ ti_output,