mirror of https://github.com/bmaltais/kohya_ss
Feat(tensorboard): Add AVX check to prevent crashes
I added a check for AVX support before enabling the TensorBoard feature. This prevents crashes on systems where the CPU does not support the required instruction sets. I implemented a check for the `avx` flag in `/proc/cpuinfo` and disabled the TensorBoard feature if it is not present. This provides a more robust solution that avoids fatal errors on incompatible hardware.pull/3344/head
parent
4161d1d80a
commit
fbf6eb9ad9
|
|
@ -3,12 +3,25 @@ import gradio as gr
|
|||
import subprocess
|
||||
import time
|
||||
import webbrowser
|
||||
import shutil
|
||||
import re
|
||||
|
||||
def check_avx_support():
|
||||
try:
|
||||
with open('/proc/cpuinfo') as f:
|
||||
for line in f:
|
||||
if line.startswith('flags'):
|
||||
return 'avx' in line
|
||||
except FileNotFoundError:
|
||||
# /proc/cpuinfo is not available on all platforms.
|
||||
# As a fallback, assume AVX is not supported.
|
||||
return False
|
||||
|
||||
try:
|
||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
||||
import tensorflow # Attempt to import tensorflow to check if it is installed
|
||||
|
||||
if shutil.which("tensorboard") and check_avx_support():
|
||||
visibility = True
|
||||
else:
|
||||
visibility = False
|
||||
except ImportError:
|
||||
visibility = False
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,43 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, mock_open
|
||||
import importlib
|
||||
|
||||
# Since we are modifying an existing file, we need to reload it
|
||||
import kohya_gui.class_tensorboard
|
||||
importlib.reload(kohya_gui.class_tensorboard)
|
||||
|
||||
class TestTensorboardVisibility(unittest.TestCase):
|
||||
|
||||
@patch('shutil.which', return_value='/usr/bin/tensorboard')
|
||||
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True)
|
||||
def test_tensorboard_visibility_when_tensorboard_and_avx_are_present(self, mock_avx, mock_which):
|
||||
importlib.reload(kohya_gui.class_tensorboard)
|
||||
self.assertTrue(kohya_gui.class_tensorboard.visibility)
|
||||
|
||||
@patch('shutil.which', return_value=None)
|
||||
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True)
|
||||
def test_tensorboard_visibility_when_tensorboard_is_absent(self, mock_avx, mock_which):
|
||||
importlib.reload(kohya_gui.class_tensorboard)
|
||||
self.assertFalse(kohya_gui.class_tensorboard.visibility)
|
||||
|
||||
@patch('shutil.which', return_value='/usr/bin/tensorboard')
|
||||
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=False)
|
||||
def test_tensorboard_visibility_when_avx_is_absent(self, mock_avx, mock_which):
|
||||
importlib.reload(kohya_gui.class_tensorboard)
|
||||
self.assertFalse(kohya_gui.class_tensorboard.visibility)
|
||||
|
||||
@patch('builtins.open', mock_open(read_data="flags : avx"))
|
||||
def test_check_avx_support_present(self):
|
||||
self.assertTrue(kohya_gui.class_tensorboard.check_avx_support())
|
||||
|
||||
@patch('builtins.open', mock_open(read_data="flags : sse"))
|
||||
def test_check_avx_support_absent(self):
|
||||
self.assertFalse(kohya_gui.class_tensorboard.check_avx_support())
|
||||
|
||||
@patch('builtins.open', side_effect=FileNotFoundError)
|
||||
def test_check_avx_support_file_not_found(self, mock_open):
|
||||
self.assertFalse(kohya_gui.class_tensorboard.check_avx_support())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue