mirror of https://github.com/bmaltais/kohya_ss
Apple Silicone Support (This time not on Master) (#3174)
* Adding some changes to support current apple silicone Adding a note that MPS is detected in validation, and a current set of packages that offer MPS torch acceleration * Adding MPS support for blip2pull/3205/head
parent
b4ea70b72d
commit
9d500d99c2
|
|
@ -13,10 +13,17 @@ log = setup_logging()
|
|||
|
||||
def load_model():
|
||||
# Set the device to GPU if available, otherwise use CPU
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if hasattr(torch, 'cuda') and torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
elif hasattr(torch, 'mps') and torch.mps.is_available():
|
||||
device = 'mps'
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
|
||||
# Initialize the BLIP2 processor
|
||||
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
print('processor:' + str(processor))
|
||||
|
||||
# Initialize the BLIP2 model
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
torch==2.0.0 torchvision==0.15.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
xformers bitsandbytes==0.43.3
|
||||
tensorflow-macos tensorflow-metal tensorboard==2.14.1
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
torch==2.8.0.*
|
||||
torchvision==0.22.*
|
||||
xformers==0.0.29.*
|
||||
git+https://github.com/bitsandbytes-foundation/bitsandbytes.git/#0.45.5
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
tensorboard==2.14.1
|
||||
onnxruntime==1.17.1
|
||||
-r requirements.txt
|
||||
|
|
|
|||
|
|
@ -96,6 +96,9 @@ def check_torch():
|
|||
log.debug("XPU is available, logging XPU info...")
|
||||
log_xpu_info(torch, ipex)
|
||||
# Log a warning if no GPU is available
|
||||
elif hasattr(torch, "mps") and torch.mps.is_available():
|
||||
log.info("MPS is available, logging MPS info...")
|
||||
log_mps_info(torch)
|
||||
else:
|
||||
log.warning("Torch reports GPU not available")
|
||||
|
||||
|
|
@ -130,6 +133,15 @@ def log_cuda_info(torch):
|
|||
f"Torch detected GPU: {props.name} VRAM {round(props.total_memory / 1024 / 1024)}MB Arch {props.major}.{props.minor} Cores {props.multi_processor_count}"
|
||||
)
|
||||
|
||||
def log_mps_info(torch):
|
||||
"""Log information about Apple Silicone (MPS)"""
|
||||
max_reccomended_mem = round(torch.mps.recommended_max_memory() / 1024**2)
|
||||
log.info(
|
||||
f"Torch detected Apple MPS: {max_reccomended_mem}MB Unified Memory Available"
|
||||
)
|
||||
log.warning('MPS support is still experimental, proceed with caution.')
|
||||
|
||||
|
||||
def log_xpu_info(torch, ipex):
|
||||
"""Log information about Intel XPU-enabled GPUs."""
|
||||
# Log the Intel Extension for PyTorch (IPEX) version if available
|
||||
|
|
|
|||
Loading…
Reference in New Issue