mirror of https://github.com/vladmandic/automatic
fix code formatting under modules/dml
parent
ef29f1a238
commit
ff2c1db1cc
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Generator(torch.Generator):
|
class Generator(torch.Generator):
|
||||||
def __init__(self, device: Optional[torch.device] = None):
|
def __init__(self, device: Optional[torch.device] = None):
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
from modules.errors import log
|
from modules.errors import log
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
|
|
||||||
|
|
||||||
memory_providers = ["None", "atiadlxx (AMD only)"]
|
memory_providers = ["None", "atiadlxx (AMD only)"]
|
||||||
default_memory_provider = "None"
|
default_memory_provider = "None"
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
|
|
@ -12,6 +13,7 @@ if platform.system() == "Windows":
|
||||||
do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment
|
do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment
|
||||||
do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment
|
do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment
|
||||||
|
|
||||||
|
|
||||||
def _set_memory_provider():
|
def _set_memory_provider():
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
if opts.directml_memory_provider == "Performance Counter":
|
if opts.directml_memory_provider == "Performance Counter":
|
||||||
|
|
@ -35,6 +37,7 @@ def _set_memory_provider():
|
||||||
torch.dml.mem_get_info = mem_get_info
|
torch.dml.mem_get_info = mem_get_info
|
||||||
torch.cuda.mem_get_info = torch.dml.mem_get_info
|
torch.cuda.mem_get_info = torch.dml.mem_get_info
|
||||||
|
|
||||||
|
|
||||||
def directml_init():
|
def directml_init():
|
||||||
try:
|
try:
|
||||||
from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports
|
from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports
|
||||||
|
|
@ -63,6 +66,7 @@ def directml_init():
|
||||||
return False, e
|
return False, e
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def directml_do_hijack():
|
def directml_do_hijack():
|
||||||
import modules.dml.hijack # pylint: disable=unused-import
|
import modules.dml.hijack # pylint: disable=unused-import
|
||||||
from modules.devices import device
|
from modules.devices import device
|
||||||
|
|
@ -79,17 +83,20 @@ def directml_do_hijack():
|
||||||
|
|
||||||
_set_memory_provider()
|
_set_memory_provider()
|
||||||
|
|
||||||
|
|
||||||
class OverrideItem(NamedTuple):
|
class OverrideItem(NamedTuple):
|
||||||
value: str
|
value: str
|
||||||
condition: Optional[Callable]
|
condition: Optional[Callable]
|
||||||
message: Optional[str]
|
message: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
opts_override_table = {
|
opts_override_table = {
|
||||||
"diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"),
|
"diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"),
|
||||||
"diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"),
|
"diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"),
|
||||||
"diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"),
|
"diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def directml_override_opts():
|
def directml_override_opts():
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,14 @@ import importlib
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
|
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
|
||||||
supported_cast_pairs = {
|
supported_cast_pairs = {
|
||||||
torch.float16: (torch.float32,),
|
torch.float16: (torch.float32,),
|
||||||
torch.float32: (torch.float16,),
|
torch.float32: (torch.float16,),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def forward(op, args: tuple, kwargs: dict):
|
def forward(op, args: tuple, kwargs: dict):
|
||||||
if not torch.dml.is_autocast_enabled:
|
if not torch.dml.is_autocast_enabled:
|
||||||
return op(*args, **kwargs)
|
return op(*args, **kwargs)
|
||||||
|
|
@ -16,6 +18,7 @@ def forward(op, args: tuple, kwargs: dict):
|
||||||
kwargs[kwarg] = cast(kwargs[kwarg])
|
kwargs[kwarg] = cast(kwargs[kwarg])
|
||||||
return op(*args, **kwargs)
|
return op(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def cast(tensor: torch.Tensor):
|
def cast(tensor: torch.Tensor):
|
||||||
if not torch.is_tensor(tensor):
|
if not torch.is_tensor(tensor):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
@ -24,6 +27,7 @@ def cast(tensor: torch.Tensor):
|
||||||
return tensor
|
return tensor
|
||||||
return tensor.type(torch.dml.autocast_gpu_dtype)
|
return tensor.type(torch.dml.autocast_gpu_dtype)
|
||||||
|
|
||||||
|
|
||||||
def cond(op: str):
|
def cond(op: str):
|
||||||
if isinstance(op, str):
|
if isinstance(op, str):
|
||||||
func_path = op.split('.')
|
func_path = op.split('.')
|
||||||
|
|
@ -38,8 +42,10 @@ def cond(op: str):
|
||||||
op = getattr(resolved_obj, func_path[-1])
|
op = getattr(resolved_obj, func_path[-1])
|
||||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
|
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
|
||||||
|
|
||||||
for op in ops:
|
|
||||||
cond(op)
|
for o in ops:
|
||||||
|
cond(o)
|
||||||
|
|
||||||
|
|
||||||
class autocast:
|
class autocast:
|
||||||
prev: bool
|
prev: bool
|
||||||
|
|
|
||||||
|
|
@ -3,26 +3,29 @@ from typing import Optional, Callable
|
||||||
import torch
|
import torch
|
||||||
import torch_directml # pylint: disable=import-error
|
import torch_directml # pylint: disable=import-error
|
||||||
import modules.dml.amp as amp
|
import modules.dml.amp as amp
|
||||||
|
|
||||||
from .utils import rDevice, get_device
|
from .utils import rDevice, get_device
|
||||||
from .device import device
|
from .device import Device
|
||||||
from .Generator import Generator
|
from .Generator import Generator
|
||||||
from .device_properties import DeviceProperties
|
from .device_properties import DeviceProperties
|
||||||
|
|
||||||
|
|
||||||
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||||
from .memory_amd import AMDMemoryProvider
|
from .memory_amd import AMDMemoryProvider
|
||||||
return AMDMemoryProvider.mem_get_info(get_device(device).index)
|
return AMDMemoryProvider.mem_get_info(get_device(device).index)
|
||||||
|
|
||||||
|
|
||||||
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||||
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
|
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
|
||||||
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
|
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
|
||||||
|
|
||||||
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
|
||||||
|
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
|
||||||
return (8589934592, 8589934592)
|
return (8589934592, 8589934592)
|
||||||
|
|
||||||
|
|
||||||
class DirectML:
|
class DirectML:
|
||||||
amp = amp
|
amp = amp
|
||||||
device = device
|
device = Device
|
||||||
Generator = Generator
|
Generator = Generator
|
||||||
|
|
||||||
context_device: Optional[torch.device] = None
|
context_device: Optional[torch.device] = None
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .utils import rDevice, get_device
|
from .utils import rDevice, get_device
|
||||||
|
|
||||||
class device:
|
|
||||||
|
class Device:
|
||||||
def __enter__(self, device: Optional[rDevice]=None):
|
def __enter__(self, device: Optional[rDevice]=None):
|
||||||
torch.dml.context_device = get_device(device)
|
torch.dml.context_device = get_device(device)
|
||||||
|
|
||||||
def __init__(self, device: Optional[rDevice]=None) -> torch.device:
|
def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
|
||||||
return get_device(device)
|
return get_device(device)
|
||||||
|
|
||||||
def __exit__(self, type, val, tb):
|
def __exit__(self, t, v, tb):
|
||||||
torch.dml.context_device = None
|
torch.dml.context_device = None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class DeviceProperties:
|
class DeviceProperties:
|
||||||
type: str = "directml"
|
type: str = "directml"
|
||||||
name: str
|
name: str
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from modules.shared import log, opts
|
from modules.shared import log, opts
|
||||||
|
|
||||||
|
|
||||||
def catch_nan(func: Callable[[], torch.Tensor]):
|
def catch_nan(func: Callable[[], torch.Tensor]):
|
||||||
if not opts.directml_catch_nan:
|
if not opts.directml_catch_nan:
|
||||||
return func()
|
return func()
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from os import getpid
|
from os import getpid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path
|
from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path
|
||||||
|
|
||||||
|
|
||||||
class MemoryProvider:
|
class MemoryProvider:
|
||||||
hQuery: HQuery
|
hQuery: HQuery
|
||||||
hCounters: defaultdict[str, list[HCounter]]
|
hCounters: defaultdict[str, list[HCounter]]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
from .driver.atiadlxx import ATIADLxx
|
from .driver.atiadlxx import ATIADLxx
|
||||||
|
|
||||||
|
|
||||||
class AMDMemoryProvider:
|
class AMDMemoryProvider:
|
||||||
driver: ATIADLxx = ATIADLxx()
|
driver: ATIADLxx = ATIADLxx()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def mem_get_info(index):
|
def mem_get_info(index):
|
||||||
usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20)
|
usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20)
|
||||||
return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize)
|
return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import ctypes as C
|
import ctypes as C
|
||||||
from .atiadlxx_apis import *
|
from modules.dml.memory_amd.driver.atiadlxx_apis import ADL2_Main_Control_Create, ADL_Main_Memory_Alloc, ADL2_Adapter_NumberOfAdapters_Get, ADL2_Adapter_AdapterInfo_Get, ADL2_Adapter_MemoryInfo2_Get, ADL2_Adapter_DedicatedVRAMUsage_Get, ADL2_Adapter_VRAMUsage_Get
|
||||||
from .atiadlxx_structures import *
|
from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, AdapterInfo, LPAdapterInfo, ADLMemoryInfo2
|
||||||
from .atiadlxx_defines import *
|
from modules.dml.memory_amd.driver.atiadlxx_defines import ADL_OK
|
||||||
|
|
||||||
class ATIADLxx(object):
|
|
||||||
|
class ATIADLxx:
|
||||||
iHyperMemorySize = 0
|
iHyperMemorySize = 0
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,30 @@
|
||||||
import ctypes as C
|
import ctypes as C
|
||||||
from platform import platform
|
from platform import platform
|
||||||
from .atiadlxx_structures import *
|
from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, LPAdapterInfo, ADLMemoryInfo2
|
||||||
|
|
||||||
|
|
||||||
if 'Windows' in platform():
|
if 'Windows' in platform():
|
||||||
atiadlxx = C.WinDLL("atiadlxx.dll")
|
atiadlxx = C.WinDLL("atiadlxx.dll")
|
||||||
else:
|
else:
|
||||||
atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported.
|
atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported.
|
||||||
|
|
||||||
|
|
||||||
ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int)
|
ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int)
|
||||||
ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p))
|
ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p))
|
||||||
|
|
||||||
|
|
||||||
@ADL_MAIN_MALLOC_CALLBACK
|
@ADL_MAIN_MALLOC_CALLBACK
|
||||||
def ADL_Main_Memory_Alloc(iSize):
|
def ADL_Main_Memory_Alloc(iSize):
|
||||||
return C._malloc(iSize)
|
return C._malloc(iSize)
|
||||||
|
|
||||||
|
|
||||||
@ADL_MAIN_FREE_CALLBACK
|
@ADL_MAIN_FREE_CALLBACK
|
||||||
def ADL_Main_Memory_Free(lpBuffer):
|
def ADL_Main_Memory_Free(lpBuffer):
|
||||||
if lpBuffer[0] is not None:
|
if lpBuffer[0] is not None:
|
||||||
C._free(lpBuffer[0])
|
C._free(lpBuffer[0])
|
||||||
lpBuffer[0] = None
|
lpBuffer[0] = None
|
||||||
|
|
||||||
|
|
||||||
ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create
|
ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create
|
||||||
ADL2_Main_Control_Create.restype = C.c_int
|
ADL2_Main_Control_Create.restype = C.c_int
|
||||||
ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE]
|
ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import ctypes as C
|
import ctypes as C
|
||||||
|
|
||||||
|
|
||||||
class _ADLPMActivity(C.Structure):
|
class _ADLPMActivity(C.Structure):
|
||||||
__slot__ = [
|
__slot__ = [
|
||||||
'iActivityPercent',
|
'iActivityPercent',
|
||||||
|
|
@ -13,7 +14,7 @@ class _ADLPMActivity(C.Structure):
|
||||||
'iSize',
|
'iSize',
|
||||||
'iVddc',
|
'iVddc',
|
||||||
]
|
]
|
||||||
_ADLPMActivity._fields_ = [
|
_ADLPMActivity._fields_ = [ # pylint: disable=protected-access
|
||||||
('iActivityPercent', C.c_int),
|
('iActivityPercent', C.c_int),
|
||||||
('iCurrentBusLanes', C.c_int),
|
('iCurrentBusLanes', C.c_int),
|
||||||
('iCurrentBusSpeed', C.c_int),
|
('iCurrentBusSpeed', C.c_int),
|
||||||
|
|
@ -27,6 +28,7 @@ _ADLPMActivity._fields_ = [
|
||||||
]
|
]
|
||||||
ADLPMActivity = _ADLPMActivity
|
ADLPMActivity = _ADLPMActivity
|
||||||
|
|
||||||
|
|
||||||
class _ADLMemoryInfo2(C.Structure):
|
class _ADLMemoryInfo2(C.Structure):
|
||||||
__slot__ = [
|
__slot__ = [
|
||||||
'iHyperMemorySize',
|
'iHyperMemorySize',
|
||||||
|
|
@ -36,7 +38,7 @@ class _ADLMemoryInfo2(C.Structure):
|
||||||
'iVisibleMemorySize',
|
'iVisibleMemorySize',
|
||||||
'strMemoryType'
|
'strMemoryType'
|
||||||
]
|
]
|
||||||
_ADLMemoryInfo2._fields_ = [
|
_ADLMemoryInfo2._fields_ = [ # pylint: disable=protected-access
|
||||||
('iHyperMemorySize', C.c_longlong),
|
('iHyperMemorySize', C.c_longlong),
|
||||||
('iInvisibleMemorySize', C.c_longlong),
|
('iInvisibleMemorySize', C.c_longlong),
|
||||||
('iMemoryBandwidth', C.c_longlong),
|
('iMemoryBandwidth', C.c_longlong),
|
||||||
|
|
@ -46,6 +48,7 @@ _ADLMemoryInfo2._fields_ = [
|
||||||
]
|
]
|
||||||
ADLMemoryInfo2 = _ADLMemoryInfo2
|
ADLMemoryInfo2 = _ADLMemoryInfo2
|
||||||
|
|
||||||
|
|
||||||
class _AdapterInfo(C.Structure):
|
class _AdapterInfo(C.Structure):
|
||||||
__slot__ = [
|
__slot__ = [
|
||||||
'iSize',
|
'iSize',
|
||||||
|
|
@ -64,7 +67,7 @@ class _AdapterInfo(C.Structure):
|
||||||
'strPNPString',
|
'strPNPString',
|
||||||
'iOSDisplayIndex',
|
'iOSDisplayIndex',
|
||||||
]
|
]
|
||||||
_AdapterInfo._fields_ = [
|
_AdapterInfo._fields_ = [ # pylint: disable=protected-access
|
||||||
('iSize', C.c_int),
|
('iSize', C.c_int),
|
||||||
('iAdapterIndex', C.c_int),
|
('iAdapterIndex', C.c_int),
|
||||||
('strUDID', C.c_char * 256),
|
('strUDID', C.c_char * 256),
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,24 @@
|
||||||
from ctypes import *
|
from ctypes import *
|
||||||
from ctypes.wintypes import *
|
from ctypes.wintypes import *
|
||||||
from typing import NamedTuple, TypeVar
|
from typing import NamedTuple, TypeVar
|
||||||
|
|
||||||
from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery
|
from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery
|
||||||
from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
|
from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
|
||||||
from .defines import *
|
from .defines import *
|
||||||
from .msvcrt import malloc
|
from .msvcrt import malloc
|
||||||
from .errors import PDHError
|
from .errors import PDHError
|
||||||
|
|
||||||
|
|
||||||
class __InternalAbstraction(NamedTuple):
|
class __InternalAbstraction(NamedTuple):
|
||||||
flag: int
|
flag: int
|
||||||
attr_name: str
|
attr_name: str
|
||||||
|
|
||||||
|
|
||||||
_type_map = {
|
_type_map = {
|
||||||
int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"),
|
int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"),
|
||||||
float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"),
|
float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def expand_wildcard_path(path: str) -> list[str]:
|
def expand_wildcard_path(path: str) -> list[str]:
|
||||||
listLength = DWORD(0)
|
listLength = DWORD(0)
|
||||||
if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA:
|
if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA:
|
||||||
|
|
@ -24,33 +26,35 @@ def expand_wildcard_path(path: str) -> list[str]:
|
||||||
expanded = (WCHAR * listLength.value)()
|
expanded = (WCHAR * listLength.value)()
|
||||||
if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK:
|
if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK:
|
||||||
raise PDHError(f"Couldn't expand wildcard path '{path}'")
|
raise PDHError(f"Couldn't expand wildcard path '{path}'")
|
||||||
result = list()
|
result = []
|
||||||
cur = str()
|
cur = ""
|
||||||
for chr in expanded:
|
for c in expanded:
|
||||||
if chr == '\0':
|
if c == '\0':
|
||||||
result.append(cur)
|
result.append(cur)
|
||||||
cur = str()
|
cur = ""
|
||||||
else:
|
else:
|
||||||
cur += chr
|
cur += c
|
||||||
result.pop()
|
result.pop()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", *_type_map.keys())
|
T = TypeVar("T", *_type_map.keys())
|
||||||
|
|
||||||
|
|
||||||
class HCounter(PDH_HCOUNTER):
|
class HCounter(PDH_HCOUNTER):
|
||||||
def get_formatted_value(self, type: T) -> T:
|
def get_formatted_value(self, typ: T) -> T:
|
||||||
if type not in _type_map:
|
if typ not in _type_map:
|
||||||
raise PDHError(f"Invalid value type: {type}")
|
raise PDHError(f"Invalid value type: {typ}")
|
||||||
flag, attr_name = _type_map[type]
|
flag, attr_name = _type_map[typ]
|
||||||
value = PDH_FMT_COUNTERVALUE()
|
value = PDH_FMT_COUNTERVALUE()
|
||||||
if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK:
|
if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK:
|
||||||
raise PDHError("Couldn't get formatted counter value.")
|
raise PDHError("Couldn't get formatted counter value.")
|
||||||
return getattr(value.u, attr_name)
|
return getattr(value.u, attr_name)
|
||||||
|
|
||||||
def get_formatted_dict(self, type: T) -> dict[str, T]:
|
def get_formatted_dict(self, typ: T) -> dict[str, T]:
|
||||||
if type not in _type_map:
|
if typ not in _type_map:
|
||||||
raise PDHError(f"Invalid value type: {type}")
|
raise PDHError(f"Invalid value type: {typ}")
|
||||||
flag, attr_name = _type_map[type]
|
flag, attr_name = _type_map[typ]
|
||||||
bufferSize = DWORD(0)
|
bufferSize = DWORD(0)
|
||||||
itemCount = DWORD(0)
|
itemCount = DWORD(0)
|
||||||
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA:
|
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA:
|
||||||
|
|
@ -64,9 +68,10 @@ class HCounter(PDH_HCOUNTER):
|
||||||
result[item.szName] = getattr(item.FmtValue.u, attr_name)
|
result[item.szName] = getattr(item.FmtValue.u, attr_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class HQuery(PDH_HQUERY):
|
class HQuery(PDH_HQUERY):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(HQuery, self).__init__()
|
super().__init__()
|
||||||
if PdhOpenQueryW(None, None, byref(self)) != PDH_OK:
|
if PdhOpenQueryW(None, None, byref(self)) != PDH_OK:
|
||||||
raise PDHError("Couldn't open PDH query.")
|
raise PDHError("Couldn't open PDH query.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
from ctypes import *
|
from ctypes import CDLL, POINTER
|
||||||
from ctypes.wintypes import *
|
from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
|
||||||
|
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR
|
||||||
|
|
||||||
from .structures import *
|
|
||||||
from .defines import *
|
|
||||||
|
|
||||||
pdh = CDLL("pdh.dll")
|
pdh = CDLL("pdh.dll")
|
||||||
|
|
||||||
|
|
||||||
PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW
|
PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW
|
||||||
PdhExpandWildCardPathW.restype = PDH_FUNCTION
|
PdhExpandWildCardPathW.restype = PDH_FUNCTION
|
||||||
PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD]
|
PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from ctypes import *
|
from ctypes import c_int, POINTER
|
||||||
from ctypes.wintypes import *
|
from ctypes.wintypes import DWORD, WCHAR
|
||||||
|
|
||||||
|
|
||||||
PDH_FUNCTION = c_int
|
PDH_FUNCTION = c_int
|
||||||
PDH_OK = 0x00000000
|
PDH_OK = 0x00000000
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
class PDHError(Exception):
|
class PDHError(Exception):
|
||||||
def __init__(self, message: str):
|
def __init__(self, message: str):
|
||||||
super(PDHError, self).__init__(message)
|
super().__init__(message)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
from ctypes import *
|
from ctypes import CDLL, c_void_p, c_size_t
|
||||||
|
|
||||||
|
|
||||||
msvcrt = CDLL("msvcrt")
|
msvcrt = CDLL("msvcrt")
|
||||||
|
|
||||||
|
|
||||||
malloc = msvcrt.malloc
|
malloc = msvcrt.malloc
|
||||||
malloc.restype = c_void_p
|
malloc.restype = c_void_p
|
||||||
malloc.argtypes = [c_size_t]
|
malloc.argtypes = [c_size_t]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from ctypes import *
|
from ctypes import Union, c_double, c_longlong, Structure, POINTER
|
||||||
from ctypes.wintypes import *
|
from ctypes.wintypes import HANDLE, LONG, LPCSTR, LPCWSTR, DWORD, LPWSTR
|
||||||
|
|
||||||
|
|
||||||
PDH_HQUERY = HANDLE
|
PDH_HQUERY = HANDLE
|
||||||
PDH_HCOUNTER = HANDLE
|
PDH_HCOUNTER = HANDLE
|
||||||
|
|
||||||
|
|
||||||
class PDH_FMT_COUNTERVALUE_U(Union):
|
class PDH_FMT_COUNTERVALUE_U(Union):
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("longValue", LONG),
|
("longValue", LONG),
|
||||||
|
|
@ -19,6 +21,7 @@ class PDH_FMT_COUNTERVALUE_U(Union):
|
||||||
AnsiStringValue: LPCSTR
|
AnsiStringValue: LPCSTR
|
||||||
WideStringValue: LPCWSTR
|
WideStringValue: LPCWSTR
|
||||||
|
|
||||||
|
|
||||||
class PDH_FMT_COUNTERVALUE(Structure):
|
class PDH_FMT_COUNTERVALUE(Structure):
|
||||||
_anonymous_ = ("u",)
|
_anonymous_ = ("u",)
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
|
|
@ -30,6 +33,7 @@ class PDH_FMT_COUNTERVALUE(Structure):
|
||||||
u: PDH_FMT_COUNTERVALUE_U
|
u: PDH_FMT_COUNTERVALUE_U
|
||||||
PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE)
|
PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE)
|
||||||
|
|
||||||
|
|
||||||
class PDH_FMT_COUNTERVALUE_ITEM_W(Structure):
|
class PDH_FMT_COUNTERVALUE_ITEM_W(Structure):
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
("szName", LPWSTR),
|
("szName", LPWSTR),
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
rDevice = Union[torch.device, int]
|
rDevice = Union[torch.device, int]
|
||||||
def get_device(device: Optional[rDevice]=None) -> torch.device:
|
def get_device(device: Optional[rDevice]=None) -> torch.device:
|
||||||
if device is None:
|
if device is None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue