Asynchronous Data Copies in CuTe DSL
I’ve always been interested in writing performant code, and recently I’ve been learning how to do that for NVIDIA GPUs. Specifically, I’ve been looking into CuTe DSL, which is a domain-specific language based on Python by NVIDIA. It offers a less verbose and slightly less complicated way to program GPUs compared to writing CUDA C++, but also is more expressive and powerful than Triton.
My ultimate goal is to learn how to write something like FlashAttention 4 in CuTe DSL. However, even a “simple” general matrix multiply (GEMM) kernel for an A100 is about 800 lines of code with many abstractions mixed in, and is difficult for me to grasp, despite the many examples provided on GitHub. The docs are also rather terse. So I thought I’d incrementally work up to complicated kernels, starting from a simple copy kernel to understand how to do asynchronous data copies (which are needed for the fancier kernels).
Prerequisites: I assume a basic familiarity with core GPU programming concepts (see, e.g., sections 1 and 2 of ThunderKittens), as well as some knowledge of CuTe DSL, at the level of the elementwise-addition kernel from the docs. Simon Veitner’s An applied introduction to CuTeDSL and Thread-value layouts in CuTe are great supplements to that example.
Basic info
Our goal in this post is to write a kernel that copies data from a source matrix to a destination matrix , both of shape . The general idea is to split the source and destination matrices into tiles of a fixed size, and have one thread block (a group of threads that execute on the same streaming multiprocessor) perform the copy for one tile. We will copy source tiles from global memory (GMEM) to shared memory (SMEM) asynchronously, and then copy these to destination tensor blocks on GMEM synchronously. The asynchronous copy is not strictly necessary in this example, but it will be for more complicated kernels where we need to overlap computation with memory operations.
Setup
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
if __name__ == "__main__":
# Create source and destination matrices
M, N = 8192, 8192
src = torch.randn(M, N, device="cuda", dtype=torch.float16)
dst = torch.empty_like(src)
src_ = from_dlpack(src, assumed_align=16)
dst_ = from_dlpack(dst, assumed_align=16)
# Compile and run the kernel
op = TensorCopyAsync()
compiled = cute.compile(op, src_, dst_)
compiled(src_, dst_)
torch.testing.assert_close(dst, src)This is straightforward. We create source and destination matrices, both of type float16. The from_dlpack function converts a PyTorch tensor (or a tensor from another DLPack-compatible framework like JAX) into a CuTe tensor with static shape and stride. The TensorCopyAsync class holds the kernel, which we will write next, and after we compile and run it, we verify that dst matches src.
Host code
We initialize the class by providing the number of rows we want in a tile (tile_m) and the number of threads in a thread block (num_threads). We also define a synchronization barrier for the thread block (cta_sync_barrier), where cta stands for “Cooperative Thread Array”, which is CUDA nomenclature for a thread block.
import cutlass
import cutlass.pipeline as pipeline
class TensorCopyAsync:
def __init__(self, tile_m: int = 32, num_threads: int = 512):
if num_threads % tile_m != 0:
raise ValueError("num_threads must be divisible by tile_m")
self._tile_m = tile_m
self._num_threads = num_threads
self.cta_sync_barrier = pipeline.NamedBarrier(
barrier_id=1, num_threads=num_threads
)For the __call__ method, we first calculate the size of the tile based on how many elements we want a single thread to move. Since a single copy instruction can move 32, 64, or 128 bits, here we choose 128, which corresponds to 128 // 16 = 8 elements per thread.
@cute.jit
def __call__(self, mSrc: cute.Tensor, mDst: cute.Tensor):
copy_bits = 128
vector_elems = copy_bits // mSrc.element_type.width
threads_per_row = self._num_threads // self._tile_m
tile_n = threads_per_row * vector_elems
if cutlass.const_expr(
mSrc.shape[0] % self._tile_m != 0 or mSrc.shape[1] % tile_n != 0
):
raise ValueError(
f"mSrc/mDst shape must be divisible by ({self._tile_m}, {tile_n})"
)Here, each thread will copy vector_elems (which is 8) elements which are laid out sequentially in a single row.
The next order of business is to define the copy operations we want. We first want to copy data from src to a staging tile in SMEM, then from that tile back to dst on GMEM. In CuTe, we specify the copy instruction by constructing a CopyAtom, which is a Python class that holds information related to the instruction, such as which copy operation to use and the data type of the elements to be copied. Here we define the two copy atoms: atom_async_copy is for the GMEM to SMEM copy, whereas atom_store is for the reverse.
atom_async_copy = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(), # Asynchronous GMEM -> SMEM copy operation
mSrc.element_type,
num_bits_per_copy=copy_bits,
)
atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), # General-purpose copy operation
mDst.element_type,
num_bits_per_copy=copy_bits,
)The rest of the code defines a thread-value layout for the tiles, and divides the source and destination tensors according to the layout with cute.zipped_divide. Note the use of cute.make_tiled_copy_tv to create tiled versions of the copy instructions we defined earlier.
thr_layout = cute.make_layout(
(self._tile_m, threads_per_row), stride=(threads_per_row, 1)
)
val_layout = cute.make_layout((1, vector_elems))
tiled_copy_load = cute.make_tiled_copy_tv(
atom_async_copy, thr_layout, val_layout
)
tiled_copy_store = cute.make_tiled_copy_tv(atom_store, thr_layout, val_layout)
sSrc_layout = cute.make_layout((self._tile_m, tile_n), stride=(tile_n, 1))
gSrc = cute.zipped_divide(mSrc, (self._tile_m, tile_n))
gDst = cute.zipped_divide(mDst, (self._tile_m, tile_n))We finally launch the kernel with one tile assigned to one thread block. We use cute.ceil_div to calculate how many tiles we need to cover the entire matrix.
tiles_mn = cute.ceil_div(mSrc.shape, (self._tile_m, tile_n))
self.kernel(
gSrc,
gDst,
sSrc_layout,
tiled_copy_load,
tiled_copy_store,
).launch(
grid=[tiles_mn[1], tiles_mn[0], 1],
block=[self._num_threads, 1, 1],
)Device code
For the kernel itself, we start by accessing the source and destination tiles assigned to this thread block.
@cute.kernel
def kernel(
self,
gSrc: cute.Tensor,
gDst: cute.Tensor,
sSrc_layout: cute.Layout,
tiled_copy_load: cute.TiledCopy,
tiled_copy_store: cute.TiledCopy,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, _ = cute.arch.block_idx()
blkSrc = gSrc[((None, None), (bidy, bidx))]
blkDst = gDst[((None, None), (bidy, bidx))]This is followed by allocating memory in SMEM, which we will use for staging.
smem = cutlass.utils.SmemAllocator()
sSrc = smem.allocate_tensor(gSrc.element_type, sSrc_layout, 16)We now need to create copy instructions for the elements of the tile that only the current thread is responsible for. We do this in CuTe using the following snippet (the function names are self-explanatory):
thr_copy_load = tiled_copy_load.get_slice(tidx)
tSgSrc = thr_copy_load.partition_S(blkSrc)
tSsSrc = thr_copy_load.partition_D(sSrc)
cute.copy(tiled_copy_load, tSgSrc, tSsSrc)When using cute.copy with an asynchronous copy instruction, we also have to use the following code to instruct the GPU to execute it.
cute.arch.cp_async_commit_group() # Submit this thread's queued operations
cute.arch.cp_async_wait_group(0) # Wait until this thread's operations are done
self.cta_sync_barrier.arrive_and_wait() # Sync between all threads in the thread blockWe finally copy data from the SMEM tensor to dst.
thr_copy_store = tiled_copy_store.get_slice(tidx)
tSsStore = thr_copy_store.partition_S(sSrc)
tSgDst = thr_copy_store.partition_D(blkDst)
cute.copy(tiled_copy_store, tSsStore, tSgDst)That’s it! On profiling this kernel with Nsight Compute on an A100, it achieves 87% of the bandwidth of the GPU, which is pretty good.
Parting thoughts
The full code is available here. I tried experimenting with different tile sizes and number of threads, but couldn’t 87% is the most I could achieve. If you have ideas on how to improve the throughput even more, please send me an email or open a GitHub issue!
Next up is a GEMM kernel for the Ampere generation (following the example here).