mirror of https://github.com/vladmandic/automatic
fix code formatting under modules/dml
parent
ef29f1a238
commit
ff2c1db1cc
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class Generator(torch.Generator):
|
||||
def __init__(self, device: Optional[torch.device] = None):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import torch
|
|||
from modules.errors import log
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
|
||||
|
||||
memory_providers = ["None", "atiadlxx (AMD only)"]
|
||||
default_memory_provider = "None"
|
||||
if platform.system() == "Windows":
|
||||
|
|
@ -12,6 +13,7 @@ if platform.system() == "Windows":
|
|||
do_nothing = lambda: None # pylint: disable=unnecessary-lambda-assignment
|
||||
do_nothing_with_self = lambda self: None # pylint: disable=unnecessary-lambda-assignment
|
||||
|
||||
|
||||
def _set_memory_provider():
|
||||
from modules.shared import opts, cmd_opts
|
||||
if opts.directml_memory_provider == "Performance Counter":
|
||||
|
|
@ -35,6 +37,7 @@ def _set_memory_provider():
|
|||
torch.dml.mem_get_info = mem_get_info
|
||||
torch.cuda.mem_get_info = torch.dml.mem_get_info
|
||||
|
||||
|
||||
def directml_init():
|
||||
try:
|
||||
from modules.dml.backend import DirectML # pylint: disable=ungrouped-imports
|
||||
|
|
@ -63,6 +66,7 @@ def directml_init():
|
|||
return False, e
|
||||
return True, None
|
||||
|
||||
|
||||
def directml_do_hijack():
|
||||
import modules.dml.hijack # pylint: disable=unused-import
|
||||
from modules.devices import device
|
||||
|
|
@ -79,17 +83,20 @@ def directml_do_hijack():
|
|||
|
||||
_set_memory_provider()
|
||||
|
||||
|
||||
class OverrideItem(NamedTuple):
|
||||
value: str
|
||||
condition: Optional[Callable]
|
||||
message: Optional[str]
|
||||
|
||||
|
||||
opts_override_table = {
|
||||
"diffusers_generator_device": OverrideItem("CPU", None, "DirectML does not support torch Generator API"),
|
||||
"diffusers_model_cpu_offload": OverrideItem(False, None, "Diffusers model CPU offloading does not support DirectML devices"),
|
||||
"diffusers_seq_cpu_offload": OverrideItem(False, lambda opts: opts.diffusers_pipeline != "Stable Diffusion XL", "Diffusers sequential CPU offloading is available only on StableDiffusionXLPipeline with DirectML devices"),
|
||||
}
|
||||
|
||||
|
||||
def directml_override_opts():
|
||||
from modules import shared
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ import importlib
|
|||
from typing import Any, Optional
|
||||
import torch
|
||||
|
||||
|
||||
ops = ["torch.Tensor.__matmul__", "torch.addbmm", "torch.addmm", "torch.addmv", "torch.addr", "torch.baddbmm", "torch.bmm", "torch.chain_matmul", "torch.linalg.multi_dot", "torch.nn.functional.conv1d", "torch.nn.functional.conv2d", "torch.nn.functional.conv3d", "torch.nn.functional.conv_transpose1d", "torch.nn.functional.conv_transpose2d", "torch.nn.functional.conv_transpose3d", "torch.nn.GRUCell", "torch.nn.functional.linear", "torch.nn.LSTMCell", "torch.matmul", "torch.mm", "torch.mv", "torch.prelu", "torch.nn.RNNCell", "torch.embedding"]
|
||||
supported_cast_pairs = {
|
||||
torch.float16: (torch.float32,),
|
||||
torch.float32: (torch.float16,),
|
||||
}
|
||||
|
||||
|
||||
def forward(op, args: tuple, kwargs: dict):
|
||||
if not torch.dml.is_autocast_enabled:
|
||||
return op(*args, **kwargs)
|
||||
|
|
@ -16,6 +18,7 @@ def forward(op, args: tuple, kwargs: dict):
|
|||
kwargs[kwarg] = cast(kwargs[kwarg])
|
||||
return op(*args, **kwargs)
|
||||
|
||||
|
||||
def cast(tensor: torch.Tensor):
|
||||
if not torch.is_tensor(tensor):
|
||||
return tensor
|
||||
|
|
@ -24,6 +27,7 @@ def cast(tensor: torch.Tensor):
|
|||
return tensor
|
||||
return tensor.type(torch.dml.autocast_gpu_dtype)
|
||||
|
||||
|
||||
def cond(op: str):
|
||||
if isinstance(op, str):
|
||||
func_path = op.split('.')
|
||||
|
|
@ -38,8 +42,10 @@ def cond(op: str):
|
|||
op = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: forward(op, args, kwargs))
|
||||
|
||||
for op in ops:
|
||||
cond(op)
|
||||
|
||||
for o in ops:
|
||||
cond(o)
|
||||
|
||||
|
||||
class autocast:
|
||||
prev: bool
|
||||
|
|
|
|||
|
|
@ -3,26 +3,29 @@ from typing import Optional, Callable
|
|||
import torch
|
||||
import torch_directml # pylint: disable=import-error
|
||||
import modules.dml.amp as amp
|
||||
|
||||
from .utils import rDevice, get_device
|
||||
from .device import device
|
||||
from .device import Device
|
||||
from .Generator import Generator
|
||||
from .device_properties import DeviceProperties
|
||||
|
||||
|
||||
def amd_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||
from .memory_amd import AMDMemoryProvider
|
||||
return AMDMemoryProvider.mem_get_info(get_device(device).index)
|
||||
|
||||
|
||||
def pdh_mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||
mem_info = DirectML.memory_provider.get_memory(get_device(device).index)
|
||||
return (mem_info["total_committed"] - mem_info["dedicated_usage"], mem_info["total_committed"])
|
||||
|
||||
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]:
|
||||
|
||||
def mem_get_info(device: Optional[rDevice]=None) -> tuple[int, int]: # pylint: disable=unused-argument
|
||||
return (8589934592, 8589934592)
|
||||
|
||||
|
||||
class DirectML:
|
||||
amp = amp
|
||||
device = device
|
||||
device = Device
|
||||
Generator = Generator
|
||||
|
||||
context_device: Optional[torch.device] = None
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from .utils import rDevice, get_device
|
||||
|
||||
class device:
|
||||
|
||||
class Device:
|
||||
def __enter__(self, device: Optional[rDevice]=None):
|
||||
torch.dml.context_device = get_device(device)
|
||||
|
||||
def __init__(self, device: Optional[rDevice]=None) -> torch.device:
|
||||
def __init__(self, device: Optional[rDevice]=None) -> torch.device: # pylint: disable=return-in-init
|
||||
return get_device(device)
|
||||
|
||||
def __exit__(self, type, val, tb):
|
||||
def __exit__(self, t, v, tb):
|
||||
torch.dml.context_device = None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
|
||||
|
||||
class DeviceProperties:
|
||||
type: str = "directml"
|
||||
name: str
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from typing import Callable
|
||||
from modules.shared import log, opts
|
||||
|
||||
|
||||
def catch_nan(func: Callable[[], torch.Tensor]):
|
||||
if not opts.directml_catch_nan:
|
||||
return func()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from os import getpid
|
||||
from collections import defaultdict
|
||||
|
||||
from modules.dml.pdh import HQuery, HCounter, expand_wildcard_path
|
||||
|
||||
|
||||
class MemoryProvider:
|
||||
hQuery: HQuery
|
||||
hCounters: defaultdict[str, list[HCounter]]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from .driver.atiadlxx import ATIADLxx
|
||||
|
||||
|
||||
class AMDMemoryProvider:
|
||||
driver: ATIADLxx = ATIADLxx()
|
||||
|
||||
@staticmethod
|
||||
def mem_get_info(index):
|
||||
usage = AMDMemoryProvider.driver.get_dedicated_vram_usage(index) * (1 << 20)
|
||||
return (AMDMemoryProvider.driver.iHyperMemorySize - usage, AMDMemoryProvider.driver.iHyperMemorySize)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import ctypes as C
|
||||
from .atiadlxx_apis import *
|
||||
from .atiadlxx_structures import *
|
||||
from .atiadlxx_defines import *
|
||||
from modules.dml.memory_amd.driver.atiadlxx_apis import ADL2_Main_Control_Create, ADL_Main_Memory_Alloc, ADL2_Adapter_NumberOfAdapters_Get, ADL2_Adapter_AdapterInfo_Get, ADL2_Adapter_MemoryInfo2_Get, ADL2_Adapter_DedicatedVRAMUsage_Get, ADL2_Adapter_VRAMUsage_Get
|
||||
from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, AdapterInfo, LPAdapterInfo, ADLMemoryInfo2
|
||||
from modules.dml.memory_amd.driver.atiadlxx_defines import ADL_OK
|
||||
|
||||
class ATIADLxx(object):
|
||||
|
||||
class ATIADLxx:
|
||||
iHyperMemorySize = 0
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -1,25 +1,30 @@
|
|||
import ctypes as C
|
||||
from platform import platform
|
||||
from .atiadlxx_structures import *
|
||||
from modules.dml.memory_amd.driver.atiadlxx_structures import ADL_CONTEXT_HANDLE, LPAdapterInfo, ADLMemoryInfo2
|
||||
|
||||
|
||||
if 'Windows' in platform():
|
||||
atiadlxx = C.WinDLL("atiadlxx.dll")
|
||||
else:
|
||||
atiadlxx = C.CDLL("libatiadlxx.so") # Not tested on Linux system. But will be supported.
|
||||
|
||||
|
||||
ADL_MAIN_MALLOC_CALLBACK = C.CFUNCTYPE(C.c_void_p, C.c_int)
|
||||
ADL_MAIN_FREE_CALLBACK = C.CFUNCTYPE(None, C.POINTER(C.c_void_p))
|
||||
|
||||
|
||||
@ADL_MAIN_MALLOC_CALLBACK
|
||||
def ADL_Main_Memory_Alloc(iSize):
|
||||
return C._malloc(iSize)
|
||||
|
||||
|
||||
@ADL_MAIN_FREE_CALLBACK
|
||||
def ADL_Main_Memory_Free(lpBuffer):
|
||||
if lpBuffer[0] is not None:
|
||||
C._free(lpBuffer[0])
|
||||
lpBuffer[0] = None
|
||||
|
||||
|
||||
ADL2_Main_Control_Create = atiadlxx.ADL2_Main_Control_Create
|
||||
ADL2_Main_Control_Create.restype = C.c_int
|
||||
ADL2_Main_Control_Create.argtypes = [ADL_MAIN_MALLOC_CALLBACK, C.c_int, ADL_CONTEXT_HANDLE]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import ctypes as C
|
||||
|
||||
|
||||
class _ADLPMActivity(C.Structure):
|
||||
__slot__ = [
|
||||
'iActivityPercent',
|
||||
|
|
@ -13,7 +14,7 @@ class _ADLPMActivity(C.Structure):
|
|||
'iSize',
|
||||
'iVddc',
|
||||
]
|
||||
_ADLPMActivity._fields_ = [
|
||||
_ADLPMActivity._fields_ = [ # pylint: disable=protected-access
|
||||
('iActivityPercent', C.c_int),
|
||||
('iCurrentBusLanes', C.c_int),
|
||||
('iCurrentBusSpeed', C.c_int),
|
||||
|
|
@ -27,6 +28,7 @@ _ADLPMActivity._fields_ = [
|
|||
]
|
||||
ADLPMActivity = _ADLPMActivity
|
||||
|
||||
|
||||
class _ADLMemoryInfo2(C.Structure):
|
||||
__slot__ = [
|
||||
'iHyperMemorySize',
|
||||
|
|
@ -36,7 +38,7 @@ class _ADLMemoryInfo2(C.Structure):
|
|||
'iVisibleMemorySize',
|
||||
'strMemoryType'
|
||||
]
|
||||
_ADLMemoryInfo2._fields_ = [
|
||||
_ADLMemoryInfo2._fields_ = [ # pylint: disable=protected-access
|
||||
('iHyperMemorySize', C.c_longlong),
|
||||
('iInvisibleMemorySize', C.c_longlong),
|
||||
('iMemoryBandwidth', C.c_longlong),
|
||||
|
|
@ -46,6 +48,7 @@ _ADLMemoryInfo2._fields_ = [
|
|||
]
|
||||
ADLMemoryInfo2 = _ADLMemoryInfo2
|
||||
|
||||
|
||||
class _AdapterInfo(C.Structure):
|
||||
__slot__ = [
|
||||
'iSize',
|
||||
|
|
@ -64,7 +67,7 @@ class _AdapterInfo(C.Structure):
|
|||
'strPNPString',
|
||||
'iOSDisplayIndex',
|
||||
]
|
||||
_AdapterInfo._fields_ = [
|
||||
_AdapterInfo._fields_ = [ # pylint: disable=protected-access
|
||||
('iSize', C.c_int),
|
||||
('iAdapterIndex', C.c_int),
|
||||
('strUDID', C.c_char * 256),
|
||||
|
|
|
|||
|
|
@ -1,22 +1,24 @@
|
|||
from ctypes import *
|
||||
from ctypes.wintypes import *
|
||||
from typing import NamedTuple, TypeVar
|
||||
|
||||
from .apis import PdhExpandWildCardPathW, PdhOpenQueryW, PdhAddEnglishCounterW, PdhCollectQueryData, PdhGetFormattedCounterValue, PdhGetFormattedCounterArrayW, PdhCloseQuery
|
||||
from .structures import PDH_HQUERY, PDH_HCOUNTER, PDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
|
||||
from .defines import *
|
||||
from .msvcrt import malloc
|
||||
from .errors import PDHError
|
||||
|
||||
|
||||
class __InternalAbstraction(NamedTuple):
|
||||
flag: int
|
||||
attr_name: str
|
||||
|
||||
|
||||
_type_map = {
|
||||
int: __InternalAbstraction(PDH_FMT_LARGE, "largeValue"),
|
||||
float: __InternalAbstraction(PDH_FMT_DOUBLE, "doubleValue"),
|
||||
}
|
||||
|
||||
|
||||
def expand_wildcard_path(path: str) -> list[str]:
|
||||
listLength = DWORD(0)
|
||||
if PdhExpandWildCardPathW(None, LPCWSTR(path), None, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_MORE_DATA:
|
||||
|
|
@ -24,33 +26,35 @@ def expand_wildcard_path(path: str) -> list[str]:
|
|||
expanded = (WCHAR * listLength.value)()
|
||||
if PdhExpandWildCardPathW(None, LPCWSTR(path), expanded, byref(listLength), PDH_NOEXPANDCOUNTERS) != PDH_OK:
|
||||
raise PDHError(f"Couldn't expand wildcard path '{path}'")
|
||||
result = list()
|
||||
cur = str()
|
||||
for chr in expanded:
|
||||
if chr == '\0':
|
||||
result = []
|
||||
cur = ""
|
||||
for c in expanded:
|
||||
if c == '\0':
|
||||
result.append(cur)
|
||||
cur = str()
|
||||
cur = ""
|
||||
else:
|
||||
cur += chr
|
||||
cur += c
|
||||
result.pop()
|
||||
return result
|
||||
|
||||
|
||||
T = TypeVar("T", *_type_map.keys())
|
||||
|
||||
|
||||
class HCounter(PDH_HCOUNTER):
|
||||
def get_formatted_value(self, type: T) -> T:
|
||||
if type not in _type_map:
|
||||
raise PDHError(f"Invalid value type: {type}")
|
||||
flag, attr_name = _type_map[type]
|
||||
def get_formatted_value(self, typ: T) -> T:
|
||||
if typ not in _type_map:
|
||||
raise PDHError(f"Invalid value type: {typ}")
|
||||
flag, attr_name = _type_map[typ]
|
||||
value = PDH_FMT_COUNTERVALUE()
|
||||
if PdhGetFormattedCounterValue(self, DWORD(flag | PDH_FMT_NOSCALE), None, byref(value)) != PDH_OK:
|
||||
raise PDHError("Couldn't get formatted counter value.")
|
||||
return getattr(value.u, attr_name)
|
||||
|
||||
def get_formatted_dict(self, type: T) -> dict[str, T]:
|
||||
if type not in _type_map:
|
||||
raise PDHError(f"Invalid value type: {type}")
|
||||
flag, attr_name = _type_map[type]
|
||||
def get_formatted_dict(self, typ: T) -> dict[str, T]:
|
||||
if typ not in _type_map:
|
||||
raise PDHError(f"Invalid value type: {typ}")
|
||||
flag, attr_name = _type_map[typ]
|
||||
bufferSize = DWORD(0)
|
||||
itemCount = DWORD(0)
|
||||
if PdhGetFormattedCounterArrayW(self, DWORD(flag | PDH_FMT_NOSCALE), byref(bufferSize), byref(itemCount), None) != PDH_MORE_DATA:
|
||||
|
|
@ -64,9 +68,10 @@ class HCounter(PDH_HCOUNTER):
|
|||
result[item.szName] = getattr(item.FmtValue.u, attr_name)
|
||||
return result
|
||||
|
||||
|
||||
class HQuery(PDH_HQUERY):
|
||||
def __init__(self):
|
||||
super(HQuery, self).__init__()
|
||||
super().__init__()
|
||||
if PdhOpenQueryW(None, None, byref(self)) != PDH_OK:
|
||||
raise PDHError("Couldn't open PDH query.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
from ctypes import *
|
||||
from ctypes.wintypes import *
|
||||
from ctypes import CDLL, POINTER
|
||||
from ctypes.wintypes import LPCWSTR, LPDWORD, DWORD
|
||||
from typing import Callable
|
||||
from .structures import PDH_HQUERY, PDH_HCOUNTER, PPDH_FMT_COUNTERVALUE, PPDH_FMT_COUNTERVALUE_ITEM_W
|
||||
from .defines import PDH_FUNCTION, PZZWSTR, DWORD_PTR
|
||||
|
||||
from .structures import *
|
||||
from .defines import *
|
||||
|
||||
pdh = CDLL("pdh.dll")
|
||||
|
||||
|
||||
PdhExpandWildCardPathW: Callable = pdh.PdhExpandWildCardPathW
|
||||
PdhExpandWildCardPathW.restype = PDH_FUNCTION
|
||||
PdhExpandWildCardPathW.argtypes = [LPCWSTR, LPCWSTR, PZZWSTR, LPDWORD, DWORD]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from ctypes import *
|
||||
from ctypes.wintypes import *
|
||||
from ctypes import c_int, POINTER
|
||||
from ctypes.wintypes import DWORD, WCHAR
|
||||
|
||||
|
||||
PDH_FUNCTION = c_int
|
||||
PDH_OK = 0x00000000
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
class PDHError(Exception):
|
||||
def __init__(self, message: str):
|
||||
super(PDHError, self).__init__(message)
|
||||
super().__init__(message)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from ctypes import *
|
||||
from ctypes import CDLL, c_void_p, c_size_t
|
||||
|
||||
|
||||
msvcrt = CDLL("msvcrt")
|
||||
|
||||
|
||||
malloc = msvcrt.malloc
|
||||
malloc.restype = c_void_p
|
||||
malloc.argtypes = [c_size_t]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from ctypes import *
|
||||
from ctypes.wintypes import *
|
||||
from ctypes import Union, c_double, c_longlong, Structure, POINTER
|
||||
from ctypes.wintypes import HANDLE, LONG, LPCSTR, LPCWSTR, DWORD, LPWSTR
|
||||
|
||||
|
||||
PDH_HQUERY = HANDLE
|
||||
PDH_HCOUNTER = HANDLE
|
||||
|
||||
|
||||
class PDH_FMT_COUNTERVALUE_U(Union):
|
||||
_fields_ = [
|
||||
("longValue", LONG),
|
||||
|
|
@ -19,6 +21,7 @@ class PDH_FMT_COUNTERVALUE_U(Union):
|
|||
AnsiStringValue: LPCSTR
|
||||
WideStringValue: LPCWSTR
|
||||
|
||||
|
||||
class PDH_FMT_COUNTERVALUE(Structure):
|
||||
_anonymous_ = ("u",)
|
||||
_fields_ = [
|
||||
|
|
@ -30,6 +33,7 @@ class PDH_FMT_COUNTERVALUE(Structure):
|
|||
u: PDH_FMT_COUNTERVALUE_U
|
||||
PPDH_FMT_COUNTERVALUE = POINTER(PDH_FMT_COUNTERVALUE)
|
||||
|
||||
|
||||
class PDH_FMT_COUNTERVALUE_ITEM_W(Structure):
|
||||
_fields_ = [
|
||||
("szName", LPWSTR),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
|
||||
rDevice = Union[torch.device, int]
|
||||
def get_device(device: Optional[rDevice]=None) -> torch.device:
|
||||
if device is None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue