diff --git a/kohya_gui/blip2_caption_gui.py b/kohya_gui/blip2_caption_gui.py index b326322..6a859af 100644 --- a/kohya_gui/blip2_caption_gui.py +++ b/kohya_gui/blip2_caption_gui.py @@ -13,10 +13,17 @@ log = setup_logging() def load_model(): # Set the device to GPU if available, otherwise use CPU - device = "cuda" if torch.cuda.is_available() else "cpu" + if hasattr(torch, 'cuda') and torch.cuda.is_available(): + device = 'cuda' + elif hasattr(torch, 'mps') and torch.mps.is_available(): + device = 'mps' + else: + device = 'cpu' + # Initialize the BLIP2 processor processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + print('processor:' + str(processor)) # Initialize the BLIP2 model model = Blip2ForConditionalGeneration.from_pretrained( diff --git a/requirements_macos_arm64.txt b/requirements_macos_arm64.txt index 364c44a..c495f9d 100644 --- a/requirements_macos_arm64.txt +++ b/requirements_macos_arm64.txt @@ -1,5 +1,10 @@ -torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html -xformers bitsandbytes==0.43.3 -tensorflow-macos tensorflow-metal tensorboard==2.14.1 +--extra-index-url https://download.pytorch.org/whl/nightly/cpu +torch==2.8.0.* +torchvision==0.22.* +xformers==0.0.29.* +git+https://github.com/bitsandbytes-foundation/bitsandbytes.git/#0.45.5 +tensorflow-macos +tensorflow-metal +tensorboard==2.14.1 onnxruntime==1.17.1 -r requirements.txt diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index f402939..cb13840 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -96,6 +96,9 @@ def check_torch(): log.debug("XPU is available, logging XPU info...") log_xpu_info(torch, ipex) # Log a warning if no GPU is available + elif hasattr(torch, "mps") and torch.mps.is_available(): + log.info("MPS is available, logging MPS info...") + log_mps_info(torch) else: log.warning("Torch reports GPU not available") @@ -130,6 +133,15 @@ def log_cuda_info(torch): f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}" ) +def log_mps_info(torch): + """Log information about Apple Silicone (MPS)""" + max_reccomended_mem = round(torch.mps.recommended_max_memory() / 1024**2) + log.info( + f"Torch detected Apple MPS: {max_reccomended_mem}MB Unified Memory Available" + ) + log.warning('MPS support is still experimental, proceed with caution.') + + def log_xpu_info(torch, ipex): """Log information about Intel XPU-enabled GPUs.""" # Log the Intel Extension for PyTorch (IPEX) version if available