mirror of https://github.com/vladmandic/automatic
29 lines
822 B
Python
29 lines
822 B
Python
import os
|
|
import sys
|
|
import torch
|
|
|
|
|
|
_topk = torch.topk
|
|
def topk(tensor: torch.Tensor, *args, **kwargs):
|
|
device = tensor.device
|
|
values, indices = _topk(tensor.cpu(), *args, **kwargs)
|
|
return torch.return_types.topk((values.to(device), indices.to(device),))
|
|
|
|
|
|
def _join_rocm_home(*paths) -> str:
|
|
from torch.utils.cpp_extension import ROCM_HOME
|
|
return os.path.join(ROCM_HOME, *paths)
|
|
|
|
|
|
def do_hijack():
|
|
torch.version.hip = "5.7"
|
|
torch.topk = topk
|
|
platform = sys.platform
|
|
sys.platform = ""
|
|
from torch.utils import cpp_extension
|
|
sys.platform = platform
|
|
cpp_extension.IS_WINDOWS = platform == "win32"
|
|
cpp_extension.IS_MACOS = False
|
|
cpp_extension.IS_LINUX = platform.startswith('linux')
|
|
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
|