diff --git a/kohya_gui/tensorboard_gui.py b/kohya_gui/tensorboard_gui.py index 7f01140..85335e6 100644 --- a/kohya_gui/tensorboard_gui.py +++ b/kohya_gui/tensorboard_gui.py @@ -16,6 +16,7 @@ TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe' DEFAULT_TENSORBOARD_PORT = 6006 def start_tensorboard(headless, logging_dir, wait_time=5): + os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' global tensorboard_proc headless_bool = True if headless.get('label') == 'True' else False @@ -49,7 +50,13 @@ def start_tensorboard(headless, logging_dir, wait_time=5): # Start background process log.info('Starting TensorBoard on port {}'.format(tensorboard_port)) try: - tensorboard_proc = subprocess.Popen(run_cmd) + # Copy the current environment + env = os.environ.copy() + + # Set your specific environment variable + env['TF_ENABLE_ONEDNN_OPTS'] = '0' + + tensorboard_proc = subprocess.Popen(run_cmd, env=env) except Exception as e: log.error('Failed to start Tensorboard:', e) return diff --git a/requirements.txt b/requirements.txt index 0ddc62b..48344ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,7 @@ onnxruntime-gpu==1.16.0 # onnxruntime==1.16.0 # this is for onnx: # tensorboard==2.14.1 +prodigyopt==1.0 protobuf==3.20.3 # open clip for SDXL open-clip-torch==2.20.0 @@ -35,6 +36,7 @@ prodigyopt==1.0 pytorch-lightning==1.9.0 rich==13.7.0 safetensors==0.4.2 +scipy==1.11.4 timm==0.6.12 tk==0.1.0 toml==0.10.2 diff --git a/requirements_pytorch_windows.txt b/requirements_pytorch_windows.txt new file mode 100644 index 0000000..63fd7e0 --- /dev/null +++ b/requirements_pytorch_windows.txt @@ -0,0 +1,3 @@ +torch==2.1.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +torchvision==0.16.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 diff --git a/requirements_windows.txt b/requirements_windows.txt new file mode 100644 index 0000000..e641755 --- /dev/null +++ b/requirements_windows.txt @@ -0,0 +1,4 @@ +bitsandbytes==0.43.0 +tensorboard +tensorflow +-r requirements.txt \ No newline at end of file diff --git a/requirements_windows_torch2.txt b/requirements_windows_torch2.txt deleted file mode 100644 index b381420..0000000 --- a/requirements_windows_torch2.txt +++ /dev/null @@ -1,4 +0,0 @@ -torch==2.1.2+cu118 torchvision==0.16.2+cu118 torchaudio==2.1.2+cu118 xformers==0.0.23.post1+cu118 --index-url https://download.pytorch.org/whl/cu118 -bitsandbytes==0.43.0 -tensorboard==2.14.1 tensorflow==2.14.0 --r requirements.txt diff --git a/setup/setup_common.py b/setup/setup_common.py index 7875668..35db54f 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -218,6 +218,24 @@ def setup_logging(clean=False): log.addHandler(rh) +def install_requirements_inbulk(requirements_file, show_stdout=True, optional_parm="", upgrade = False): + if not os.path.exists(requirements_file): + log.error(f'Could not find the requirements file in {requirements_file}.') + return + + log.info(f'Installing requirements from {requirements_file}...') + + if upgrade: + optional_parm += " -U" + + if show_stdout: + run_cmd(f'pip install -r {requirements_file} {optional_parm}') + else: + run_cmd(f'pip install -r {requirements_file} {optional_parm} --quiet') + log.info(f'Requirements from {requirements_file} installed.') + + + def configure_accelerate(run_accelerate=False): # # This function was taken and adapted from code written by jstayco diff --git a/setup/setup_windows.py b/setup/setup_windows.py index 7fe9307..1f9a102 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -118,8 +118,16 @@ def install_kohya_ss_torch2(headless: bool = False): setup_common.install("pip") - setup_common.install_requirements( - "requirements_windows_torch2.txt", check_no_verify_flag=False + # setup_common.install_requirements( + # "requirements_windows_torch2.txt", check_no_verify_flag=False + # ) + + setup_common.install_requirements_inbulk( + "requirements_pytorch_windows.txt", show_stdout=True, optional_parm="--index-url https://download.pytorch.org/whl/cu118" + ) + + setup_common.install_requirements_inbulk( + "requirements_windows.txt", show_stdout=True, upgrade=True ) setup_common.configure_accelerate( diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index f756684..5fbadf5 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -113,8 +113,8 @@ def main(): if args.requirements: setup_common.install_requirements(args.requirements, check_no_verify_flag=True) else: - setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=True) - + setup_common.install_requirements('requirements_pytorch_windows.txt', check_no_verify_flag=True) + setup_common.install_requirements('requirements_windows.txt', check_no_verify_flag=True) if __name__ == '__main__': main()