Add option to install Triton for Windows to setup menu

pull/2137/head
bmaltais 2024-03-19 19:04:51 -04:00
parent 99691b7b5d
commit e9cee44db2
2 changed files with 54 additions and 31 deletions

View File

@ -385,6 +385,7 @@ The documentation in this section will be moved to a separate document later.
### 2024/03/20 (v23.0.15)
- Add support for toml dataset configuration fole to all trainers
- Add new setup menu option to install Triton 2.1.0 for Windows
### 2024/03/19 (v23.0.14)

View File

@ -22,9 +22,12 @@ def cudnn_install():
"nvidia-cudnn-cu11 8.9.5.29",
reinstall=True,
)
# Original path with "..\\venv"
original_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..\\venv\\Lib\\site-packages\\nvidia\\cudnn\\bin")
original_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..\\venv\\Lib\\site-packages\\nvidia\\cudnn\\bin",
)
# Normalize the path to resolve "..\\venv"
cudnn_src = os.path.abspath(original_path)
cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib")
@ -35,7 +38,7 @@ def cudnn_install():
# check for different files
filecmp.clear_cache()
for file in os.listdir(cudnn_src):
if file.lower().endswith('.dll'): # Check if the file is a .dll file
if file.lower().endswith(".dll"): # Check if the file is a .dll file
src_file = os.path.join(cudnn_src, file)
dest_file = os.path.join(cudnn_dest, file)
# if dest file exists, check if it's different
@ -110,16 +113,18 @@ def install_kohya_ss_torch2(headless: bool = False):
setup_common.check_repo_version()
if not setup_common.check_python_version():
exit(1)
setup_common.update_submodule()
setup_common.install("pip")
setup_common.install_requirements(
"requirements_windows_torch2.txt", check_no_verify_flag=False
)
setup_common.configure_accelerate(run_accelerate=not headless) # False if headless is True and vice versa
setup_common.configure_accelerate(
run_accelerate=not headless
) # False if headless is True and vice versa
def install_bitsandbytes_0_35_0():
@ -147,6 +152,7 @@ def install_bitsandbytes_0_41_1():
reinstall=True,
)
def install_bitsandbytes_0_41_2():
log.info("Installing bitsandbytes 0.41.2...")
setup_common.install(
@ -155,21 +161,34 @@ def install_bitsandbytes_0_41_2():
reinstall=True,
)
def install_triton_2_1_0():
log.info("Installing triton 2.1.0...")
setup_common.install(
"--upgrade https://huggingface.co/Rodeszones/CogVLM-grounding-generalist-hf-quant4/resolve/main/triton-2.1.0-cp310-cp310-win_amd64.whl?download=true",
"triton 2.1.0",
reinstall=True,
)
def main_menu(headless: bool = False):
if headless:
install_kohya_ss_torch2(headless=headless)
else:
setup_common.clear_screen()
while True:
print("\nKohya_ss GUI setup menu:\n")
print("1. Install kohya_ss gui")
print("2. (Optional) Install cudnn files (if you want to use latest supported cudnn version)")
print("3. (Optional) Install specific bitsandbytes versions")
print("4. (Optional) Manually configure accelerate")
print("5. (Optional) Start Kohya_ss GUI in browser")
print("6. Quit")
print("\nKohya_ss setup menu:\n")
print("1. Install kohya_ss GUI")
print(
"2. (Optional) Install CuDNN files (to use the latest supported CuDNN version)"
)
print("3. (Optional) Install Triton 2.1.0 for Windows")
print("4. (Optional) Install specific version of bitsandbytes")
print("5. (Optional) Manually configure Accelerate")
print("6. (Optional) Launch Kohya_ss GUI in browser")
print("7. Exit Setup")
choice = input("\nEnter your choice: ")
choice = input("\nSelect an option: ")
print("")
if choice == "1":
@ -177,22 +196,25 @@ def main_menu(headless: bool = False):
elif choice == "2":
cudnn_install()
elif choice == "3":
install_triton_2_1_0()
elif choice == "4":
while True:
print("1. (Optional) Force installation of bitsandbytes 0.35.0")
print("\nBitsandBytes Installation Menu:")
print("1. Force install Bitsandbytes 0.35.0")
print(
"2. (Optional) Force installation of bitsandbytes 0.40.1 for new optimizer options support and pre-bugfix results"
"2. Force install Bitsandbytes 0.40.1 (supports new optimizer options, pre-bugfix results)"
)
print(
"3. (Optional) Force installation of bitsandbytes 0.41.1 for new optimizer options support"
"3. Force installation Bitsandbytes 0.41.1 (supports new optimizer options)"
)
print(
"4. (Recommended) Force installation of bitsandbytes 0.41.2 for new optimizer options support"
"4. (Recommended) Force install Bitsandbytes 0.41.2 (supports new optimizer options)"
)
print(
"5. (Danger) Install bitsandbytes-windows (this package has been reported to cause issues for most... avoid...)"
"5. (Warning) Install bitsandbytes-windows (may cause issues, use with caution)"
)
print("6. Exit")
choice_torch = input("\nEnter your choice: ")
print("6. Return to Previous Menu:")
choice_torch = input("\nSelect an option: ")
print("")
if choice_torch == "1":
@ -215,29 +237,29 @@ def main_menu(headless: bool = False):
elif choice_torch == "6":
break
else:
print("Invalid choice. Please enter a number between 1-3.")
elif choice == "4":
setup_common.run_cmd("accelerate config")
print("Invalid choice. Please chose an option between 1-6.")
elif choice == "5":
setup_common.run_cmd("accelerate config")
elif choice == "6":
subprocess.Popen(
"start cmd /k .\gui.bat --inbrowser", shell=True
) # /k keep the terminal open on quit. /c would close the terminal instead
elif choice == "6":
print("Quitting the program.")
elif choice == "7":
print("Exiting setup.")
break
else:
print("Invalid choice. Please enter a number between 1-5.")
print("Invalid selection. Please choose an option between 1-7.")
if __name__ == "__main__":
setup_common.ensure_base_requirements()
setup_common.setup_logging()
# Setup argument parser
parser = argparse.ArgumentParser(description="Your Script Description")
parser.add_argument('--headless', action='store_true', help='Run in headless mode')
parser.add_argument("--headless", action="store_true", help="Run in headless mode")
# Parse arguments
args = parser.parse_args()
main_menu(headless=args.headless)