add manual dataset seed option

noise-scheduler
aria1th 2023-01-19 06:34:57 +09:00
parent 5d82d0d36e
commit da120135a8
3 changed files with 19 additions and 9 deletions

View File

@ -47,12 +47,16 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None,
cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1,
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1):
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1, manual_seed=-1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(
shared.opts.dataset_filename_word_regex) > 0 else None
seed = randrange(sys.maxsize)
set_rng(seed) # reset forked RNG state when we create dataset.
print(f"Dataset seed was set to f{seed}")
if manual_seed == -1:
seed = randrange(sys.maxsize)
set_rng(seed) # reset forked RNG state when we create dataset.
print(f"Dataset seed was set to f{seed}")
else:
set_rng(manual_seed)
print(f"Dataset seed was set to f{manual_seed}")
self.placeholder_token = placeholder_token
self.width = width

View File

@ -485,10 +485,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
move_optimizer=True,
load_hypernetworks_option='', load_training_options=''):
load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
base_hypernetwork_name = hypernetwork_name
manual_seed = int(manual_dataset_seed)
if load_hypernetworks_option != '':
timeStr = time.strftime('%Y%m%d%H%M%S')
dump_hyper: dict = get_training_option(load_hypernetworks_option)
@ -687,7 +688,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
tag_drop_out=tag_drop_out,
latent_sampling_method=latent_sampling_method,
latent_sampling_std=latent_sampling_std)
latent_sampling_std=latent_sampling_std,
manual_seed=manual_seed)
latent_sampling_method = ds.latent_sampling_method
@ -977,7 +979,7 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
preview_width, preview_height,
move_optimizer=True,
optional_new_hypernetwork_name='', load_hypernetworks_options='',
load_training_options=''):
load_training_options='', manual_dataset_seed=-1):
load_hypernetworks_options = load_hypernetworks_options.split(',')
load_training_options = load_training_options.split(',')
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
@ -998,6 +1000,6 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index,
preview_cfg_scale, preview_seed, preview_width, preview_height,
move_optimizer,
load_hypernetworks_option, load_training_option)
load_hypernetworks_option, load_training_option, manual_dataset_seed)
if shared.state.interrupted:
return None, None

View File

@ -384,6 +384,9 @@ def on_train_tuning(params=None):
label='Save a copy of model to log directory every N steps, 0 to disable', value=500, precision=0)
preview_from_txt2img = gr.Checkbox(
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
manual_dataset_seed = gr.Number(
label="Manual dataset seed", value=-1, precision=0
)
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
@ -404,7 +407,8 @@ def on_train_tuning(params=None):
move_optim_when_generate,
optional_new_hypernetwork_name,
load_hypernetworks_option,
load_training_options
load_training_options,
manual_dataset_seed
],
outputs=[
ti_output,