mirror of https://github.com/vladmandic/automatic
15 lines
382 B
Python
15 lines
382 B
Python
import torch
|
|
from modules import rocm
|
|
|
|
|
|
_topk = torch.topk
|
|
def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-builtin
|
|
device = input.device
|
|
values, indices = _topk(input.cpu(), *args, **kwargs)
|
|
return torch.return_types.topk((values.to(device), indices.to(device),))
|
|
|
|
|
|
def do_hijack():
|
|
torch.version.hip = rocm.version
|
|
torch.topk = topk
|