mirror of https://github.com/bmaltais/kohya_ss
Feat(tensorboard): Add AVX check to prevent crashes
I added a check for AVX support before enabling the TensorBoard feature. This prevents crashes on systems where the CPU does not support the required instruction sets. I implemented a check for the `avx` flag in `/proc/cpuinfo` and disabled the TensorBoard feature if it is not present. This provides a more robust solution that avoids fatal errors on incompatible hardware.pull/3344/head
parent
4161d1d80a
commit
fbf6eb9ad9
|
|
@ -3,12 +3,25 @@ import gradio as gr
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
import shutil
|
||||||
|
import re
|
||||||
|
|
||||||
|
def check_avx_support():
|
||||||
|
try:
|
||||||
|
with open('/proc/cpuinfo') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith('flags'):
|
||||||
|
return 'avx' in line
|
||||||
|
except FileNotFoundError:
|
||||||
|
# /proc/cpuinfo is not available on all platforms.
|
||||||
|
# As a fallback, assume AVX is not supported.
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
if shutil.which("tensorboard") and check_avx_support():
|
||||||
import tensorflow # Attempt to import tensorflow to check if it is installed
|
visibility = True
|
||||||
|
else:
|
||||||
visibility = True
|
visibility = False
|
||||||
except ImportError:
|
except ImportError:
|
||||||
visibility = False
|
visibility = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, mock_open
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# Since we are modifying an existing file, we need to reload it
|
||||||
|
import kohya_gui.class_tensorboard
|
||||||
|
importlib.reload(kohya_gui.class_tensorboard)
|
||||||
|
|
||||||
|
class TestTensorboardVisibility(unittest.TestCase):
|
||||||
|
|
||||||
|
@patch('shutil.which', return_value='/usr/bin/tensorboard')
|
||||||
|
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True)
|
||||||
|
def test_tensorboard_visibility_when_tensorboard_and_avx_are_present(self, mock_avx, mock_which):
|
||||||
|
importlib.reload(kohya_gui.class_tensorboard)
|
||||||
|
self.assertTrue(kohya_gui.class_tensorboard.visibility)
|
||||||
|
|
||||||
|
@patch('shutil.which', return_value=None)
|
||||||
|
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=True)
|
||||||
|
def test_tensorboard_visibility_when_tensorboard_is_absent(self, mock_avx, mock_which):
|
||||||
|
importlib.reload(kohya_gui.class_tensorboard)
|
||||||
|
self.assertFalse(kohya_gui.class_tensorboard.visibility)
|
||||||
|
|
||||||
|
@patch('shutil.which', return_value='/usr/bin/tensorboard')
|
||||||
|
@patch('kohya_gui.class_tensorboard.check_avx_support', return_value=False)
|
||||||
|
def test_tensorboard_visibility_when_avx_is_absent(self, mock_avx, mock_which):
|
||||||
|
importlib.reload(kohya_gui.class_tensorboard)
|
||||||
|
self.assertFalse(kohya_gui.class_tensorboard.visibility)
|
||||||
|
|
||||||
|
@patch('builtins.open', mock_open(read_data="flags : avx"))
|
||||||
|
def test_check_avx_support_present(self):
|
||||||
|
self.assertTrue(kohya_gui.class_tensorboard.check_avx_support())
|
||||||
|
|
||||||
|
@patch('builtins.open', mock_open(read_data="flags : sse"))
|
||||||
|
def test_check_avx_support_absent(self):
|
||||||
|
self.assertFalse(kohya_gui.class_tensorboard.check_avx_support())
|
||||||
|
|
||||||
|
@patch('builtins.open', side_effect=FileNotFoundError)
|
||||||
|
def test_check_avx_support_file_not_found(self, mock_open):
|
||||||
|
self.assertFalse(kohya_gui.class_tensorboard.check_avx_support())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue