From ff2c1db1cc7e68f3a9e2e71f30b5a7f6b6618f8d Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Mon, 5 Feb 2024 22:43:10 +0900 Subject: [PATCH] fix code formatting under modules/dml --- modules/dml/Generator.py | 3 +- modules/dml/__init__.py | 7 ++++ modules/dml/amp/autocast_mode.py | 10 ++++- modules/dml/backend.py | 11 ++++-- modules/dml/device.py | 8 ++-- modules/dml/device_properties.py | 1 + modules/dml/hijack/utils.py | 1 + modules/dml/memory.py | 2 +- modules/dml/memory_amd/__init__.py | 3 ++ modules/dml/memory_amd/driver/atiadlxx.py | 9 +++-- .../dml/memory_amd/driver/atiadlxx_apis.py | 7 +++- .../memory_amd/driver/atiadlxx_structures.py | 9 +++-- modules/dml/pdh/__init__.py | 37 +++++++++++-------- modules/dml/pdh/apis.py | 9 +++-- modules/dml/pdh/defines.py | 5 ++- modules/dml/pdh/errors.py | 2 +- modules/dml/pdh/msvcrt.py | 4 +- modules/dml/pdh/structures.py | 8 +++- modules/dml/utils.py | 1 + 19 files changed, 91 insertions(+), 46 deletions(-) diff --git a/modules/dml/Generator.py b/modules/dml/Generator.py index c35f87361..ea273310c 100644 --- a/modules/dml/Generator.py +++ b/modules/dml/Generator.py @@ -1,5 +1,6 @@ -import torch from typing import Optional +import torch + class Generator(torch.Generator): def __init__(self, device: Optional[torch.device] = None): diff --git a/modules/dml/__init__.py b/modules/dml/__init__.py index 358885650..8eed1aad1 100644 --- a/modules/dml/__init__.py +++ b/modules/dml/__init__.py @@ -4,6 +4,7 @@ import torch from modules.errors import log from modules.sd_hijack_utils import CondFunc + memory_providers = ["None", "atiadlxx (AMD only)"] default_memory_provider = "None" if platform.system() == "Windows": @@ -12,6 +13,7 @@ if platform.system() == "Windows": do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment + def _set_memory_provider(): from modules.shared import opts, cmd_opts if opts.directml_memory_provider == "Performance Counter": @@ -35,6 +37,7 @@ def _set_memory_provider(): torch.dml.mem_get_info = mem_get_info torch.cuda.mem_get_info = torch.dml.mem_get_info + def directml_init(): try: from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports @@ -63,6 +66,7 @@ def directml_init(): return False, e return True, None + def directml_do_hijack(): import modules.dml.hijack # pylint: disable=unused-import from modules.devices import device @@ -79,17 +83,20 @@ def directml_do_hijack(): _set_memory_provider() + class OverrideItem(NamedTuple): value: str condition: Optional[Callable] message: Optional[str] + opts_override_table = { "diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"), "diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"), "diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"), } + def directml_override_opts(): from modules import shared diff --git a/modules/dml/amp/autocast_mode.py b/modules/dml/amp/autocast_mode.py index a5766dd32..401d26d9e 100644 --- a/modules/dml/amp/autocast_mode.py +++ b/modules/dml/amp/autocast_mode.py @@ -2,12 +2,14 @@ import importlib from typing import Any, Optional import torch + ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"] supported_cast_pairs = { torch.float16: (torch.float32,), torch.float32: (torch.float16,), } + def forward(op, args: tuple, kwargs: dict): if not torch.dml.is_autocast_enabled: return op(*args, **kwargs) @@ -16,6 +18,7 @@ def forward(op, args: tuple, kwargs: dict): kwargs[kwarg] = cast(kwargs[kwarg]) return op(*args, **kwargs) + def cast(tensor: torch.Tensor): if not torch.is_tensor(tensor): return tensor @@ -24,6 +27,7 @@ def cast(tensor: torch.Tensor): return tensor return tensor.type(torch.dml.autocast_gpu_dtype) + def cond(op: str): if isinstance(op, str): func_path = op.split('.') @@ -38,8 +42,10 @@ def cond(op: str): op = getattr(resolved_obj, func_path[-1]) setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs)) -for op in ops: - cond(op) + +for o in ops: + cond(o) + class autocast: prev: bool diff --git a/modules/dml/backend.py b/modules/dml/backend.py index 16fd5231b..7947dc81b 100644 --- a/modules/dml/backend.py +++ b/modules/dml/backend.py @@ -3,26 +3,29 @@ from typing import Optional, Callable import torch import torch_directml # pylint: disable=import-error import modules.dml.amp as amp - from .utils import rDevice, get_device -from .device import device +from .device import Device from .Generator import Generator from .device_properties import DeviceProperties + def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: from .memory_amd import AMDMemoryProvider return AMDMemoryProvider.mem_get_info(get_device(device).index) + def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: mem_info = DirectML.memory_provider.get_memory(get_device(device).index) return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"]) -def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: + +def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument return (8589934592, 8589934592) + class DirectML: amp = amp - device = device + device = Device Generator = Generator context_device: Optional[torch.device] = None diff --git a/modules/dml/device.py b/modules/dml/device.py index c84de1b9a..b5e2c8a36 100644 --- a/modules/dml/device.py +++ b/modules/dml/device.py @@ -1,14 +1,14 @@ from typing import Optional import torch - from .utils import rDevice, get_device -class device: + +class Device: def __enter__(self, device: Optional[rDevice]=None): torch.dml.context_device = get_device(device) - def __init__(self, device: Optional[rDevice]=None) -> torch.device: + def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init return get_device(device) - def __exit__(self, type, val, tb): + def __exit__(self, t, v, tb): torch.dml.context_device = None diff --git a/modules/dml/device_properties.py b/modules/dml/device_properties.py index abd146653..1d8328478 100644 --- a/modules/dml/device_properties.py +++ b/modules/dml/device_properties.py @@ -1,5 +1,6 @@ import torch + class DeviceProperties: type: str = "directml" name: str diff --git a/modules/dml/hijack/utils.py b/modules/dml/hijack/utils.py index 9e0fb2480..659431c22 100644 --- a/modules/dml/hijack/utils.py +++ b/modules/dml/hijack/utils.py @@ -2,6 +2,7 @@ import torch from typing import Callable from modules.shared import log, opts + def catch_nan(func: Callable[[], torch.Tensor]): if not opts.directml_catch_nan: return func() diff --git a/modules/dml/memory.py b/modules/dml/memory.py index af2d8060f..c3ca959b6 100644 --- a/modules/dml/memory.py +++ b/modules/dml/memory.py @@ -1,8 +1,8 @@ from os import getpid from collections import defaultdict - from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path + class MemoryProvider: hQuery: HQuery hCounters: defaultdict[str, list[HCounter]] diff --git a/modules/dml/memory_amd/__init__.py b/modules/dml/memory_amd/__init__.py index 9928d5bc5..43907656c 100644 --- a/modules/dml/memory_amd/__init__.py +++ b/modules/dml/memory_amd/__init__.py @@ -1,7 +1,10 @@ from .driver.atiadlxx import ATIADLxx + class AMDMemoryProvider: driver: ATIADLxx = ATIADLxx() + + @staticmethod def mem_get_info(index): usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20) return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize) diff --git a/modules/dml/memory_amd/driver/atiadlxx.py b/modules/dml/memory_amd/driver/atiadlxx.py index e0b602e28..ea47e8c4d 100644 --- a/modules/dml/memory_amd/driver/atiadlxx.py +++ b/modules/dml/memory_amd/driver/atiadlxx.py @@ -1,9 +1,10 @@ import ctypes as C -from .atiadlxx_apis import * -from .atiadlxx_structures import * -from .atiadlxx_defines import * +from modules.dml.memory_amd.driver.atiadlxx_apis import ADL2_Main_Control_Create, ADL_Main_Memory_Alloc, ADL2_Adapter_NumberOfAdapters_Get, ADL2_Adapter_AdapterInfo_Get, ADL2_Adapter_MemoryInfo2_Get, ADL2_Adapter_DedicatedVRAMUsage_Get, ADL2_Adapter_VRAMUsage_Get +from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, AdapterInfo, LPAdapterInfo, ADLMemoryInfo2 +from modules.dml.memory_amd.driver.atiadlxx_defines import ADL_OK -class ATIADLxx(object): + +class ATIADLxx: iHyperMemorySize = 0 def __init__(self): diff --git a/modules/dml/memory_amd/driver/atiadlxx_apis.py b/modules/dml/memory_amd/driver/atiadlxx_apis.py index 25e6390ef..789622a4c 100644 --- a/modules/dml/memory_amd/driver/atiadlxx_apis.py +++ b/modules/dml/memory_amd/driver/atiadlxx_apis.py @@ -1,25 +1,30 @@ import ctypes as C from platform import platform -from .atiadlxx_structures import * +from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, LPAdapterInfo, ADLMemoryInfo2 + if 'Windows' in platform(): atiadlxx = C.WinDLL("atiadlxx.dll") else: atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported. + ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int) ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p)) + @ADL_MAIN_MALLOC_CALLBACK def ADL_Main_Memory_Alloc(iSize): return C._malloc(iSize) + @ADL_MAIN_FREE_CALLBACK def ADL_Main_Memory_Free(lpBuffer): if lpBuffer[0] is not None: C._free(lpBuffer[0]) lpBuffer[0] = None + ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create ADL2_Main_Control_Create.restype = C.c_int ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE] diff --git a/modules/dml/memory_amd/driver/atiadlxx_structures.py b/modules/dml/memory_amd/driver/atiadlxx_structures.py index c0443592d..db8c421c7 100644 --- a/modules/dml/memory_amd/driver/atiadlxx_structures.py +++ b/modules/dml/memory_amd/driver/atiadlxx_structures.py @@ -1,5 +1,6 @@ import ctypes as C + class _ADLPMActivity(C.Structure): __slot__ = [ 'iActivityPercent', @@ -13,7 +14,7 @@ class _ADLPMActivity(C.Structure): 'iSize', 'iVddc', ] -_ADLPMActivity._fields_ = [ +_ADLPMActivity._fields_ = [ # pylint: disable=protected-access ('iActivityPercent', C.c_int), ('iCurrentBusLanes', C.c_int), ('iCurrentBusSpeed', C.c_int), @@ -27,6 +28,7 @@ _ADLPMActivity._fields_ = [ ] ADLPMActivity = _ADLPMActivity + class _ADLMemoryInfo2(C.Structure): __slot__ = [ 'iHyperMemorySize', @@ -36,7 +38,7 @@ class _ADLMemoryInfo2(C.Structure): 'iVisibleMemorySize', 'strMemoryType' ] -_ADLMemoryInfo2._fields_ = [ +_ADLMemoryInfo2._fields_ = [ # pylint: disable=protected-access ('iHyperMemorySize', C.c_longlong), ('iInvisibleMemorySize', C.c_longlong), ('iMemoryBandwidth', C.c_longlong), @@ -46,6 +48,7 @@ _ADLMemoryInfo2._fields_ = [ ] ADLMemoryInfo2 = _ADLMemoryInfo2 + class _AdapterInfo(C.Structure): __slot__ = [ 'iSize', @@ -64,7 +67,7 @@ class _AdapterInfo(C.Structure): 'strPNPString', 'iOSDisplayIndex', ] -_AdapterInfo._fields_ = [ +_AdapterInfo._fields_ = [ # pylint: disable=protected-access ('iSize', C.c_int), ('iAdapterIndex', C.c_int), ('strUDID', C.c_char * 256), diff --git a/modules/dml/pdh/__init__.py b/modules/dml/pdh/__init__.py index 0dcd466cb..5f9980cf4 100644 --- a/modules/dml/pdh/__init__.py +++ b/modules/dml/pdh/__init__.py @@ -1,22 +1,24 @@ from ctypes import * from ctypes.wintypes import * from typing import NamedTuple, TypeVar - from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W from .defines import * from .msvcrt import malloc from .errors import PDHError + class __InternalAbstraction(NamedTuple): flag: int attr_name: str + _type_map = { int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"), float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"), } + def expand_wildcard_path(path: str) -> list[str]: listLength = DWORD(0) if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA: @@ -24,33 +26,35 @@ def expand_wildcard_path(path: str) -> list[str]: expanded = (WCHAR * listLength.value)() if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK: raise PDHError(f"Couldn't expand wildcard path '{path}'") - result = list() - cur = str() - for chr in expanded: - if chr == '\0': + result = [] + cur = "" + for c in expanded: + if c == '\0': result.append(cur) - cur = str() + cur = "" else: - cur += chr + cur += c result.pop() return result + T = TypeVar("T", *_type_map.keys()) + class HCounter(PDH_HCOUNTER): - def get_formatted_value(self, type: T) -> T: - if type not in _type_map: - raise PDHError(f"Invalid value type: {type}") - flag, attr_name = _type_map[type] + def get_formatted_value(self, typ: T) -> T: + if typ not in _type_map: + raise PDHError(f"Invalid value type: {typ}") + flag, attr_name = _type_map[typ] value = PDH_FMT_COUNTERVALUE() if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK: raise PDHError("Couldn't get formatted counter value.") return getattr(value.u, attr_name) - def get_formatted_dict(self, type: T) -> dict[str, T]: - if type not in _type_map: - raise PDHError(f"Invalid value type: {type}") - flag, attr_name = _type_map[type] + def get_formatted_dict(self, typ: T) -> dict[str, T]: + if typ not in _type_map: + raise PDHError(f"Invalid value type: {typ}") + flag, attr_name = _type_map[typ] bufferSize = DWORD(0) itemCount = DWORD(0) if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA: @@ -64,9 +68,10 @@ class HCounter(PDH_HCOUNTER): result[item.szName] = getattr(item.FmtValue.u, attr_name) return result + class HQuery(PDH_HQUERY): def __init__(self): - super(HQuery, self).__init__() + super().__init__() if PdhOpenQueryW(None, None, byref(self)) != PDH_OK: raise PDHError("Couldn't open PDH query.") diff --git a/modules/dml/pdh/apis.py b/modules/dml/pdh/apis.py index 87c1d1204..f01222b45 100644 --- a/modules/dml/pdh/apis.py +++ b/modules/dml/pdh/apis.py @@ -1,12 +1,13 @@ -from ctypes import * -from ctypes.wintypes import * +from ctypes import CDLL, POINTER +from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD from typing import Callable +from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W +from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR -from .structures import * -from .defines import * pdh = CDLL("pdh.dll") + PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW PdhExpandWildCardPathW.restype = PDH_FUNCTION PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD] diff --git a/modules/dml/pdh/defines.py b/modules/dml/pdh/defines.py index a5ea1d479..1d55bcf73 100644 --- a/modules/dml/pdh/defines.py +++ b/modules/dml/pdh/defines.py @@ -1,5 +1,6 @@ -from ctypes import * -from ctypes.wintypes import * +from ctypes import c_int, POINTER +from ctypes.wintypes import DWORD, WCHAR + PDH_FUNCTION = c_int PDH_OK = 0x00000000 diff --git a/modules/dml/pdh/errors.py b/modules/dml/pdh/errors.py index 60d9ab8f7..3d8270219 100644 --- a/modules/dml/pdh/errors.py +++ b/modules/dml/pdh/errors.py @@ -1,3 +1,3 @@ class PDHError(Exception): def __init__(self, message: str): - super(PDHError, self).__init__(message) + super().__init__(message) diff --git a/modules/dml/pdh/msvcrt.py b/modules/dml/pdh/msvcrt.py index bc5d93031..22374cccd 100644 --- a/modules/dml/pdh/msvcrt.py +++ b/modules/dml/pdh/msvcrt.py @@ -1,7 +1,9 @@ -from ctypes import * +from ctypes import CDLL, c_void_p, c_size_t + msvcrt = CDLL("msvcrt") + malloc = msvcrt.malloc malloc.restype = c_void_p malloc.argtypes = [c_size_t] diff --git a/modules/dml/pdh/structures.py b/modules/dml/pdh/structures.py index 8fb09e6cb..e38556ccf 100644 --- a/modules/dml/pdh/structures.py +++ b/modules/dml/pdh/structures.py @@ -1,9 +1,11 @@ -from ctypes import * -from ctypes.wintypes import * +from ctypes import Union, c_double, c_longlong, Structure, POINTER +from ctypes.wintypes import HANDLE, LONG, LPCSTR, LPCWSTR, DWORD, LPWSTR + PDH_HQUERY = HANDLE PDH_HCOUNTER = HANDLE + class PDH_FMT_COUNTERVALUE_U(Union): _fields_ = [ ("longValue", LONG), @@ -19,6 +21,7 @@ class PDH_FMT_COUNTERVALUE_U(Union): AnsiStringValue: LPCSTR WideStringValue: LPCWSTR + class PDH_FMT_COUNTERVALUE(Structure): _anonymous_ = ("u",) _fields_ = [ @@ -30,6 +33,7 @@ class PDH_FMT_COUNTERVALUE(Structure): u: PDH_FMT_COUNTERVALUE_U PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE) + class PDH_FMT_COUNTERVALUE_ITEM_W(Structure): _fields_ = [ ("szName", LPWSTR), diff --git a/modules/dml/utils.py b/modules/dml/utils.py index 5ed48c1ba..cb19ed900 100644 --- a/modules/dml/utils.py +++ b/modules/dml/utils.py @@ -1,6 +1,7 @@ from typing import Optional, Union import torch + rDevice = Union[torch.device, int] def get_device(device: Optional[rDevice]=None) -> torch.device: if device is None: