diff --git a/modules/sdnq/triton_mm.py b/modules/sdnq/triton_mm.py index c8fb1a167..cd04f6631 100644 --- a/modules/sdnq/triton_mm.py +++ b/modules/sdnq/triton_mm.py @@ -117,7 +117,7 @@ def triton_mm_td_kernel( off_k = 0 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for _ in range(0, K, BLOCK_SIZE_K): a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N]) accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)