Feat(tensorboard): Add AVX check to prevent crashes

I added a check for AVX support before enabling the TensorBoard feature. This prevents crashes on systems where the CPU does not support the required instruction sets.

I implemented a check for the `avx` flag in `/proc/cpuinfo` and disabled the TensorBoard feature if it is not present. This provides a more robust solution that avoids fatal errors on incompatible hardware.
pull/3344/head
google-labs-jules[bot] 2025-07-13 13:52:28 +00:00
parent 4161d1d80a
commit fbf6eb9ad9
2 changed files with 60 additions and 4 deletions

View File

@ -3,12 +3,25 @@ import gradio as gr
import subprocess
import time
import webbrowser
import shutil
import re
def check_avx_support():
try:
with open('/proc/cpuinfo') as f:
for line in f:
if line.startswith('flags'):
return 'avx' in line
except FileNotFoundError:
# /proc/cpuinfo is not available on all platforms.
# As a fallback, assume AVX is not supported.
return False
try:
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow # Attempt to import tensorflow to check if it is installed
visibility = True
if shutil.which("tensorboard") and check_avx_support():
visibility = True
else:
visibility = False
except ImportError:
visibility = False

View File

@ -0,0 +1,43 @@
import unittest
from unittest.mock import patch, mock_open
import importlib
# Since we are modifying an existing file, we need to reload it
import kohya_gui.class_tensorboard
importlib.reload(kohya_gui.class_tensorboard)
class TestTensorboardVisibility(unittest.TestCase):
@patch('shutil.which', return_value='/usr/bin/tensorboard')
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True)
def test_tensorboard_visibility_when_tensorboard_and_avx_are_present(self, mock_avx, mock_which):
importlib.reload(kohya_gui.class_tensorboard)
self.assertTrue(kohya_gui.class_tensorboard.visibility)
@patch('shutil.which', return_value=None)
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True)
def test_tensorboard_visibility_when_tensorboard_is_absent(self, mock_avx, mock_which):
importlib.reload(kohya_gui.class_tensorboard)
self.assertFalse(kohya_gui.class_tensorboard.visibility)
@patch('shutil.which', return_value='/usr/bin/tensorboard')
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=False)
def test_tensorboard_visibility_when_avx_is_absent(self, mock_avx, mock_which):
importlib.reload(kohya_gui.class_tensorboard)
self.assertFalse(kohya_gui.class_tensorboard.visibility)
@patch('builtins.open', mock_open(read_data="flags : avx"))
def test_check_avx_support_present(self):
self.assertTrue(kohya_gui.class_tensorboard.check_avx_support())
@patch('builtins.open', mock_open(read_data="flags : sse"))
def test_check_avx_support_absent(self):
self.assertFalse(kohya_gui.class_tensorboard.check_avx_support())
@patch('builtins.open', side_effect=FileNotFoundError)
def test_check_avx_support_file_not_found(self, mock_open):
self.assertFalse(kohya_gui.class_tensorboard.check_avx_support())
if __name__ == '__main__':
unittest.main()