mirror of https://github.com/vladmandic/automatic
56 lines
2.3 KiB
Python
56 lines
2.3 KiB
Python
from collections import defaultdict
|
|
import torch
|
|
|
|
|
|
class MemUsageMonitor():
|
|
device = None
|
|
disabled = False
|
|
opts = None
|
|
data = None
|
|
|
|
def __init__(self, name, device):
|
|
self.name = name
|
|
self.device = device
|
|
self.data = defaultdict(int)
|
|
if not torch.cuda.is_available():
|
|
self.disabled = True
|
|
else:
|
|
try:
|
|
torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())
|
|
torch.cuda.memory_stats(self.device)
|
|
except Exception:
|
|
self.disabled = True
|
|
|
|
def cuda_mem_get_info(self): # legacy for extensions only
|
|
if self.disabled:
|
|
return 0, 0
|
|
return torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())
|
|
|
|
def reset(self):
|
|
if not self.disabled:
|
|
try:
|
|
torch.cuda.reset_peak_memory_stats(self.device)
|
|
self.data['retries'] = 0
|
|
self.data['oom'] = 0
|
|
# torch.cuda.reset_accumulated_memory_stats(self.device)
|
|
# torch.cuda.reset_max_memory_allocated(self.device)
|
|
# torch.cuda.reset_max_memory_cached(self.device)
|
|
except Exception:
|
|
pass
|
|
|
|
def read(self):
|
|
if not self.disabled:
|
|
try:
|
|
self.data["free"], self.data["total"] = torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())
|
|
self.data["used"] = self.data["total"] - self.data["free"]
|
|
torch_stats = torch.cuda.memory_stats(self.device)
|
|
self.data["active"] = torch_stats.get("active.all.current", torch_stats["active_bytes.all.current"])
|
|
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
|
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
|
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
|
self.data['retries'] = torch_stats.get("num_alloc_retries", -1)
|
|
self.data['oom'] = torch_stats.get("num_ooms", -1)
|
|
except Exception:
|
|
self.disabled = True
|
|
return self.data
|