automatic/modules/zluda_hijacks.py

39 lines
1.2 KiB
Python

import torch
from modules import rocm
_topk = torch.topk
def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
device = input.device
values, indices = _topk(input.cpu(), *args, **kwargs)
return torch.return_types.topk((values.to(device), indices.to(device),))
_fft_fftn = torch.fft.fftn
def fft_fftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_fftn(input.cpu(), *args, **kwargs).to(input.device)
_fft_ifftn = torch.fft.ifftn
def fft_ifftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_ifftn(input.cpu(), *args, **kwargs).to(input.device)
_fft_rfftn = torch.fft.rfftn
def fft_rfftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_rfftn(input.cpu(), *args, **kwargs).to(input.device)
def jit_script(f, *_, **__): # experiment / provide dummy graph
f.graph = torch._C.Graph() # pylint: disable=protected-access
return f
def do_hijack():
torch.version.hip = rocm.version
torch.topk = topk
torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
torch.fft.rfftn = fft_rfftn
torch.jit.script = jit_script