diff --git a/kohya_gui/class_tensorboard.py b/kohya_gui/class_tensorboard.py index 001c894..9c11379 100644 --- a/kohya_gui/class_tensorboard.py +++ b/kohya_gui/class_tensorboard.py @@ -3,12 +3,15 @@ import gradio as gr import subprocess import time import webbrowser +import shutil try: - os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" - import tensorflow # Attempt to import tensorflow to check if it is installed - - visibility = True + # os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + # import tensorflow # Attempt to import tensorflow to check if it is installed + if shutil.which("tensorboard"): + visibility = True + else: + visibility = False except ImportError: visibility = False diff --git a/tests/test_tensorboard_visibility.py b/tests/test_tensorboard_visibility.py new file mode 100644 index 0000000..3b16e66 --- /dev/null +++ b/tests/test_tensorboard_visibility.py @@ -0,0 +1,20 @@ +import unittest +from unittest.mock import patch +import importlib + +class TestTensorboardVisibility(unittest.TestCase): + + @patch('shutil.which', return_value='/usr/bin/tensorboard') + def test_tensorboard_visibility_when_tensorboard_is_present(self, mock_which): + import kohya_gui.class_tensorboard + importlib.reload(kohya_gui.class_tensorboard) + self.assertTrue(kohya_gui.class_tensorboard.visibility) + + @patch('shutil.which', return_value=None) + def test_tensorboard_visibility_when_tensorboard_is_absent(self, mock_which): + import kohya_gui.class_tensorboard + importlib.reload(kohya_gui.class_tensorboard) + self.assertFalse(kohya_gui.class_tensorboard.visibility) + +if __name__ == '__main__': + unittest.main()