automatic/modules/ipex_specific/__init__.py

152 lines
8.0 KiB
Python

import os
import contextlib
import torch
import intel_extension_for_pytorch as ipex
from modules import shared
from modules.sd_hijack_utils import CondFunc
#ControlNet depth_leres++
class DummyDataParallel(torch.nn.Module):
def __new__(cls, module, device_ids=None, output_device=None, dim=0):
if type(device_ids) is list and len(device_ids) > 1:
shared.log.warning("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to(shared.device)
def ipex_no_cuda(orig_func, *args, **kwargs): # pylint: disable=redefined-outer-name
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
def return_null_context(*args, **kwargs):
return contextlib.nullcontext()
def ipex_init():
#Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
#Memory:
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#Training:
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception:
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu" # pylint: disable=protected-access
torch.xpu.empty_cache = torch.xpu.empty_cache if "WSL2" not in os.popen("uname -a").read() else lambda: None
torch.cuda.get_device_properties.major = 2023
torch.cuda.get_device_properties.minor = 2
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.cuda.ipc_collect = lambda: None
torch.cuda.utilization = lambda: 0
#Libraries that blindly uses cuda:
#Adetailer:
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, shared.device, *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: (type(device) is torch.device and device.type == "cuda") or (type(device) is str and "cuda" in device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=shared.device, **kwargs),
lambda orig_func, *args, device=None, **kwargs: (type(device) is torch.device and device.type == "cuda") or (type(device) is str and "cuda" in device))
#ControlNet depth_leres
CondFunc('torch.load',
lambda orig_func, f, map_location=None, *args, **kwargs: orig_func(f, shared.device, *args, **kwargs),
lambda orig_func, f, map_location=None, *args, **kwargs: (map_location is None) or (type(map_location) is torch.device and map_location.type == "cuda") or (type(map_location) is str and "cuda" in map_location))
#Broken functions when torch.cuda.is_available is True:
#Pin Memory:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
#Diffusers bfloat16:
CondFunc('torch.nn.modules.Conv2d._conv_forward',
lambda orig_func, self, input, weight, bias=None: orig_func(self, input.to(weight.data.dtype), weight, bias=bias),
lambda orig_func, self, input, weight, bias=None: input.dtype != weight.data.dtype)
#Functions that does not work with the XPU:
#UniPC:
CondFunc('torch.linalg.solve',
lambda orig_func, A, B, *args, **kwargs: orig_func(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(shared.device),
lambda orig_func, A, B, *args, **kwargs: A.device != torch.device("cpu") or B.device != torch.device("cpu"))
#SDE Samplers:
CondFunc('torch.Generator',
lambda orig_func, device: torch.xpu.Generator(device),
lambda orig_func, device: device != torch.device("cpu") and device != "cpu")
#Latent antialias:
CondFunc('torch.nn.functional.interpolate',
lambda orig_func, input, *args, **kwargs: orig_func(input.to("cpu"), *args, **kwargs).to(shared.device),
lambda orig_func, input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False: antialias)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#ControlNet and TiledVAE:
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=shared.device),
bias if bias is not None else torch.zeros(input.size()[1], device=shared.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#ControlNet
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=shared.device),
bias if bias is not None else torch.zeros(input.size()[1], device=shared.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))