diff --git a/kohya_gui/class_tensorboard.py b/kohya_gui/class_tensorboard.py index 001c894..6ce3d34 100644 --- a/kohya_gui/class_tensorboard.py +++ b/kohya_gui/class_tensorboard.py @@ -3,12 +3,25 @@ import gradio as gr import subprocess import time import webbrowser +import shutil +import re + +def check_avx_support(): + try: + with open('/proc/cpuinfo') as f: + for line in f: + if line.startswith('flags'): + return 'avx' in line + except FileNotFoundError: + # /proc/cpuinfo is not available on all platforms. + # As a fallback, assume AVX is not supported. + return False try: - os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" - import tensorflow # Attempt to import tensorflow to check if it is installed - - visibility = True + if shutil.which("tensorboard") and check_avx_support(): + visibility = True + else: + visibility = False except ImportError: visibility = False diff --git a/tests/test_tensorboard_visibility.py b/tests/test_tensorboard_visibility.py new file mode 100644 index 0000000..3950480 --- /dev/null +++ b/tests/test_tensorboard_visibility.py @@ -0,0 +1,43 @@ +import unittest +from unittest.mock import patch, mock_open +import importlib + +# Since we are modifying an existing file, we need to reload it +import kohya_gui.class_tensorboard +importlib.reload(kohya_gui.class_tensorboard) + +class TestTensorboardVisibility(unittest.TestCase): + + @patch('shutil.which', return_value='/usr/bin/tensorboard') + @patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True) + def test_tensorboard_visibility_when_tensorboard_and_avx_are_present(self, mock_avx, mock_which): + importlib.reload(kohya_gui.class_tensorboard) + self.assertTrue(kohya_gui.class_tensorboard.visibility) + + @patch('shutil.which', return_value=None) + @patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True) + def test_tensorboard_visibility_when_tensorboard_is_absent(self, mock_avx, mock_which): + importlib.reload(kohya_gui.class_tensorboard) + self.assertFalse(kohya_gui.class_tensorboard.visibility) + + @patch('shutil.which', return_value='/usr/bin/tensorboard') + @patch('kohya_gui.class_tensorboard.check_avx_support', return_value=False) + def test_tensorboard_visibility_when_avx_is_absent(self, mock_avx, mock_which): + importlib.reload(kohya_gui.class_tensorboard) + self.assertFalse(kohya_gui.class_tensorboard.visibility) + + @patch('builtins.open', mock_open(read_data="flags : avx")) + def test_check_avx_support_present(self): + self.assertTrue(kohya_gui.class_tensorboard.check_avx_support()) + + @patch('builtins.open', mock_open(read_data="flags : sse")) + def test_check_avx_support_absent(self): + self.assertFalse(kohya_gui.class_tensorboard.check_avx_support()) + + @patch('builtins.open', side_effect=FileNotFoundError) + def test_check_avx_support_file_not_found(self, mock_open): + self.assertFalse(kohya_gui.class_tensorboard.check_avx_support()) + + +if __name__ == '__main__': + unittest.main()