fix code formatting under modules/dml

pull/2803/head
Seunghoon Lee 2024-02-05 22:43:10 +09:00
parent ef29f1a238
commit ff2c1db1cc
No known key found for this signature in database
GPG Key ID: 436E38F4E70BD152
19 changed files with 91 additions and 46 deletions

View File

@ -1,5 +1,6 @@
import torch
from typing import Optional
import torch
class Generator(torch.Generator):
def __init__(self, device: Optional[torch.device] = None):

View File

@ -4,6 +4,7 @@ import torch
from modules.errors import log
from modules.sd_hijack_utils import CondFunc
memory_providers = ["None", "atiadlxx (AMD only)"]
default_memory_provider = "None"
if platform.system() == "Windows":
@ -12,6 +13,7 @@ if platform.system() == "Windows":
do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment
do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment
def _set_memory_provider():
from modules.shared import opts, cmd_opts
if opts.directml_memory_provider == "Performance Counter":
@ -35,6 +37,7 @@ def _set_memory_provider():
torch.dml.mem_get_info = mem_get_info
torch.cuda.mem_get_info = torch.dml.mem_get_info
def directml_init():
try:
from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports
@ -63,6 +66,7 @@ def directml_init():
return False, e
return True, None
def directml_do_hijack():
import modules.dml.hijack # pylint: disable=unused-import
from modules.devices import device
@ -79,17 +83,20 @@ def directml_do_hijack():
_set_memory_provider()
class OverrideItem(NamedTuple):
value: str
condition: Optional[Callable]
message: Optional[str]
opts_override_table = {
"diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"),
"diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"),
"diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"),
}
def directml_override_opts():
from modules import shared

View File

@ -2,12 +2,14 @@ import importlib
from typing import Any, Optional
import torch
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
supported_cast_pairs = {
torch.float16: (torch.float32,),
torch.float32: (torch.float16,),
}
def forward(op, args: tuple, kwargs: dict):
if not torch.dml.is_autocast_enabled:
return op(*args, **kwargs)
@ -16,6 +18,7 @@ def forward(op, args: tuple, kwargs: dict):
kwargs[kwarg] = cast(kwargs[kwarg])
return op(*args, **kwargs)
def cast(tensor: torch.Tensor):
if not torch.is_tensor(tensor):
return tensor
@ -24,6 +27,7 @@ def cast(tensor: torch.Tensor):
return tensor
return tensor.type(torch.dml.autocast_gpu_dtype)
def cond(op: str):
if isinstance(op, str):
func_path = op.split('.')
@ -38,8 +42,10 @@ def cond(op: str):
op = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
for op in ops:
cond(op)
for o in ops:
cond(o)
class autocast:
prev: bool

View File

@ -3,26 +3,29 @@ from typing import Optional, Callable
import torch
import torch_directml # pylint: disable=import-error
import modules.dml.amp as amp
from .utils import rDevice, get_device
from .device import device
from .device import Device
from .Generator import Generator
from .device_properties import DeviceProperties
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
from .memory_amd import AMDMemoryProvider
return AMDMemoryProvider.mem_get_info(get_device(device).index)
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
return (8589934592, 8589934592)
class DirectML:
amp = amp
device = device
device = Device
Generator = Generator
context_device: Optional[torch.device] = None

View File

@ -1,14 +1,14 @@
from typing import Optional
import torch
from .utils import rDevice, get_device
class device:
class Device:
def __enter__(self, device: Optional[rDevice]=None):
torch.dml.context_device = get_device(device)
def __init__(self, device: Optional[rDevice]=None) -> torch.device:
def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
return get_device(device)
def __exit__(self, type, val, tb):
def __exit__(self, t, v, tb):
torch.dml.context_device = None

View File

@ -1,5 +1,6 @@
import torch
class DeviceProperties:
type: str = "directml"
name: str

View File

@ -2,6 +2,7 @@ import torch
from typing import Callable
from modules.shared import log, opts
def catch_nan(func: Callable[[], torch.Tensor]):
if not opts.directml_catch_nan:
return func()

View File

@ -1,8 +1,8 @@
from os import getpid
from collections import defaultdict
from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path
class MemoryProvider:
hQuery: HQuery
hCounters: defaultdict[str, list[HCounter]]

View File

@ -1,7 +1,10 @@
from .driver.atiadlxx import ATIADLxx
class AMDMemoryProvider:
driver: ATIADLxx = ATIADLxx()
@staticmethod
def mem_get_info(index):
usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20)
return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize)

View File

@ -1,9 +1,10 @@
import ctypes as C
from .atiadlxx_apis import *
from .atiadlxx_structures import *
from .atiadlxx_defines import *
from modules.dml.memory_amd.driver.atiadlxx_apis import ADL2_Main_Control_Create, ADL_Main_Memory_Alloc, ADL2_Adapter_NumberOfAdapters_Get, ADL2_Adapter_AdapterInfo_Get, ADL2_Adapter_MemoryInfo2_Get, ADL2_Adapter_DedicatedVRAMUsage_Get, ADL2_Adapter_VRAMUsage_Get
from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, AdapterInfo, LPAdapterInfo, ADLMemoryInfo2
from modules.dml.memory_amd.driver.atiadlxx_defines import ADL_OK
class ATIADLxx(object):
class ATIADLxx:
iHyperMemorySize = 0
def __init__(self):

View File

@ -1,25 +1,30 @@
import ctypes as C
from platform import platform
from .atiadlxx_structures import *
from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, LPAdapterInfo, ADLMemoryInfo2
if 'Windows' in platform():
atiadlxx = C.WinDLL("atiadlxx.dll")
else:
atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported.
ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int)
ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p))
@ADL_MAIN_MALLOC_CALLBACK
def ADL_Main_Memory_Alloc(iSize):
return C._malloc(iSize)
@ADL_MAIN_FREE_CALLBACK
def ADL_Main_Memory_Free(lpBuffer):
if lpBuffer[0] is not None:
C._free(lpBuffer[0])
lpBuffer[0] = None
ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create
ADL2_Main_Control_Create.restype = C.c_int
ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE]

View File

@ -1,5 +1,6 @@
import ctypes as C
class _ADLPMActivity(C.Structure):
__slot__ = [
'iActivityPercent',
@ -13,7 +14,7 @@ class _ADLPMActivity(C.Structure):
'iSize',
'iVddc',
]
_ADLPMActivity._fields_ = [
_ADLPMActivity._fields_ = [ # pylint: disable=protected-access
('iActivityPercent', C.c_int),
('iCurrentBusLanes', C.c_int),
('iCurrentBusSpeed', C.c_int),
@ -27,6 +28,7 @@ _ADLPMActivity._fields_ = [
]
ADLPMActivity = _ADLPMActivity
class _ADLMemoryInfo2(C.Structure):
__slot__ = [
'iHyperMemorySize',
@ -36,7 +38,7 @@ class _ADLMemoryInfo2(C.Structure):
'iVisibleMemorySize',
'strMemoryType'
]
_ADLMemoryInfo2._fields_ = [
_ADLMemoryInfo2._fields_ = [ # pylint: disable=protected-access
('iHyperMemorySize', C.c_longlong),
('iInvisibleMemorySize', C.c_longlong),
('iMemoryBandwidth', C.c_longlong),
@ -46,6 +48,7 @@ _ADLMemoryInfo2._fields_ = [
]
ADLMemoryInfo2 = _ADLMemoryInfo2
class _AdapterInfo(C.Structure):
__slot__ = [
'iSize',
@ -64,7 +67,7 @@ class _AdapterInfo(C.Structure):
'strPNPString',
'iOSDisplayIndex',
]
_AdapterInfo._fields_ = [
_AdapterInfo._fields_ = [ # pylint: disable=protected-access
('iSize', C.c_int),
('iAdapterIndex', C.c_int),
('strUDID', C.c_char * 256),

View File

@ -1,22 +1,24 @@
from ctypes import *
from ctypes.wintypes import *
from typing import NamedTuple, TypeVar
from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery
from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
from .defines import *
from .msvcrt import malloc
from .errors import PDHError
class __InternalAbstraction(NamedTuple):
flag: int
attr_name: str
_type_map = {
int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"),
float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"),
}
def expand_wildcard_path(path: str) -> list[str]:
listLength = DWORD(0)
if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA:
@ -24,33 +26,35 @@ def expand_wildcard_path(path: str) -> list[str]:
expanded = (WCHAR * listLength.value)()
if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK:
raise PDHError(f"Couldn't expand wildcard path '{path}'")
result = list()
cur = str()
for chr in expanded:
if chr == '\0':
result = []
cur = ""
for c in expanded:
if c == '\0':
result.append(cur)
cur = str()
cur = ""
else:
cur += chr
cur += c
result.pop()
return result
T = TypeVar("T", *_type_map.keys())
class HCounter(PDH_HCOUNTER):
def get_formatted_value(self, type: T) -> T:
if type not in _type_map:
raise PDHError(f"Invalid value type: {type}")
flag, attr_name = _type_map[type]
def get_formatted_value(self, typ: T) -> T:
if typ not in _type_map:
raise PDHError(f"Invalid value type: {typ}")
flag, attr_name = _type_map[typ]
value = PDH_FMT_COUNTERVALUE()
if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK:
raise PDHError("Couldn't get formatted counter value.")
return getattr(value.u, attr_name)
def get_formatted_dict(self, type: T) -> dict[str, T]:
if type not in _type_map:
raise PDHError(f"Invalid value type: {type}")
flag, attr_name = _type_map[type]
def get_formatted_dict(self, typ: T) -> dict[str, T]:
if typ not in _type_map:
raise PDHError(f"Invalid value type: {typ}")
flag, attr_name = _type_map[typ]
bufferSize = DWORD(0)
itemCount = DWORD(0)
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA:
@ -64,9 +68,10 @@ class HCounter(PDH_HCOUNTER):
result[item.szName] = getattr(item.FmtValue.u, attr_name)
return result
class HQuery(PDH_HQUERY):
def __init__(self):
super(HQuery, self).__init__()
super().__init__()
if PdhOpenQueryW(None, None, byref(self)) != PDH_OK:
raise PDHError("Couldn't open PDH query.")

View File

@ -1,12 +1,13 @@
from ctypes import *
from ctypes.wintypes import *
from ctypes import CDLL, POINTER
from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
from typing import Callable
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR
from .structures import *
from .defines import *
pdh = CDLL("pdh.dll")
PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW
PdhExpandWildCardPathW.restype = PDH_FUNCTION
PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD]

View File

@ -1,5 +1,6 @@
from ctypes import *
from ctypes.wintypes import *
from ctypes import c_int, POINTER
from ctypes.wintypes import DWORD, WCHAR
PDH_FUNCTION = c_int
PDH_OK = 0x00000000

View File

@ -1,3 +1,3 @@
class PDHError(Exception):
def __init__(self, message: str):
super(PDHError, self).__init__(message)
super().__init__(message)

View File

@ -1,7 +1,9 @@
from ctypes import *
from ctypes import CDLL, c_void_p, c_size_t
msvcrt = CDLL("msvcrt")
malloc = msvcrt.malloc
malloc.restype = c_void_p
malloc.argtypes = [c_size_t]

View File

@ -1,9 +1,11 @@
from ctypes import *
from ctypes.wintypes import *
from ctypes import Union, c_double, c_longlong, Structure, POINTER
from ctypes.wintypes import HANDLE, LONG, LPCSTR, LPCWSTR, DWORD, LPWSTR
PDH_HQUERY = HANDLE
PDH_HCOUNTER = HANDLE
class PDH_FMT_COUNTERVALUE_U(Union):
_fields_ = [
("longValue", LONG),
@ -19,6 +21,7 @@ class PDH_FMT_COUNTERVALUE_U(Union):
AnsiStringValue: LPCSTR
WideStringValue: LPCWSTR
class PDH_FMT_COUNTERVALUE(Structure):
_anonymous_ = ("u",)
_fields_ = [
@ -30,6 +33,7 @@ class PDH_FMT_COUNTERVALUE(Structure):
u: PDH_FMT_COUNTERVALUE_U
PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE)
class PDH_FMT_COUNTERVALUE_ITEM_W(Structure):
_fields_ = [
("szName", LPWSTR),

View File

@ -1,6 +1,7 @@
from typing import Optional, Union
import torch
rDevice = Union[torch.device, int]
def get_device(device: Optional[rDevice]=None) -> torch.device:
if device is None: