mirror of https://github.com/vladmandic/automatic
89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
import torch
|
|
from modules import shared
|
|
|
|
|
|
class MemUsageMonitor(threading.Thread):
|
|
run_flag = None
|
|
device = None
|
|
disabled = False
|
|
opts = None
|
|
data = None
|
|
|
|
def __init__(self, name, device, opts):
|
|
threading.Thread.__init__(self)
|
|
self.name = name
|
|
self.device = device
|
|
self.opts = opts
|
|
self.daemon = True
|
|
self.run_flag = threading.Event()
|
|
self.data = defaultdict(int)
|
|
if not torch.cuda.is_available():
|
|
#torch.cuda.is_available() reports False when using IPEX.
|
|
if shared.cmd_opts.use_ipex:
|
|
self.cuda_mem_get_info()
|
|
torch.xpu.memory_stats(self.device)
|
|
else:
|
|
self.disabled = True
|
|
else:
|
|
try:
|
|
self.cuda_mem_get_info()
|
|
torch.cuda.memory_stats(self.device)
|
|
except Exception:
|
|
self.disabled = True
|
|
|
|
def cuda_mem_get_info(self):
|
|
if shared.cmd_opts.use_ipex:
|
|
index = self.device.index if self.device.index is not None else torch.xpu.current_device()
|
|
return [(torch.xpu.get_device_properties(index).total_memory - torch.xpu.memory_allocated(index)), torch.xpu.get_device_properties(index).total_memory]
|
|
else:
|
|
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
|
|
return torch.cuda.mem_get_info(index)
|
|
|
|
def run(self):
|
|
if self.disabled:
|
|
return
|
|
while True:
|
|
self.run_flag.wait()
|
|
if shared.cmd_opts.use_ipex:
|
|
torch.xpu.reset_peak_memory_stats()
|
|
else:
|
|
torch.cuda.reset_peak_memory_stats()
|
|
self.data.clear()
|
|
if self.opts.memmon_poll_rate <= 0:
|
|
self.run_flag.clear()
|
|
continue
|
|
self.data["min_free"] = self.cuda_mem_get_info()[0]
|
|
while self.run_flag.is_set():
|
|
free, _total = self.cuda_mem_get_info()
|
|
self.data["min_free"] = min(self.data["min_free"], free)
|
|
time.sleep(1 / self.opts.memmon_poll_rate)
|
|
|
|
def monitor(self):
|
|
self.run_flag.set()
|
|
|
|
def read(self):
|
|
if not self.disabled:
|
|
free, total = self.cuda_mem_get_info()
|
|
self.data["free"] = free
|
|
self.data["total"] = total
|
|
try:
|
|
if shared.cmd_opts.use_ipex:
|
|
torch_stats = torch.xpu.memory_stats(self.device)
|
|
else:
|
|
torch_stats = torch.cuda.memory_stats(self.device)
|
|
self.data["active"] = torch_stats["active.all.current"]
|
|
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
|
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
|
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
|
self.data["system_peak"] = total - self.data["min_free"]
|
|
except Exception:
|
|
self.disabled = True
|
|
return self.data
|
|
|
|
def stop(self):
|
|
self.run_flag.clear()
|
|
return self.read()
|