mirror of https://github.com/vladmandic/automatic
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
import torch
|
|
from modules import rocm
|
|
|
|
|
|
_topk = torch.topk
|
|
def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
|
|
device = input.device
|
|
values, indices = _topk(input.cpu(), *args, **kwargs)
|
|
return torch.return_types.topk((values.to(device), indices.to(device),))
|
|
|
|
|
|
_fft_fftn = torch.fft.fftn
|
|
def fft_fftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
|
|
return _fft_fftn(input.cpu(), *args, **kwargs).to(input.device)
|
|
|
|
|
|
_fft_ifftn = torch.fft.ifftn
|
|
def fft_ifftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
|
|
return _fft_ifftn(input.cpu(), *args, **kwargs).to(input.device)
|
|
|
|
|
|
_fft_rfftn = torch.fft.rfftn
|
|
def fft_rfftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
|
|
return _fft_rfftn(input.cpu(), *args, **kwargs).to(input.device)
|
|
|
|
|
|
def jit_script(f, *_, **__): # experiment / provide dummy graph
|
|
f.graph = torch._C.Graph() # pylint: disable=protected-access
|
|
return f
|
|
|
|
|
|
def do_hijack():
|
|
torch.version.hip = rocm.version
|
|
torch.topk = topk
|
|
torch.fft.fftn = fft_fftn
|
|
torch.fft.ifftn = fft_ifftn
|
|
torch.fft.rfftn = fft_rfftn
|
|
torch.jit.script = jit_script
|