From e9cee44db21792e793a25d9bf33f9a12bf2a2090 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 19 Mar 2024 19:04:51 -0400 Subject: [PATCH] Add option to install Triton for Windows to setup menu --- README.md | 1 + setup/setup_windows.py | 84 ++++++++++++++++++++++++++---------------- 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index aeb5bd0..1658f38 100644 --- a/README.md +++ b/README.md @@ -385,6 +385,7 @@ The documentation in this section will be moved to a separate document later. ### 2024/03/20 (v23.0.15) - Add support for toml dataset configuration fole to all trainers +- Add new setup menu option to install Triton 2.1.0 for Windows ### 2024/03/19 (v23.0.14) diff --git a/setup/setup_windows.py b/setup/setup_windows.py index b91f027..7fe9307 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -22,9 +22,12 @@ def cudnn_install(): "nvidia-cudnn-cu11 8.9.5.29", reinstall=True, ) - + # Original path with "..\\venv" - original_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..\\venv\\Lib\\site-packages\\nvidia\\cudnn\\bin") + original_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "..\\venv\\Lib\\site-packages\\nvidia\\cudnn\\bin", + ) # Normalize the path to resolve "..\\venv" cudnn_src = os.path.abspath(original_path) cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib") @@ -35,7 +38,7 @@ def cudnn_install(): # check for different files filecmp.clear_cache() for file in os.listdir(cudnn_src): - if file.lower().endswith('.dll'): # Check if the file is a .dll file + if file.lower().endswith(".dll"): # Check if the file is a .dll file src_file = os.path.join(cudnn_src, file) dest_file = os.path.join(cudnn_dest, file) # if dest file exists, check if it's different @@ -110,16 +113,18 @@ def install_kohya_ss_torch2(headless: bool = False): setup_common.check_repo_version() if not setup_common.check_python_version(): exit(1) - + setup_common.update_submodule() - + setup_common.install("pip") setup_common.install_requirements( "requirements_windows_torch2.txt", check_no_verify_flag=False ) - - setup_common.configure_accelerate(run_accelerate=not headless) # False if headless is True and vice versa + + setup_common.configure_accelerate( + run_accelerate=not headless + ) # False if headless is True and vice versa def install_bitsandbytes_0_35_0(): @@ -147,6 +152,7 @@ def install_bitsandbytes_0_41_1(): reinstall=True, ) + def install_bitsandbytes_0_41_2(): log.info("Installing bitsandbytes 0.41.2...") setup_common.install( @@ -155,21 +161,34 @@ def install_bitsandbytes_0_41_2(): reinstall=True, ) + +def install_triton_2_1_0(): + log.info("Installing triton 2.1.0...") + setup_common.install( + "--upgrade https://huggingface.co/Rodeszones/CogVLM-grounding-generalist-hf-quant4/resolve/main/triton-2.1.0-cp310-cp310-win_amd64.whl?download=true", + "triton 2.1.0", + reinstall=True, + ) + + def main_menu(headless: bool = False): if headless: install_kohya_ss_torch2(headless=headless) else: setup_common.clear_screen() while True: - print("\nKohya_ss GUI setup menu:\n") - print("1. Install kohya_ss gui") - print("2. (Optional) Install cudnn files (if you want to use latest supported cudnn version)") - print("3. (Optional) Install specific bitsandbytes versions") - print("4. (Optional) Manually configure accelerate") - print("5. (Optional) Start Kohya_ss GUI in browser") - print("6. Quit") + print("\nKohya_ss setup menu:\n") + print("1. Install kohya_ss GUI") + print( + "2. (Optional) Install CuDNN files (to use the latest supported CuDNN version)" + ) + print("3. (Optional) Install Triton 2.1.0 for Windows") + print("4. (Optional) Install specific version of bitsandbytes") + print("5. (Optional) Manually configure Accelerate") + print("6. (Optional) Launch Kohya_ss GUI in browser") + print("7. Exit Setup") - choice = input("\nEnter your choice: ") + choice = input("\nSelect an option: ") print("") if choice == "1": @@ -177,22 +196,25 @@ def main_menu(headless: bool = False): elif choice == "2": cudnn_install() elif choice == "3": + install_triton_2_1_0() + elif choice == "4": while True: - print("1. (Optional) Force installation of bitsandbytes 0.35.0") + print("\nBitsandBytes Installation Menu:") + print("1. Force install Bitsandbytes 0.35.0") print( - "2. (Optional) Force installation of bitsandbytes 0.40.1 for new optimizer options support and pre-bugfix results" + "2. Force install Bitsandbytes 0.40.1 (supports new optimizer options, pre-bugfix results)" ) print( - "3. (Optional) Force installation of bitsandbytes 0.41.1 for new optimizer options support" + "3. Force installation Bitsandbytes 0.41.1 (supports new optimizer options)" ) print( - "4. (Recommended) Force installation of bitsandbytes 0.41.2 for new optimizer options support" + "4. (Recommended) Force install Bitsandbytes 0.41.2 (supports new optimizer options)" ) print( - "5. (Danger) Install bitsandbytes-windows (this package has been reported to cause issues for most... avoid...)" + "5. (Warning) Install bitsandbytes-windows (may cause issues, use with caution)" ) - print("6. Exit") - choice_torch = input("\nEnter your choice: ") + print("6. Return to Previous Menu:") + choice_torch = input("\nSelect an option: ") print("") if choice_torch == "1": @@ -215,29 +237,29 @@ def main_menu(headless: bool = False): elif choice_torch == "6": break else: - print("Invalid choice. Please enter a number between 1-3.") - elif choice == "4": - setup_common.run_cmd("accelerate config") + print("Invalid choice. Please chose an option between 1-6.") elif choice == "5": + setup_common.run_cmd("accelerate config") + elif choice == "6": subprocess.Popen( "start cmd /k .\gui.bat --inbrowser", shell=True ) # /k keep the terminal open on quit. /c would close the terminal instead - elif choice == "6": - print("Quitting the program.") + elif choice == "7": + print("Exiting setup.") break else: - print("Invalid choice. Please enter a number between 1-5.") + print("Invalid selection. Please choose an option between 1-7.") if __name__ == "__main__": setup_common.ensure_base_requirements() setup_common.setup_logging() - + # Setup argument parser parser = argparse.ArgumentParser(description="Your Script Description") - parser.add_argument('--headless', action='store_true', help='Run in headless mode') + parser.add_argument("--headless", action="store_true", help="Run in headless mode") # Parse arguments args = parser.parse_args() - + main_menu(headless=args.headless)