《CuTe C++ 简介01,从示例开始 》 中,最后看到了 计算 gemm 的cuda kernel:gemm_device().

    它使用 NVIDIA CUTLASS 的 CUTe (CUDA Tile) 库实现的高性能 GEMM (通用矩阵乘法) CUDA kernel。接下来解释一下这个内核的各个部分。文末再贴一遍代码,方便查看。

1. 逐行解析 gemm_device

template <class ProblemShape, class CtaTiler,
          class TA, class AStride, class ASmemLayout, class AThreadLayout,
          class TB, class BStride, class BSmemLayout, class BThreadLayout,
          class TC, class CStride, class CSmemLayout, class CThreadLayout,
          class Alpha, class Beta>

功能: 内核的模板参数列表。

作用: 定义了整个计算所需的类型和策略,实现泛型编程。

    ProblemShape: 问题形状,通常是 (M, N, K) 三元组。

    CtaTiler: 定义 CTA (线程块) 如何平铺(tile)整个问题的策略。

    TA, TB, TC: 矩阵 A, B, C 的数据类型。

    AStride, BStride, CStride: 矩阵 A, B, C 在主存中的步长(stride)类型,用于处理非连续内存。

    ASmemLayout, BSmemLayout, CSmemLayout: 矩阵 A, B, C 在共享内存(smem)中的布局类型。

    AThreadLayout, BThreadLayout, CThreadLayout: 定义线程如何从平铺块中划分数据的布局类型。

    Alpha, Beta: GEMM 标量参数 alpha 和 beta 的数据类型。


__global__ static

功能: CUDA 内核限定符。

作用__global__ 表明这是一个 CUDA 内核函数,由主机调用,设备执行。static 限制该函数在本编译单元内可见。


__launch_bounds__(decltype(size(CThreadLayout{}))::value)

功能: 指定内核启动边界。

作用: 向编译器建议该内核函数的最佳线程块大小(每个 CTA 的线程数)。size(CThreadLayout{}) 计算出于处理 C 矩阵而定义的线程布局所需的线程数,这通常是整个 CTA 的线程数。这有助于编译器优化寄存器分配。


void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
            TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
            TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
            TC      * C, CStride dC, CSmemLayout          , CThreadLayout tC,
            Alpha alpha, Beta beta)

功能: 内核函数名和参数列表。

作用: 接收执行 GEMM 操作 C = alpha * A * B + beta * C 所需的所有信息

    shape_MNK: 具体的问题尺寸,如 make_shape(M, N, K)

    cta_tiler: CTA 平铺策略的具体实例。A, dA: 矩阵 A 的全局内存指针及其步长。

    sA_layout, tA: 矩阵 A 的共享内存布局和线程划分布局。

    B, dB: 矩阵 B 的全局内存指针及其步长。

    sB_layout, tB: 矩阵 B 的共享内存布局和线程划分布局。

    C, dC: 矩阵 C 的全局内存指针及其步长。tC: 矩阵 C 的线程划分布局(CSmemLayout 参数未使用,用占位符忽略)。

    alpha, beta: GEMM 的标量参数。


  using namespace cute;

功能: 使用 cute 命名空间。

作用: 简化代码,允许直接使用 cute 库中的函数(如 make_tensormake_gmem_ptr 等),而无需前缀。


  Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)

功能: 创建代表全局内存中矩阵 A 的逻辑张量(Tensor)视图。

作用make_gmem_ptr(A) 将指针包装为特殊指针。select<0,2>(shape_MNK) 从 (M,N,K) 中选取第0和2维,得到 (M, K) 作为张量形状。dA 是步长。结果是逻辑上的 (M, K) 张量 mA


  Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)

功能: 创建代表全局内存中矩阵 B 的逻辑张量视图。

作用: 类似上一行,选取第1和2维 (N, K)


  Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)

功能: 创建代表全局内存中矩阵 C 的逻辑张量视图。

作用: 选取第0和1维 (M, N)


  auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);              // (m,n,k)

功能: 创建当前 CTA (线程块) 的坐标。

作用blockIdx.x 和 blockIdx.y 通常对应网格的 M 和 N 维度。_(下划线)是 cute 中的占位符,表示 K 维度尚未确定或需要后续计算。结果是一个 (m_idx, n_idx, _) 坐标。


  Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)

功能: 从全局张量 mA 中提取出当前 CTA 负责处理的那一部分平铺数据(Tile)。

作用cta_tiler 定义了平铺方式(如 BLK_MBLK_NBLK_K)。

    cta_coord 指定要提取哪个块。

    Step<_1, X, _1> 是一个遍历模式(stride pattern):

    _1: 在 M 和 K 维度上,每个 CTA 处理一个块。

    X: 在 N 维度上,所有 CTA 都处理整个 K 维度(这是 GEMM 的特性,每个 CTA 需要读取 A 的整个 K 维度与其 BLK_M 相交的部分)。

    结果张量 gA 的形状可能是 (BLK_M, BLK_K, num_k_tiles)


  Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{});  // (BLK_N,BLK_K,k)

功能: 提取当前 CTA 负责的 B 矩阵平铺数据。

作用: 类似上一行。Step<X, _1, _1> 表示:

    X: 在 M 维度上,所有 CTA 都处理整个 K 维度。

    _1: 在 N 和 K 维度上,每个 CTA 处理一个块。

    结果张量 gB 的形状可能是 (BLK_N, BLK_K, num_k_tiles)


  Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)

功能: 提取当前 CTA 负责计算的 C 矩阵平铺数据。

作用Step<_1, _1, X> 表示:

    _1: 在 M 和 N 维度上,每个 CTA 处理一个块。

    X: 在 K 维度上,进行平铺(因为 C 矩阵没有 K 维度)。

    结果张量 gC 的形状是 (BLK_M, BLK_N)


  __shared__ TA smemA[cosize_v<ASmemLayout>];

功能: 在共享内存中静态分配矩阵 A 所需的缓冲区。

作用cosize_v<ASmemLayout> 在编译时计算布局 sA_layout 定义的总大小。smemA 是这块共享内存的原始数组。


  __shared__ TB smemB[cosize_v<BSmemLayout>];

功能: 在共享内存中静态分配矩阵 B 所需的缓冲区。


  Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K)

功能: 创建共享内存缓冲区 smemA 的逻辑张量视图。

作用make_smem_ptr 包装指针。sA_layout 定义了数据的逻辑布局(如 (BLK_M, BLK_K)),这个布局通常经过优化以促进高效的内存访问(如避免bank冲突)。


  Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K)

功能: 创建共享内存中 B 矩阵缓冲区的逻辑张量视图。


  Tensor tAgA = local_partition(gA, tA, threadIdx.x);                  // (THR_M,THR_K,k)

功能: 将当前 CTA 的全局数据块 gA 进一步划分给当前线程。

作用tA 定义了线程布局(如 (THR_M, THR_K))。threadIdx.x 是线程索引。结果是当前线程需要从全局内存中加载的 A 矩阵数据块 tAgA,形状如 (THR_M, THR_K, num_k_tiles)


  Tensor tAsA = local_partition(sA, tA, threadIdx.x);                  // (THR_M,THR_K)

功能: 将共享内存张量 sA 划分给当前线程。

作用: 定义了当前线程负责将数据从 tAgA 写入共享内存 sA 的哪个部分。形状是 (THR_M, THR_K)


  Tensor tBgB = local_partition(gB, tB, threadIdx.x);                  // (THR_N,THR_K,k)

功能: 将当前 CTA 的全局数据块 gB 划分给当前线程。


  Tensor tBsB = local_partition(sB, tB, threadIdx.x);                  // (THR_N,THR_K)

功能: 将共享内存张量 sB 划分给当前线程。


  Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{});   // (THR_M,BLK_K)

功能: 为计算操作,从共享内存 sA 中划分出当前线程需要读取的数据块。

作用: 使用处理 C 矩阵的线程布局 tC 进行划分。

    Step<_1, X> 是重要的模式:_1: 在 M 维度上,每个线程取一部分(THR_M)。

    X: 在 K 维度上,每个线程读取整个维度 (BLK_K)。

    这是为了支持外积(outer-product)类的计算,线程需要一整行/列来与其他张量的整行/列进行计算。


  Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{});   // (THR_N,BLK_K)

功能: 为计算操作,从共享内存 sB 中划分出当前线程需要读取的数据块。

作用Step<X, _1> 模式:

    X: 在 N 维度上,每个线程读取整个维度 (BLK_N)? (注释是 THR_N,这里可能需要根据布局确认,但 Step<X,_1> 通常意味着在第一个模式维度上取全部,在第二个模式维度上划分)。

     _1: 在 K 维度上,每个线程取一部分 (THR_K)? 注释是 BLK_K

    (注释 (THR_N,BLK_K) 和参数 Step<X,_1> 可能看起来矛盾,实际含义取决于 tC 和 sB_layout 的具体定义。其核心思想是划分出计算所需的数据块)


  Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{});   // (THR_M,THR_N)

功能: 将当前 CTA 的全局输出数据块 gC 划分给当前线程。

作用Step<_1, _1> 表示在两个维度上都进行划分。当前线程负责计算和写入最终结果张量 tCgC,其形状为 (THR_M, THR_N)


  Tensor tCrC = make_tensor_like(tCgC);                                // (THR_M,THR_N)

功能: 在寄存器中创建一个张量,其形状与 tCgC 相同。

作用: 这个寄存器张量 tCrC 用于累加当前线程的中间计算结果。在分步计算 A*B 时,部分结果先存在这里,最后再与 beta*C 结合并写回全局内存。使用寄存器速度极快。


  clear(tCrC);

功能: 将累加寄存器 tCrC 初始化为 0。

作用: 为累加操作做准备。


  auto K_TILE_MAX = size<2>(tAgA);

功能: 获取 K 维度上需要处理的平铺块(Tile)的数量。

作用tAgA 的形状是 (THR_M, THR_K, num_k_tiles)size<2> 获取第2维(K 平铺维度)的大小。


  for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
  {

功能: 主循环,遍历 K 维度。

作用: GEMM 计算需要沿 K 维度累加。由于全局内存太大无法一次性装入共享内存,需要分块(Tiling)循环处理。


    copy(tAgA(_,_,k_tile), tAsA);

功能: 数据移动指令。

作用: 将当前线程负责的、第 k_tile 个 K 块中的全局内存数据 tAgA(_,_,k_tile) 异步拷贝到共享内存中它应处的位置 tAsAcute::copy 很可能利用到了 PTX 异步拷贝指令 (cp.async)。


    copy(tBgB(_,_,k_tile), tBsB);

功能: 将当前线程负责的 B 矩阵数据块拷贝到共享内存。


    cp_async_fence();

功能: 发出异步拷贝操作的栅栏(fence)。

作用: 确保之前发起的所有 cp.async 操作对该线程后续的 cp_async_wait 可见。它标记了一个“提交点”。


    cp_async_wait<0>();

功能: 等待异步拷贝完成。

作用<0> 表示等待所有之前发起的异步操作完成。确保在计算开始前,共享内存中的数据已经就绪。


    __syncthreads();

功能: 线程块级同步。

作用: 确保整个 CTA 的所有线程都已将它们的数据拷贝到共享内存,并且共享内存中的数据对所有线程都是完整和可见的。这是必须的,因为一个线程加载的数据可能被另一个线程使用。


    gemm(tCsA, tCsB, tCrC);

功能: 执行核心的矩阵乘加计算。

作用: 这很可能是一个高度优化的、针对 tCsA 和 tCsB 特定布局生成的内部函数或模板函数。它从共享内存读取数据,进行乘加操作 tCrC += tCsA * tCsB,并将结果累加到寄存器 tCrC 中。计算发生在线程的寄存器上,速度极快。


    __syncthreads();

功能: 再次进行线程块同步。

作用: 确保所有线程都已完成对当前共享内存块 sA 和 sB 的读取操作,之后才能开始下一轮循环并用新的数据覆盖共享内存缓冲区。这是防止数据竞争的必要条件。


  axpby(alpha, tCrC, beta, tCgC);

功能: 执行缩放和写回操作。

作用: 这很可能是一个函数或函数模板,执行操作 tCgC = alpha * tCrC + beta * tCgC。它将最终结果从寄存器 tCrC 缩放后,与全局内存中原始的 tCgC(如果 beta != 0)结合,并写回全局内存 tCgC

最后,内核函数结束 所有线程完成执行,CTA 的工作结束。

2. 主要思想

  1. 双缓冲机制: 通过循环处理 K 维度,实现计算和数据加载的重叠;

  2. 高效内存访问: 使用共享内存减少全局内存访问;

  3. 线程级并行: 精细的线程调度和数据分区;

  4. 模板元编程: 编译时优化,生成高度特化的代码;

  5. 异步拷贝: 使用 cp_async 指令隐藏内存延迟;

        评价:这个内核展示了现代 GPU 编程的最佳实践,通过精细的内存层次管理和线程调度来实现高性能的矩阵乘法运算。

3. 附录代码

template <class ProblemShape, class CtaTiler,
          class TA, class AStride, class ASmemLayout, class AThreadLayout,
          class TB, class BStride, class BSmemLayout, class BThreadLayout,
          class TC, class CStride, class CSmemLayout, class CThreadLayout,
          class Alpha, class Beta>
__global__ static
__launch_bounds__(decltype(size(CThreadLayout{}))::value)
void
gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
            TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA,
            TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB,
            TC      * C, CStride dC, CSmemLayout          , CThreadLayout tC,
            Alpha alpha, Beta beta)
{
  using namespace cute;

  // Preconditions
  CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});                   // (M, N, K)
  CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{});                   // (BLK_M, BLK_N, BLK_K)

  static_assert(is_static<AThreadLayout>::value);
  static_assert(is_static<BThreadLayout>::value);
  static_assert(is_static<CThreadLayout>::value);

  CUTE_STATIC_ASSERT_V(size(tA) == size(tB));                          // NumThreads
  CUTE_STATIC_ASSERT_V(size(tC) == size(tA));                          // NumThreads

  CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{});  // BLK_M / THR_M
  CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{});  // BLK_K / THR_K
  CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{});  // BLK_N / THR_N
  CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{});  // BLK_K / THR_K
  CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{});  // BLK_M / THR_M
  CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{});  // BLK_N / THR_N

  static_assert(is_static<ASmemLayout>::value);
  static_assert(is_static<BSmemLayout>::value);
  static_assert(is_static<CSmemLayout>::value);

  CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler));  // BLK_M
  CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler));  // BLK_M
  CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler));  // BLK_N
  CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler));  // BLK_N
  CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler));  // BLK_K
  CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler));  // BLK_K

  CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA));         // dA strides for shape MK
  CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB));         // dB strides for shape NK
  CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC));         // dC strides for shape MN

  //
  // Full and Tiled Tensors
  //

  // Represent the full tensors
  Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
  Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
  Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)

  // Get the appropriate blocks for this thread block
  auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);              // (m,n,k)
  Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)
  Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{});  // (BLK_N,BLK_K,k)
  Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)

  // Shared memory buffers
  __shared__ TA smemA[cosize_v<ASmemLayout>];
  __shared__ TB smemB[cosize_v<BSmemLayout>];
  Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K)
  Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K)

  //
  // Partition the copying of A and B tiles across the threads
  //

  // TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles

  Tensor tAgA = local_partition(gA, tA, threadIdx.x);                  // (THR_M,THR_K,k)
  Tensor tAsA = local_partition(sA, tA, threadIdx.x);                  // (THR_M,THR_K)

  Tensor tBgB = local_partition(gB, tB, threadIdx.x);                  // (THR_N,THR_K,k)
  Tensor tBsB = local_partition(sB, tB, threadIdx.x);                  // (THR_N,THR_K)

  CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA));                // THR_M
  CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA));                // THR_K
  CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB));                // THR_N
  CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB));                // THR_K

  //
  // Define A/B partitioning and C accumulators
  //

  // TUTORIAL: Example of partitioning via projections of a ThreadLayout tC

  // Partition sA (BLK_M, BLK_K) by the rows of tC
  Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{});   // (THR_M,BLK_K)
  // Partition sB (BLK_N, BLK_K) by the cols of tC
  Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{});   // (THR_N,BLK_K)
  // Partition gC (M,N) by the tile of tC
  Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{});   // (THR_M,THR_N)

  // Allocate the accumulators -- same shape/layout as the partitioned data
  Tensor tCrC = make_tensor_like(tCgC);                                // (THR_M,THR_N)

  CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC));                // THR_M
  CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA));                // THR_M
  CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC));                // THR_N
  CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB));                // THR_N
  CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB));                // BLK_K

  // Clear the accumulators
  clear(tCrC);

#if 0
  if(thread0()) {
    print("  mA : "); print(  mA); print("\n");
    print("  gA : "); print(  gA); print("\n");
    print("  sA : "); print(  sA); print("\n");
    print("tAgA : "); print(tAgA); print("\n");
    print("tAsA : "); print(tAsA); print("\n");
  }
#endif

#if 0
  if(thread0()) {
    print("  mB : "); print(  mB); print("\n");
    print("  gB : "); print(  gB); print("\n");
    print("  sB : "); print(  sB); print("\n");
    print("tBgB : "); print(tBgB); print("\n");
    print("tBsB : "); print(tBsB); print("\n");
  }
#endif

#if 0
  if(thread0()) {
    print("  mC : "); print(  mC); print("\n");
    print("  gC : "); print(  gC); print("\n");
    print("tCsA : "); print(tCsA); print("\n");
    print("tCsB : "); print(tCsB); print("\n");
    print("tCgC : "); print(tCgC); print("\n");
    print("tCrC : "); print(tCrC); print("\n");
  }
#endif

#if 1

  // TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory,
  //           and then computes on those tiles.
  //   copy(.) operates on the global and shared memory via the tA|tB partitioning
  //   gemm(.) operates on the shared and register memory via the tC partitioning

  auto K_TILE_MAX = size<2>(tAgA);

  for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile)
  {
    // Copy gmem to smem with tA|tB thread-partitioned tensors
    copy(tAgA(_,_,k_tile), tAsA);      // A   (THR_M,THR_K) -> (THR_M,THR_K)
    copy(tBgB(_,_,k_tile), tBsB);      // B   (THR_N,THR_K) -> (THR_N,THR_K)

    // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to
    //   Tensor tAgAk = tAgA(_,_,k_tile);
    //   CUTE_UNROLL
    //   for (int i = 0; i < size(tAsA); ++i) {
    //     tAsA(i) = tAgAk(i);
    //   }

    cp_async_fence();        // Label the end of (potential) cp.async instructions
    cp_async_wait<0>();      // Sync on all (potential) cp.async instructions
    __syncthreads();         // Wait for all threads to write to smem

    // Compute gemm on tC thread-partitioned smem
    gemm(tCsA, tCsB, tCrC);            // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K)

    // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to
    //   CUTE_UNROLL
    //   for (int k = 0; k < size<1>(tCsA); ++k) {
    //     CUTE_UNROLL
    //     for (int m = 0; m < size<0>(tCrC); ++m) {
    //       CUTE_UNROLL
    //       for (int n = 0; n < size<1>(tCrC); ++n) {
    //         tCrC(m,n) += tCsA(m,k) * tCsB(n,k);
    //       }
    //     }
    //   }

    __syncthreads();         // Wait for all threads to read from smem
  }

#endif

  //
  // Epilogue
  //

  axpby(alpha, tCrC, beta, tCgC);

  // TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to
  //   CUTE_UNROLL
  //   for (int i = 0; i < size(tCrC); ++i) {
  //     tCgC(i) = alpha * tCrC(i) + beta * tCgC(i);
  //   }
}

4. 逐行调试追踪

make_gmem_ptr(A):

make_tensor():

(cuda-gdb) s

(cuda-gdb) s

make_layout()

(cuda-gdb) s

Layout()

回到 366 行的 return  Tensor()

(cuda-gdb) s

完成各层函数,回到:

Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)

然后逻辑走到了 cta_coord = make_coord(...);

(cuda-gdb) s

完成后,走到了  gA=local_tile(...)

(cuda-gdb) s

(cuda-gdb) s

参数 dice:

(cuda-gdb) s

从参数返回后,回到cuda kernel 中的 gA = local_tile(...) 函数体中

(cuda-gdb) s

跳出参数后,进入 inner_partition() 的函数体:

s

返回隔层函数,回到  Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)

逻辑走到了  Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K)

(cuda-gdb) s

(cuda-gdb) s

逻辑走到了 cuda kernel 的  Tensor tAgA = local_partition(gA, tA, threadIdx.x);                  // (THR_M,THR_K,k)

(cuda-gdb) s

(cuda-gdb) s

进入了参数中的 shape(...)

s

s

到 了 outer_partition( )中的下一个参数的函数体中:

s

进入 outer_partition () 函数体:

逻辑走到了 

  clear(tCrC);

s

s

s

finish

copy() 的代码跟踪:

cp_async_fence() 追踪:

cp_async_wait() 追踪:

gemm() 追踪:

s

s

插曲:

data()

gemm()

s

s 跑这里来了:

gemm()

s

make_fragment()

s

s

gemm()

n

copy()

s

s

copy()

s

n -> copy()

copy()

s

s

next -> copy_if()

s

gemm()

next -> gemm() 496:

s

s

s

s ...

bt

跑的是这个 mma_unpack():

s

s

s

Logo

欢迎来到FlagOS开发社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐