mirror of https://github.com/vladmandic/automatic
cleanup
parent
470a0d816e
commit
b2e071dc52
|
|
@ -117,7 +117,7 @@ def triton_mm_td_kernel(
|
||||||
|
|
||||||
off_k = 0
|
off_k = 0
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for _ in range(0, K, BLOCK_SIZE_K):
|
||||||
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
|
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
|
||||||
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
|
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
|
||||||
accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
|
accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue