Apple Silicone Support (This time not on Master) (#3174)

* Adding some changes to support current apple silicone

Adding a note that MPS is detected in validation, and a current set of packages that offer MPS torch acceleration

* Adding MPS support for blip2
pull/3205/head
Ryan 2025-04-19 07:22:53 -07:00 committed by GitHub
parent b4ea70b72d
commit 9d500d99c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 4 deletions

View File

@ -13,10 +13,17 @@ log = setup_logging()
def load_model():
# Set the device to GPU if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
if hasattr(torch, 'cuda') and torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch, 'mps') and torch.mps.is_available():
device = 'mps'
else:
device = 'cpu'
# Initialize the BLIP2 processor
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
print('processor:' + str(processor))
# Initialize the BLIP2 model
model = Blip2ForConditionalGeneration.from_pretrained(

View File

@ -1,5 +1,10 @@
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
xformers bitsandbytes==0.43.3
tensorflow-macos tensorflow-metal tensorboard==2.14.1
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
torch==2.8.0.*
torchvision==0.22.*
xformers==0.0.29.*
git+https://github.com/bitsandbytes-foundation/bitsandbytes.git/#0.45.5
tensorflow-macos
tensorflow-metal
tensorboard==2.14.1
onnxruntime==1.17.1
-r requirements.txt

View File

@ -96,6 +96,9 @@ def check_torch():
log.debug("XPU is available, logging XPU info...")
log_xpu_info(torch, ipex)
# Log a warning if no GPU is available
elif hasattr(torch, "mps") and torch.mps.is_available():
log.info("MPS is available, logging MPS info...")
log_mps_info(torch)
else:
log.warning("Torch reports GPU not available")
@ -130,6 +133,15 @@ def log_cuda_info(torch):
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
)
def log_mps_info(torch):
"""Log information about Apple Silicone (MPS)"""
max_reccomended_mem = round(torch.mps.recommended_max_memory() / 1024**2)
log.info(
f"Torch detected Apple MPS: {max_reccomended_mem}MB Unified Memory Available"
)
log.warning('MPS support is still experimental, proceed with caution.')
def log_xpu_info(torch, ipex):
"""Log information about Intel XPU-enabled GPUs."""
# Log the Intel Extension for PyTorch (IPEX) version if available