add manual dataset seed option
parent
5d82d0d36e
commit
da120135a8
|
|
@ -47,12 +47,16 @@ class DatasetEntry:
|
|||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None,
|
||||
cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1,
|
||||
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1):
|
||||
shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', latent_sampling_std=-1, manual_seed=-1):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(
|
||||
shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
seed = randrange(sys.maxsize)
|
||||
set_rng(seed) # reset forked RNG state when we create dataset.
|
||||
print(f"Dataset seed was set to f{seed}")
|
||||
if manual_seed == -1:
|
||||
seed = randrange(sys.maxsize)
|
||||
set_rng(seed) # reset forked RNG state when we create dataset.
|
||||
print(f"Dataset seed was set to f{seed}")
|
||||
else:
|
||||
set_rng(manual_seed)
|
||||
print(f"Dataset seed was set to f{manual_seed}")
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.width = width
|
||||
|
|
|
|||
|
|
@ -485,10 +485,11 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
|
||||
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height,
|
||||
move_optimizer=True,
|
||||
load_hypernetworks_option='', load_training_options=''):
|
||||
load_hypernetworks_option='', load_training_options='', manual_dataset_seed=-1):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
base_hypernetwork_name = hypernetwork_name
|
||||
manual_seed = int(manual_dataset_seed)
|
||||
if load_hypernetworks_option != '':
|
||||
timeStr = time.strftime('%Y%m%d%H%M%S')
|
||||
dump_hyper: dict = get_training_option(load_hypernetworks_option)
|
||||
|
|
@ -687,7 +688,8 @@ def internal_clean_training(hypernetwork_name, data_root, log_directory,
|
|||
gradient_step=gradient_step, shuffle_tags=shuffle_tags,
|
||||
tag_drop_out=tag_drop_out,
|
||||
latent_sampling_method=latent_sampling_method,
|
||||
latent_sampling_std=latent_sampling_std)
|
||||
latent_sampling_std=latent_sampling_std,
|
||||
manual_seed=manual_seed)
|
||||
|
||||
latent_sampling_method = ds.latent_sampling_method
|
||||
|
||||
|
|
@ -977,7 +979,7 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
|
|||
preview_width, preview_height,
|
||||
move_optimizer=True,
|
||||
optional_new_hypernetwork_name='', load_hypernetworks_options='',
|
||||
load_training_options=''):
|
||||
load_training_options='', manual_dataset_seed=-1):
|
||||
load_hypernetworks_options = load_hypernetworks_options.split(',')
|
||||
load_training_options = load_training_options.split(',')
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
|
|
@ -998,6 +1000,6 @@ def train_hypernetwork_tuning(id_task, hypernetwork_name, data_root, log_directo
|
|||
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index,
|
||||
preview_cfg_scale, preview_seed, preview_width, preview_height,
|
||||
move_optimizer,
|
||||
load_hypernetworks_option, load_training_option)
|
||||
load_hypernetworks_option, load_training_option, manual_dataset_seed)
|
||||
if shared.state.interrupted:
|
||||
return None, None
|
||||
|
|
|
|||
|
|
@ -384,6 +384,9 @@ def on_train_tuning(params=None):
|
|||
label='Save a copy of model to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
preview_from_txt2img = gr.Checkbox(
|
||||
label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||
manual_dataset_seed = gr.Number(
|
||||
label="Manual dataset seed", value=-1, precision=0
|
||||
)
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||
|
|
@ -404,7 +407,8 @@ def on_train_tuning(params=None):
|
|||
move_optim_when_generate,
|
||||
optional_new_hypernetwork_name,
|
||||
load_hypernetworks_option,
|
||||
load_training_options
|
||||
load_training_options,
|
||||
manual_dataset_seed
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
|
|
|
|||
Loading…
Reference in New Issue