mirror of https://github.com/bmaltais/kohya_ss
Make Start/Stop buttons visible in headless (#2356)
parent
5a801649c2
commit
4923f5647c
16
README.md
16
README.md
|
|
@ -42,10 +42,12 @@ The GUI allows you to set the training parameters and generate and run the requi
|
|||
- [SDXL training](#sdxl-training)
|
||||
- [Masked loss](#masked-loss)
|
||||
- [Change History](#change-history)
|
||||
- [2024/04/220 (v24.0.6)](#202404220-v2406)
|
||||
- [2024/04/19 (v24.0.5)](#20240419-v2405)
|
||||
- [New Contributors](#new-contributors)
|
||||
- [2024/04/18 (v24.0.4)](#20240418-v2404)
|
||||
- [What's Changed](#whats-changed)
|
||||
- [New Contributors](#new-contributors)
|
||||
- [New Contributors](#new-contributors-1)
|
||||
- [2024/04/24 (v24.0.3)](#20240424-v2403)
|
||||
- [2024/04/24 (v24.0.2)](#20240424-v2402)
|
||||
- [2024/04/17 (v24.0.1)](#20240417-v2401)
|
||||
|
|
@ -412,9 +414,19 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
|
|||
|
||||
## Change History
|
||||
|
||||
### 2024/04/220 (v24.0.6)
|
||||
|
||||
- Make start and stop buttons visible in headless
|
||||
|
||||
### 2024/04/19 (v24.0.5)
|
||||
|
||||
- fdds
|
||||
- Hide tensorboard button if tensorflow module is not installed by @bmaltais in <https://github.com/bmaltais/kohya_ss/pull/2347>
|
||||
- wd14 captioning issue with undesired tags nor tag replacement by @bmaltais in <https://github.com/bmaltais/kohya_ss/pull/2350>
|
||||
- Changed logger checkbox to dropdown, renamed use_wandb -> log_with by @ccharest93 in <https://github.com/bmaltais/kohya_ss/pull/2352>
|
||||
|
||||
#### New Contributors
|
||||
|
||||
- @ccharest93 made their first contribution in <https://github.com/bmaltais/kohya_ss/pull/2352>
|
||||
|
||||
### 2024/04/18 (v24.0.4)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,11 +37,11 @@
|
|||
#myTensorButton {
|
||||
background: radial-gradient(ellipse, #3a99ff, #52c8ff);
|
||||
color: white;
|
||||
border: none;
|
||||
border: #296eb8;
|
||||
}
|
||||
|
||||
#myTensorButtonStop {
|
||||
background: radial-gradient(ellipse, #52c8ff, #3a99ff);
|
||||
color: black;
|
||||
border: none;
|
||||
border: #296eb8;
|
||||
}
|
||||
|
|
@ -14,12 +14,19 @@ class CommandExecutor:
|
|||
A class to execute and manage commands.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, headless: bool = False):
|
||||
"""
|
||||
Initialize the CommandExecutor.
|
||||
"""
|
||||
self.headless = headless
|
||||
self.process = None
|
||||
self.run_state = gr.Textbox(value="", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
self.button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
self.button_stop_training = gr.Button(
|
||||
"Stop training", visible=self.process is not None or headless, variant="stop"
|
||||
)
|
||||
|
||||
def execute_command(self, run_cmd: str, use_shell: bool = False, **kwargs):
|
||||
"""
|
||||
|
|
@ -64,16 +71,17 @@ class CommandExecutor:
|
|||
# General exception handling for any other errors
|
||||
log.info(f"Error when terminating process: {e}")
|
||||
else:
|
||||
self.process = None
|
||||
log.info("There is no running process to kill.")
|
||||
|
||||
return gr.Button(visible=True), gr.Button(visible=False)
|
||||
return gr.Button(visible=True), gr.Button(visible=False or self.headless)
|
||||
|
||||
def wait_for_training_to_end(self):
|
||||
while self.is_running():
|
||||
time.sleep(1)
|
||||
log.debug("Waiting for training to end...")
|
||||
log.info("Training has ended.")
|
||||
return gr.Button(visible=True), gr.Button(visible=False)
|
||||
return gr.Button(visible=True), gr.Button(visible=False or self.headless)
|
||||
|
||||
def is_running(self):
|
||||
"""
|
||||
|
|
@ -82,4 +90,4 @@ class CommandExecutor:
|
|||
Returns:
|
||||
- bool: True if the command is running, False otherwise.
|
||||
"""
|
||||
return self.process and self.process.poll() is None
|
||||
return self.process is not None and self.process.poll() is None
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ from easygui import msgbox
|
|||
from threading import Thread, Event
|
||||
from .custom_logging import setup_logging
|
||||
|
||||
|
||||
class TensorboardManager:
|
||||
DEFAULT_TENSORBOARD_PORT = 6006
|
||||
|
||||
def __init__(self, logging_dir, headless=True, wait_time=5):
|
||||
def __init__(self, logging_dir, headless: bool = False, wait_time=5):
|
||||
self.logging_dir = logging_dir
|
||||
self.headless = headless
|
||||
self.wait_time = wait_time
|
||||
|
|
@ -25,9 +26,17 @@ class TensorboardManager:
|
|||
self.gradio_interface()
|
||||
|
||||
def get_button_states(self, started=False):
|
||||
return gr.Button(visible=not started), gr.Button(visible=started)
|
||||
return gr.Button(visible=not started or self.headless), gr.Button(
|
||||
visible=started or self.headless
|
||||
)
|
||||
|
||||
def start_tensorboard(self, logging_dir=None):
|
||||
if self.tensorboard_proc is not None:
|
||||
self.log.info(
|
||||
"Tensorboard is already running. Terminating existing process before starting new one..."
|
||||
)
|
||||
self.stop_tensorboard()
|
||||
|
||||
if not os.path.exists(logging_dir) or not os.listdir(logging_dir):
|
||||
self.log.error(
|
||||
"Error: logging folder does not exist or does not contain logs."
|
||||
|
|
@ -46,11 +55,6 @@ class TensorboardManager:
|
|||
]
|
||||
|
||||
self.log.info(run_cmd)
|
||||
if self.tensorboard_proc is not None:
|
||||
self.log.info(
|
||||
"Tensorboard is already running. Terminating existing process before starting new one..."
|
||||
)
|
||||
self.stop_tensorboard()
|
||||
|
||||
self.log.info("Starting TensorBoard on port {}".format(self.tensorboard_port))
|
||||
try:
|
||||
|
|
@ -73,7 +77,7 @@ class TensorboardManager:
|
|||
self.thread = Thread(target=open_tensorboard_url)
|
||||
self.thread.start()
|
||||
|
||||
return self.get_button_states(started=True)
|
||||
return self.get_button_states(started=True or self.headless)
|
||||
|
||||
def stop_tensorboard(self):
|
||||
if self.tensorboard_proc is not None:
|
||||
|
|
@ -84,34 +88,38 @@ class TensorboardManager:
|
|||
self.log.info("...process stopped")
|
||||
except Exception as e:
|
||||
self.log.error("Failed to stop Tensorboard:", e)
|
||||
|
||||
|
||||
if self.thread is not None:
|
||||
self.stop_event.set()
|
||||
self.thread.join() # Wait for the thread to finish
|
||||
self.thread = None
|
||||
self.log.info("Thread terminated successfully.")
|
||||
|
||||
return self.get_button_states(started=False)
|
||||
return self.get_button_states(started=False or self.headless)
|
||||
|
||||
def gradio_interface(self):
|
||||
try:
|
||||
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
||||
|
||||
import tensorflow # Attempt to import tensorflow to check if it is installed
|
||||
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
import tensorflow # Attempt to import tensorflow to check if it is installed
|
||||
|
||||
visibility = True
|
||||
|
||||
|
||||
except ImportError:
|
||||
self.log.error("tensorflow is not installed, hiding the tensorboard button...")
|
||||
self.log.error(
|
||||
"tensorflow is not installed, hiding the tensorboard button..."
|
||||
)
|
||||
visibility = False
|
||||
|
||||
|
||||
with gr.Row():
|
||||
button_start_tensorboard = gr.Button(
|
||||
value="Start tensorboard", elem_id="myTensorButton", visible=visibility
|
||||
value="Start tensorboard",
|
||||
elem_id="myTensorButton",
|
||||
visible=visibility or self.headless,
|
||||
)
|
||||
button_stop_tensorboard = gr.Button(
|
||||
value="Stop tensorboard",
|
||||
visible=False,
|
||||
visible=False or self.headless,
|
||||
elem_id="myTensorButtonStop",
|
||||
)
|
||||
button_start_tensorboard.click(
|
||||
|
|
@ -124,4 +132,4 @@ class TensorboardManager:
|
|||
self.stop_tensorboard,
|
||||
outputs=[button_start_tensorboard, button_stop_tensorboard],
|
||||
show_progress=False,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,17 +43,12 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
# Setup command executor
|
||||
executor = CommandExecutor()
|
||||
executor = None
|
||||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
gr.Textbox(value=time.time()),
|
||||
]
|
||||
train_state_value = time.time()
|
||||
|
||||
|
||||
def save_configuration(
|
||||
|
|
@ -495,6 +490,17 @@ def train_model(
|
|||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
global train_state_value
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Textbox(value=train_state_value),
|
||||
]
|
||||
|
||||
if executor.is_running():
|
||||
log.error("Training is already running. Can't start another training session.")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
log.info(f"Start training Dreambooth...")
|
||||
|
||||
|
|
@ -855,11 +861,13 @@ def train_model(
|
|||
# Run the command
|
||||
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
train_state_value = time.time()
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Button(visible=True),
|
||||
gr.Textbox(value=time.time()),
|
||||
gr.Textbox(value=train_state_value),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -940,21 +948,15 @@ def dreambooth_tab(
|
|||
with gr.Accordion("HuggingFace", open=False):
|
||||
huggingface = HuggingFace(config=config)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button(
|
||||
"Stop training", visible=False, variant="stop"
|
||||
)
|
||||
|
||||
global executor
|
||||
executor = CommandExecutor(headless=headless)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_print = gr.Button("Print training command")
|
||||
|
||||
# Setup gradio tensorboard buttons
|
||||
with gr.Column(), gr.Group():
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
|
||||
settings_list = [
|
||||
source_model.pretrained_model_name_or_path,
|
||||
|
|
@ -1100,37 +1102,23 @@ def dreambooth_tab(
|
|||
outputs=[configuration.config_file_name],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# config.button_save_as_config.click(
|
||||
# save_configuration,
|
||||
# inputs=[dummy_db_true, config.config_file_name] + settings_list,
|
||||
# outputs=[config.config_file_name],
|
||||
# show_progress=False,
|
||||
# )
|
||||
|
||||
# def wait_for_training_to_end():
|
||||
# while executor.is_running():
|
||||
# time.sleep(1)
|
||||
# log.debug("Waiting for training to end...")
|
||||
# log.info("Training has ended.")
|
||||
# return gr.Button(visible=True), gr.Button(visible=False)
|
||||
|
||||
# Hidden textbox used to run the wait_for_training_to_end function to hide stop and show start at the end of the training
|
||||
run_state = gr.Textbox(value="", visible=False)
|
||||
|
||||
run_state = gr.Textbox(value=train_state_value, visible=False)
|
||||
|
||||
run_state.change(
|
||||
fn=executor.wait_for_training_to_end,
|
||||
outputs=[button_run, button_stop_training],
|
||||
outputs=[executor.button_run, executor.button_stop_training],
|
||||
)
|
||||
|
||||
button_run.click(
|
||||
executor.button_run.click(
|
||||
train_model,
|
||||
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
|
||||
outputs=[button_run, button_stop_training, run_state],
|
||||
outputs=[executor.button_run, executor.button_stop_training, run_state],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_stop_training.click(
|
||||
executor.kill_command, outputs=[button_run, button_stop_training]
|
||||
executor.button_stop_training.click(
|
||||
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
|
||||
)
|
||||
|
||||
button_print.click(
|
||||
|
|
|
|||
|
|
@ -40,13 +40,12 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
# Setup command executor
|
||||
executor = CommandExecutor()
|
||||
executor = None
|
||||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
# from easygui import msgbox
|
||||
train_state_value = time.time()
|
||||
|
||||
folder_symbol = "\U0001f4c2" # 📂
|
||||
refresh_symbol = "\U0001f504" # 🔄
|
||||
|
|
@ -56,11 +55,6 @@ document_symbol = "\U0001F4C4" # 📄
|
|||
PYTHON = sys.executable
|
||||
|
||||
presets_dir = rf"{scriptdir}/presets"
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
gr.Textbox(value=time.time()),
|
||||
]
|
||||
|
||||
|
||||
def save_configuration(
|
||||
|
|
@ -534,6 +528,17 @@ def train_model(
|
|||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
global train_state_value
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Textbox(value=train_state_value),
|
||||
]
|
||||
|
||||
if executor.is_running():
|
||||
log.error("Training is already running. Can't start another training session.")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
log.debug(f"headless = {headless} ; print_only = {print_only}")
|
||||
|
||||
|
|
@ -927,10 +932,12 @@ def train_model(
|
|||
# Run the command
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
train_state_value = time.time()
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Button(visible=True),
|
||||
gr.Textbox(value=time.time()),
|
||||
gr.Textbox(value=train_state_value),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1090,21 +1097,14 @@ def finetune_tab(
|
|||
with gr.Accordion("HuggingFace", open=False):
|
||||
huggingface = HuggingFace(config=config)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button(
|
||||
"Stop training", visible=False, variant="stop"
|
||||
)
|
||||
global executor
|
||||
executor = CommandExecutor(headless=headless)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_print = gr.Button("Print training command")
|
||||
|
||||
# Setup gradio tensorboard buttons
|
||||
with gr.Column(), gr.Group():
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
|
||||
settings_list = [
|
||||
source_model.pretrained_model_name_or_path,
|
||||
|
|
@ -1264,13 +1264,6 @@ def finetune_tab(
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
# config.button_load_config.click(
|
||||
# open_configuration,
|
||||
# inputs=[dummy_db_false, config.config_file_name] + settings_list,
|
||||
# outputs=[config.config_file_name] + settings_list,
|
||||
# show_progress=False,
|
||||
# )
|
||||
|
||||
training_preset.input(
|
||||
open_configuration,
|
||||
inputs=[dummy_db_false, dummy_db_true, configuration.config_file_name]
|
||||
|
|
@ -1280,22 +1273,22 @@ def finetune_tab(
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
# Hidden textbox used to run the wait_for_training_to_end function to hide stop and show start at the end of the training
|
||||
run_state = gr.Textbox(value="", visible=False)
|
||||
run_state = gr.Textbox(value=train_state_value, visible=False)
|
||||
|
||||
run_state.change(
|
||||
fn=executor.wait_for_training_to_end,
|
||||
outputs=[button_run, button_stop_training],
|
||||
outputs=[executor.button_run, executor.button_stop_training],
|
||||
)
|
||||
|
||||
button_run.click(
|
||||
executor.button_run.click(
|
||||
train_model,
|
||||
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
|
||||
outputs=[button_run, button_stop_training, run_state],
|
||||
outputs=[executor.button_run, executor.button_stop_training, run_state],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_stop_training.click(
|
||||
executor.kill_command, outputs=[button_run, button_stop_training]
|
||||
executor.button_stop_training.click(
|
||||
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
|
||||
)
|
||||
|
||||
button_print.click(
|
||||
|
|
|
|||
|
|
@ -47,27 +47,18 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
# Setup command executor
|
||||
executor = CommandExecutor()
|
||||
executor = None
|
||||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button("Stop training", visible=False)
|
||||
train_state_value = time.time()
|
||||
|
||||
document_symbol = "\U0001F4C4" # 📄
|
||||
|
||||
|
||||
presets_dir = rf"{scriptdir}/presets"
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
gr.Textbox(value=time.time()),
|
||||
]
|
||||
|
||||
|
||||
def save_configuration(
|
||||
save_as_bool,
|
||||
|
|
@ -674,7 +665,17 @@ def train_model(
|
|||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
global command_running
|
||||
global train_state_value
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Textbox(value=train_state_value),
|
||||
]
|
||||
|
||||
if executor.is_running():
|
||||
log.error("Training is already running. Can't start another training session.")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
log.info(f"Start training LoRA {LoRA_type} ...")
|
||||
|
||||
|
|
@ -1221,12 +1222,15 @@ def train_model(
|
|||
env["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
|
||||
# Run the command
|
||||
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
train_state_value = time.time()
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Button(visible=True),
|
||||
gr.Textbox(value=time.time()),
|
||||
gr.Textbox(value=train_state_value),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2084,21 +2088,15 @@ def lora_tab(
|
|||
],
|
||||
)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button(
|
||||
"Stop training", visible=False, variant="stop"
|
||||
)
|
||||
|
||||
global executor
|
||||
executor = CommandExecutor(headless=headless)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_print = gr.Button("Print training command")
|
||||
|
||||
# Setup gradio tensorboard buttons
|
||||
with gr.Column(), gr.Group():
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
|
||||
settings_list = [
|
||||
source_model.pretrained_model_name_or_path,
|
||||
|
|
@ -2301,29 +2299,22 @@ def lora_tab(
|
|||
show_progress=False,
|
||||
)
|
||||
|
||||
# config.button_save_as_config.click(
|
||||
# save_configuration,
|
||||
# inputs=[dummy_db_true, config.config_file_name] + settings_list,
|
||||
# outputs=[config.config_file_name],
|
||||
# show_progress=False,
|
||||
# )
|
||||
|
||||
# Hidden textbox used to run the wait_for_training_to_end function to hide stop and show start at the end of the training
|
||||
run_state = gr.Textbox(value="", visible=False)
|
||||
run_state = gr.Textbox(value=train_state_value, visible=False)
|
||||
|
||||
run_state.change(
|
||||
fn=executor.wait_for_training_to_end,
|
||||
outputs=[button_run, button_stop_training],
|
||||
outputs=[executor.button_run, executor.button_stop_training],
|
||||
)
|
||||
|
||||
button_run.click(
|
||||
executor.button_run.click(
|
||||
train_model,
|
||||
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
|
||||
outputs=[button_run, button_stop_training, run_state],
|
||||
outputs=[executor.button_run, executor.button_stop_training, run_state],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_stop_training.click(
|
||||
executor.kill_command, outputs=[button_run, button_stop_training]
|
||||
executor.button_stop_training.click(
|
||||
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
|
||||
)
|
||||
|
||||
button_print.click(
|
||||
|
|
|
|||
|
|
@ -45,17 +45,12 @@ from .custom_logging import setup_logging
|
|||
log = setup_logging()
|
||||
|
||||
# Setup command executor
|
||||
executor = CommandExecutor()
|
||||
executor = None
|
||||
|
||||
# Setup huggingface
|
||||
huggingface = None
|
||||
use_shell = False
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
gr.Textbox(value=time.time()),
|
||||
]
|
||||
train_state_value = time.time()
|
||||
|
||||
|
||||
def save_configuration(
|
||||
|
|
@ -496,6 +491,17 @@ def train_model(
|
|||
):
|
||||
# Get list of function parameters and values
|
||||
parameters = list(locals().items())
|
||||
global train_state_value
|
||||
|
||||
TRAIN_BUTTON_VISIBLE = [
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Textbox(value=train_state_value),
|
||||
]
|
||||
|
||||
if executor.is_running():
|
||||
log.error("Training is already running. Can't start another training session.")
|
||||
return TRAIN_BUTTON_VISIBLE
|
||||
|
||||
log.info(f"Start training TI...")
|
||||
|
||||
|
|
@ -874,11 +880,13 @@ def train_model(
|
|||
# Run the command
|
||||
|
||||
executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)
|
||||
|
||||
train_state_value = time.time()
|
||||
|
||||
return (
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=False or headless),
|
||||
gr.Button(visible=True),
|
||||
gr.Textbox(value=time.time()),
|
||||
gr.Textbox(value=train_state_value),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1052,21 +1060,15 @@ def ti_tab(
|
|||
with gr.Accordion("HuggingFace", open=False):
|
||||
huggingface = HuggingFace(config=config)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_run = gr.Button("Start training", variant="primary")
|
||||
|
||||
button_stop_training = gr.Button(
|
||||
"Stop training", visible=False, variant="stop"
|
||||
)
|
||||
|
||||
global executor
|
||||
executor = CommandExecutor(headless=headless)
|
||||
|
||||
with gr.Column(), gr.Group():
|
||||
with gr.Row():
|
||||
button_print = gr.Button("Print training command")
|
||||
|
||||
# Setup gradio tensorboard buttons
|
||||
with gr.Column(), gr.Group():
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
TensorboardManager(headless=headless, logging_dir=folders.logging_dir)
|
||||
|
||||
settings_list = [
|
||||
source_model.pretrained_model_name_or_path,
|
||||
|
|
@ -1211,30 +1213,23 @@ def ti_tab(
|
|||
outputs=[configuration.config_file_name],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
# config.button_save_as_config.click(
|
||||
# save_configuration,
|
||||
# inputs=[dummy_db_true, config.config_file_name] + settings_list,
|
||||
# outputs=[config.config_file_name],
|
||||
# show_progress=False,
|
||||
# )
|
||||
|
||||
# Hidden textbox used to run the wait_for_training_to_end function to hide stop and show start at the end of the training
|
||||
run_state = gr.Textbox(value="", visible=False)
|
||||
|
||||
run_state = gr.Textbox(value=train_state_value, visible=False)
|
||||
|
||||
run_state.change(
|
||||
fn=executor.wait_for_training_to_end,
|
||||
outputs=[button_run, button_stop_training],
|
||||
outputs=[executor.button_run, executor.button_stop_training],
|
||||
)
|
||||
|
||||
button_run.click(
|
||||
executor.button_run.click(
|
||||
train_model,
|
||||
inputs=[dummy_headless] + [dummy_db_false] + settings_list,
|
||||
outputs=[button_run, button_stop_training, run_state],
|
||||
outputs=[executor.button_run, executor.button_stop_training, run_state],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
button_stop_training.click(
|
||||
executor.kill_command, outputs=[button_run, button_stop_training]
|
||||
executor.button_stop_training.click(
|
||||
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
|
||||
)
|
||||
|
||||
button_print.click(
|
||||
|
|
|
|||
Loading…
Reference in New Issue