import gradio as gr import json from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel from .class_folders import Folders from .class_basic_training import BasicTraining from .class_advanced_training import AdvancedTraining from .class_sample_images import SampleImages from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) from .dataset_balancing_gui import gradio_dataset_balancing_tab from .common_gui import color_aug_changed class Dreambooth: def __init__( self, headless: bool = False, ): self.headless = headless self.dummy_db_true = gr.Label(value=True, visible=False) self.dummy_db_false = gr.Label(value=False, visible=False) self.dummy_headless = gr.Label(value=headless, visible=False) gr.Markdown( 'Train a custom model using kohya dreambooth python code...' ) # Setup Configuration Files Gradio self.config = ConfigurationFile(headless) self.source_model = SourceModel(headless=headless) with gr.Tab('Folders'): self.folders = Folders(headless=headless) with gr.Tab('Parameters'): self.basic_training = BasicTraining( learning_rate_value='1e-5', lr_scheduler_value='cosine', lr_warmup_value='10', ) self.full_bf16 = gr.Checkbox(label='Full bf16', value=False) with gr.Accordion('Advanced Configuration', open=False): self.advanced_training = AdvancedTraining(headless=headless) self.advanced_training.color_aug.change( color_aug_changed, inputs=[self.advanced_training.color_aug], outputs=[self.basic_training.cache_latents], ) self.sample = SampleImages() with gr.Tab('Dataset Preparation'): gr.Markdown( 'This section provide Dreambooth tools to help setup your dataset...' ) gradio_dreambooth_folder_creation_tab( train_data_dir_input=self.folders.train_data_dir, reg_data_dir_input=self.folders.reg_data_dir, output_dir_input=self.folders.output_dir, logging_dir_input=self.folders.logging_dir, headless=headless, ) gradio_dataset_balancing_tab(headless=headless) def save_to_json(self, filepath): def serialize(obj): if isinstance(obj, gr.inputs.Input): return obj.get() if isinstance(obj, (bool, int, float, str)): return obj if isinstance(obj, dict): return {k: serialize(v) for k, v in obj.items()} if hasattr(obj, '__dict__'): return serialize(vars(obj)) return str(obj) # Fallback for objects that can't be serialized try: with open(filepath, 'w') as outfile: print(serialize(vars(self))) json.dump(serialize(vars(self)), outfile) except Exception as e: print(f'Error saving to JSON: {str(e)}') def load_from_json(self, filepath): def deserialize(key, value): if hasattr(self, key): attr = getattr(self, key) if isinstance(attr, gr.inputs.Input): attr.set(value) elif hasattr(attr, '__dict__'): for k, v in value.items(): deserialize(k, v) else: setattr(self, key, value) else: print(f"Warning: {key} not found in the object's attributes.") try: with open(filepath) as json_file: data = json.load(json_file) for key, value in data.items(): deserialize(key, value) except FileNotFoundError: print(f'Error: The file {filepath} was not found.') except json.JSONDecodeError: print(f'Error: The file {filepath} could not be decoded as JSON.') except Exception as e: print(f'Error loading from JSON: {str(e)}')