fix code formatting under modules/dml

pull/2803/head
Seunghoon Lee 2024-02-05 22:43:10 +09:00
parent ef29f1a238
commit ff2c1db1cc
No known key found for this signature in database
GPG Key ID: 436E38F4E70BD152
19 changed files with 91 additions and 46 deletions

View File

@ -1,5 +1,6 @@
import torch
from typing import Optional from typing import Optional
import torch
class Generator(torch.Generator): class Generator(torch.Generator):
def __init__(self, device: Optional[torch.device] = None): def __init__(self, device: Optional[torch.device] = None):

View File

@ -4,6 +4,7 @@ import torch
from modules.errors import log from modules.errors import log
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
memory_providers = ["None", "atiadlxx (AMD only)"] memory_providers = ["None", "atiadlxx (AMD only)"]
default_memory_provider = "None" default_memory_provider = "None"
if platform.system() == "Windows": if platform.system() == "Windows":
@ -12,6 +13,7 @@ if platform.system() == "Windows":
do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment
do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment
def _set_memory_provider(): def _set_memory_provider():
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
if opts.directml_memory_provider == "Performance Counter": if opts.directml_memory_provider == "Performance Counter":
@ -35,6 +37,7 @@ def _set_memory_provider():
torch.dml.mem_get_info = mem_get_info torch.dml.mem_get_info = mem_get_info
torch.cuda.mem_get_info = torch.dml.mem_get_info torch.cuda.mem_get_info = torch.dml.mem_get_info
def directml_init(): def directml_init():
try: try:
from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports
@ -63,6 +66,7 @@ def directml_init():
return False, e return False, e
return True, None return True, None
def directml_do_hijack(): def directml_do_hijack():
import modules.dml.hijack # pylint: disable=unused-import import modules.dml.hijack # pylint: disable=unused-import
from modules.devices import device from modules.devices import device
@ -79,17 +83,20 @@ def directml_do_hijack():
_set_memory_provider() _set_memory_provider()
class OverrideItem(NamedTuple): class OverrideItem(NamedTuple):
value: str value: str
condition: Optional[Callable] condition: Optional[Callable]
message: Optional[str] message: Optional[str]
opts_override_table = { opts_override_table = {
"diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"), "diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"),
"diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"), "diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"),
"diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"), "diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"),
} }
def directml_override_opts(): def directml_override_opts():
from modules import shared from modules import shared

View File

@ -2,12 +2,14 @@ import importlib
from typing import Any, Optional from typing import Any, Optional
import torch import torch
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"] ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
supported_cast_pairs = { supported_cast_pairs = {
torch.float16: (torch.float32,), torch.float16: (torch.float32,),
torch.float32: (torch.float16,), torch.float32: (torch.float16,),
} }
def forward(op, args: tuple, kwargs: dict): def forward(op, args: tuple, kwargs: dict):
if not torch.dml.is_autocast_enabled: if not torch.dml.is_autocast_enabled:
return op(*args, **kwargs) return op(*args, **kwargs)
@ -16,6 +18,7 @@ def forward(op, args: tuple, kwargs: dict):
kwargs[kwarg] = cast(kwargs[kwarg]) kwargs[kwarg] = cast(kwargs[kwarg])
return op(*args, **kwargs) return op(*args, **kwargs)
def cast(tensor: torch.Tensor): def cast(tensor: torch.Tensor):
if not torch.is_tensor(tensor): if not torch.is_tensor(tensor):
return tensor return tensor
@ -24,6 +27,7 @@ def cast(tensor: torch.Tensor):
return tensor return tensor
return tensor.type(torch.dml.autocast_gpu_dtype) return tensor.type(torch.dml.autocast_gpu_dtype)
def cond(op: str): def cond(op: str):
if isinstance(op, str): if isinstance(op, str):
func_path = op.split('.') func_path = op.split('.')
@ -38,8 +42,10 @@ def cond(op: str):
op = getattr(resolved_obj, func_path[-1]) op = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs)) setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
for op in ops:
cond(op) for o in ops:
cond(o)
class autocast: class autocast:
prev: bool prev: bool

View File

@ -3,26 +3,29 @@ from typing import Optional, Callable
import torch import torch
import torch_directml # pylint: disable=import-error import torch_directml # pylint: disable=import-error
import modules.dml.amp as amp import modules.dml.amp as amp
from .utils import rDevice, get_device from .utils import rDevice, get_device
from .device import device from .device import Device
from .Generator import Generator from .Generator import Generator
from .device_properties import DeviceProperties from .device_properties import DeviceProperties
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
from .memory_amd import AMDMemoryProvider from .memory_amd import AMDMemoryProvider
return AMDMemoryProvider.mem_get_info(get_device(device).index) return AMDMemoryProvider.mem_get_info(get_device(device).index)
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
mem_info = DirectML.memory_provider.get_memory(get_device(device).index) mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"]) return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
return (8589934592, 8589934592) return (8589934592, 8589934592)
class DirectML: class DirectML:
amp = amp amp = amp
device = device device = Device
Generator = Generator Generator = Generator
context_device: Optional[torch.device] = None context_device: Optional[torch.device] = None

View File

@ -1,14 +1,14 @@
from typing import Optional from typing import Optional
import torch import torch
from .utils import rDevice, get_device from .utils import rDevice, get_device
class device:
class Device:
def __enter__(self, device: Optional[rDevice]=None): def __enter__(self, device: Optional[rDevice]=None):
torch.dml.context_device = get_device(device) torch.dml.context_device = get_device(device)
def __init__(self, device: Optional[rDevice]=None) -> torch.device: def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
return get_device(device) return get_device(device)
def __exit__(self, type, val, tb): def __exit__(self, t, v, tb):
torch.dml.context_device = None torch.dml.context_device = None

View File

@ -1,5 +1,6 @@
import torch import torch
class DeviceProperties: class DeviceProperties:
type: str = "directml" type: str = "directml"
name: str name: str

View File

@ -2,6 +2,7 @@ import torch
from typing import Callable from typing import Callable
from modules.shared import log, opts from modules.shared import log, opts
def catch_nan(func: Callable[[], torch.Tensor]): def catch_nan(func: Callable[[], torch.Tensor]):
if not opts.directml_catch_nan: if not opts.directml_catch_nan:
return func() return func()

View File

@ -1,8 +1,8 @@
from os import getpid from os import getpid
from collections import defaultdict from collections import defaultdict
from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path
class MemoryProvider: class MemoryProvider:
hQuery: HQuery hQuery: HQuery
hCounters: defaultdict[str, list[HCounter]] hCounters: defaultdict[str, list[HCounter]]

View File

@ -1,7 +1,10 @@
from .driver.atiadlxx import ATIADLxx from .driver.atiadlxx import ATIADLxx
class AMDMemoryProvider: class AMDMemoryProvider:
driver: ATIADLxx = ATIADLxx() driver: ATIADLxx = ATIADLxx()
@staticmethod
def mem_get_info(index): def mem_get_info(index):
usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20) usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20)
return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize) return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize)

View File

@ -1,9 +1,10 @@
import ctypes as C import ctypes as C
from .atiadlxx_apis import * from modules.dml.memory_amd.driver.atiadlxx_apis import ADL2_Main_Control_Create, ADL_Main_Memory_Alloc, ADL2_Adapter_NumberOfAdapters_Get, ADL2_Adapter_AdapterInfo_Get, ADL2_Adapter_MemoryInfo2_Get, ADL2_Adapter_DedicatedVRAMUsage_Get, ADL2_Adapter_VRAMUsage_Get
from .atiadlxx_structures import * from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, AdapterInfo, LPAdapterInfo, ADLMemoryInfo2
from .atiadlxx_defines import * from modules.dml.memory_amd.driver.atiadlxx_defines import ADL_OK
class ATIADLxx(object):
class ATIADLxx:
iHyperMemorySize = 0 iHyperMemorySize = 0
def __init__(self): def __init__(self):

View File

@ -1,25 +1,30 @@
import ctypes as C import ctypes as C
from platform import platform from platform import platform
from .atiadlxx_structures import * from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, LPAdapterInfo, ADLMemoryInfo2
if 'Windows' in platform(): if 'Windows' in platform():
atiadlxx = C.WinDLL("atiadlxx.dll") atiadlxx = C.WinDLL("atiadlxx.dll")
else: else:
atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported. atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported.
ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int) ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int)
ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p)) ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p))
@ADL_MAIN_MALLOC_CALLBACK @ADL_MAIN_MALLOC_CALLBACK
def ADL_Main_Memory_Alloc(iSize): def ADL_Main_Memory_Alloc(iSize):
return C._malloc(iSize) return C._malloc(iSize)
@ADL_MAIN_FREE_CALLBACK @ADL_MAIN_FREE_CALLBACK
def ADL_Main_Memory_Free(lpBuffer): def ADL_Main_Memory_Free(lpBuffer):
if lpBuffer[0] is not None: if lpBuffer[0] is not None:
C._free(lpBuffer[0]) C._free(lpBuffer[0])
lpBuffer[0] = None lpBuffer[0] = None
ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create
ADL2_Main_Control_Create.restype = C.c_int ADL2_Main_Control_Create.restype = C.c_int
ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE] ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE]

View File

@ -1,5 +1,6 @@
import ctypes as C import ctypes as C
class _ADLPMActivity(C.Structure): class _ADLPMActivity(C.Structure):
__slot__ = [ __slot__ = [
'iActivityPercent', 'iActivityPercent',
@ -13,7 +14,7 @@ class _ADLPMActivity(C.Structure):
'iSize', 'iSize',
'iVddc', 'iVddc',
] ]
_ADLPMActivity._fields_ = [ _ADLPMActivity._fields_ = [ # pylint: disable=protected-access
('iActivityPercent', C.c_int), ('iActivityPercent', C.c_int),
('iCurrentBusLanes', C.c_int), ('iCurrentBusLanes', C.c_int),
('iCurrentBusSpeed', C.c_int), ('iCurrentBusSpeed', C.c_int),
@ -27,6 +28,7 @@ _ADLPMActivity._fields_ = [
] ]
ADLPMActivity = _ADLPMActivity ADLPMActivity = _ADLPMActivity
class _ADLMemoryInfo2(C.Structure): class _ADLMemoryInfo2(C.Structure):
__slot__ = [ __slot__ = [
'iHyperMemorySize', 'iHyperMemorySize',
@ -36,7 +38,7 @@ class _ADLMemoryInfo2(C.Structure):
'iVisibleMemorySize', 'iVisibleMemorySize',
'strMemoryType' 'strMemoryType'
] ]
_ADLMemoryInfo2._fields_ = [ _ADLMemoryInfo2._fields_ = [ # pylint: disable=protected-access
('iHyperMemorySize', C.c_longlong), ('iHyperMemorySize', C.c_longlong),
('iInvisibleMemorySize', C.c_longlong), ('iInvisibleMemorySize', C.c_longlong),
('iMemoryBandwidth', C.c_longlong), ('iMemoryBandwidth', C.c_longlong),
@ -46,6 +48,7 @@ _ADLMemoryInfo2._fields_ = [
] ]
ADLMemoryInfo2 = _ADLMemoryInfo2 ADLMemoryInfo2 = _ADLMemoryInfo2
class _AdapterInfo(C.Structure): class _AdapterInfo(C.Structure):
__slot__ = [ __slot__ = [
'iSize', 'iSize',
@ -64,7 +67,7 @@ class _AdapterInfo(C.Structure):
'strPNPString', 'strPNPString',
'iOSDisplayIndex', 'iOSDisplayIndex',
] ]
_AdapterInfo._fields_ = [ _AdapterInfo._fields_ = [ # pylint: disable=protected-access
('iSize', C.c_int), ('iSize', C.c_int),
('iAdapterIndex', C.c_int), ('iAdapterIndex', C.c_int),
('strUDID', C.c_char * 256), ('strUDID', C.c_char * 256),

View File

@ -1,22 +1,24 @@
from ctypes import * from ctypes import *
from ctypes.wintypes import * from ctypes.wintypes import *
from typing import NamedTuple, TypeVar from typing import NamedTuple, TypeVar
from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery
from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
from .defines import * from .defines import *
from .msvcrt import malloc from .msvcrt import malloc
from .errors import PDHError from .errors import PDHError
class __InternalAbstraction(NamedTuple): class __InternalAbstraction(NamedTuple):
flag: int flag: int
attr_name: str attr_name: str
_type_map = { _type_map = {
int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"), int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"),
float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"), float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"),
} }
def expand_wildcard_path(path: str) -> list[str]: def expand_wildcard_path(path: str) -> list[str]:
listLength = DWORD(0) listLength = DWORD(0)
if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA: if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA:
@ -24,33 +26,35 @@ def expand_wildcard_path(path: str) -> list[str]:
expanded = (WCHAR * listLength.value)() expanded = (WCHAR * listLength.value)()
if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK: if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK:
raise PDHError(f"Couldn't expand wildcard path '{path}'") raise PDHError(f"Couldn't expand wildcard path '{path}'")
result = list() result = []
cur = str() cur = ""
for chr in expanded: for c in expanded:
if chr == '\0': if c == '\0':
result.append(cur) result.append(cur)
cur = str() cur = ""
else: else:
cur += chr cur += c
result.pop() result.pop()
return result return result
T = TypeVar("T", *_type_map.keys()) T = TypeVar("T", *_type_map.keys())
class HCounter(PDH_HCOUNTER): class HCounter(PDH_HCOUNTER):
def get_formatted_value(self, type: T) -> T: def get_formatted_value(self, typ: T) -> T:
if type not in _type_map: if typ not in _type_map:
raise PDHError(f"Invalid value type: {type}") raise PDHError(f"Invalid value type: {typ}")
flag, attr_name = _type_map[type] flag, attr_name = _type_map[typ]
value = PDH_FMT_COUNTERVALUE() value = PDH_FMT_COUNTERVALUE()
if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK: if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK:
raise PDHError("Couldn't get formatted counter value.") raise PDHError("Couldn't get formatted counter value.")
return getattr(value.u, attr_name) return getattr(value.u, attr_name)
def get_formatted_dict(self, type: T) -> dict[str, T]: def get_formatted_dict(self, typ: T) -> dict[str, T]:
if type not in _type_map: if typ not in _type_map:
raise PDHError(f"Invalid value type: {type}") raise PDHError(f"Invalid value type: {typ}")
flag, attr_name = _type_map[type] flag, attr_name = _type_map[typ]
bufferSize = DWORD(0) bufferSize = DWORD(0)
itemCount = DWORD(0) itemCount = DWORD(0)
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA: if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA:
@ -64,9 +68,10 @@ class HCounter(PDH_HCOUNTER):
result[item.szName] = getattr(item.FmtValue.u, attr_name) result[item.szName] = getattr(item.FmtValue.u, attr_name)
return result return result
class HQuery(PDH_HQUERY): class HQuery(PDH_HQUERY):
def __init__(self): def __init__(self):
super(HQuery, self).__init__() super().__init__()
if PdhOpenQueryW(None, None, byref(self)) != PDH_OK: if PdhOpenQueryW(None, None, byref(self)) != PDH_OK:
raise PDHError("Couldn't open PDH query.") raise PDHError("Couldn't open PDH query.")

View File

@ -1,12 +1,13 @@
from ctypes import * from ctypes import CDLL, POINTER
from ctypes.wintypes import * from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
from typing import Callable from typing import Callable
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR
from .structures import *
from .defines import *
pdh = CDLL("pdh.dll") pdh = CDLL("pdh.dll")
PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW
PdhExpandWildCardPathW.restype = PDH_FUNCTION PdhExpandWildCardPathW.restype = PDH_FUNCTION
PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD] PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD]

View File

@ -1,5 +1,6 @@
from ctypes import * from ctypes import c_int, POINTER
from ctypes.wintypes import * from ctypes.wintypes import DWORD, WCHAR
PDH_FUNCTION = c_int PDH_FUNCTION = c_int
PDH_OK = 0x00000000 PDH_OK = 0x00000000

View File

@ -1,3 +1,3 @@
class PDHError(Exception): class PDHError(Exception):
def __init__(self, message: str): def __init__(self, message: str):
super(PDHError, self).__init__(message) super().__init__(message)

View File

@ -1,7 +1,9 @@
from ctypes import * from ctypes import CDLL, c_void_p, c_size_t
msvcrt = CDLL("msvcrt") msvcrt = CDLL("msvcrt")
malloc = msvcrt.malloc malloc = msvcrt.malloc
malloc.restype = c_void_p malloc.restype = c_void_p
malloc.argtypes = [c_size_t] malloc.argtypes = [c_size_t]

View File

@ -1,9 +1,11 @@
from ctypes import * from ctypes import Union, c_double, c_longlong, Structure, POINTER
from ctypes.wintypes import * from ctypes.wintypes import HANDLE, LONG, LPCSTR, LPCWSTR, DWORD, LPWSTR
PDH_HQUERY = HANDLE PDH_HQUERY = HANDLE
PDH_HCOUNTER = HANDLE PDH_HCOUNTER = HANDLE
class PDH_FMT_COUNTERVALUE_U(Union): class PDH_FMT_COUNTERVALUE_U(Union):
_fields_ = [ _fields_ = [
("longValue", LONG), ("longValue", LONG),
@ -19,6 +21,7 @@ class PDH_FMT_COUNTERVALUE_U(Union):
AnsiStringValue: LPCSTR AnsiStringValue: LPCSTR
WideStringValue: LPCWSTR WideStringValue: LPCWSTR
class PDH_FMT_COUNTERVALUE(Structure): class PDH_FMT_COUNTERVALUE(Structure):
_anonymous_ = ("u",) _anonymous_ = ("u",)
_fields_ = [ _fields_ = [
@ -30,6 +33,7 @@ class PDH_FMT_COUNTERVALUE(Structure):
u: PDH_FMT_COUNTERVALUE_U u: PDH_FMT_COUNTERVALUE_U
PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE) PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE)
class PDH_FMT_COUNTERVALUE_ITEM_W(Structure): class PDH_FMT_COUNTERVALUE_ITEM_W(Structure):
_fields_ = [ _fields_ = [
("szName", LPWSTR), ("szName", LPWSTR),

View File

@ -1,6 +1,7 @@
from typing import Optional, Union from typing import Optional, Union
import torch import torch
rDevice = Union[torch.device, int] rDevice = Union[torch.device, int]
def get_device(device: Optional[rDevice]=None) -> torch.device: def get_device(device: Optional[rDevice]=None) -> torch.device:
if device is None: if device is None: