本文完整演示了在PyTorch中集成自定义CUDA算子的标准流程,重点展示如何通过JIT编译机制实现自定义的cuda算子的开发与验证。整个实现包含以下关键组件:

需预先安装Ninja构建工具:

sudo apt-get install ninja-build

主要涉及四个文件编写:
.cu文件:具体的内核实现
.h文件:声明
.cpp文件:pybind绑定
.py文件:加载后使用

CUDA核函数实现(sgemm.cu)

#define LOCATE(row, col, ld) ((row) * (ld) + (col))

__global__ void sgemm_kernal(float *C, float *A, float *B, const int M, const int N, const int K)
{
    const int BLOCK_SIZE = 32;
    const int SPLIT = 2;

    int n = blockIdx.x * blockDim.x * SPLIT + threadIdx.x * SPLIT;
    int m = blockIdx.y * blockDim.y * SPLIT + threadIdx.y * SPLIT;

    int tn = threadIdx.x * SPLIT;
    int tm = threadIdx.y * SPLIT;

    constexpr int SHARE_SIZE = BLOCK_SIZE * SPLIT;

    __shared__ float As[SHARE_SIZE][SHARE_SIZE];
    __shared__ float Bs[SHARE_SIZE][SHARE_SIZE];

    float sum[SPLIT][SPLIT] = {0.0f};

    for (int i = 0; i < (K + SHARE_SIZE - 1) / SHARE_SIZE; i++)
    {
        for (int a = 0; a < SPLIT; a++)
        {
            for (int b = 0; b < SPLIT; b++)
            {
                if ((m + a) < M && (i * SHARE_SIZE + tn + b) < K) // A中不越界
                {
                    As[tm + a][tn + b] = A[LOCATE((m + a), (i * SHARE_SIZE + tn + b), K)];
                }
                else
                {
                    As[tm + a][tn + b] = 0.0f;
                }
                if ((n + b) < N && ((i * SHARE_SIZE + tm + a) < K)) // B中不越界
                {
                    Bs[tm + a][tn + b] = B[LOCATE((i * SHARE_SIZE + tm + a), (n + b), N)];
                }
                else
                {
                    Bs[tm + a][tn + b] = 0.0f;
                }
            }
        }
        __syncthreads();
        for (int c = 0; c < SHARE_SIZE; c++)
        {
            for (int a = 0; a < SPLIT; a++)
            {
                for (int b = 0; b < SPLIT; b++)
                {
                    sum[a][b] += As[tm + a][c] * Bs[c][tn + b];
                }
            }
        }
        __syncthreads();
    }

    for (int a = 0; a < SPLIT; a++)
    {
        for (int b = 0; b < SPLIT; b++)
        {
            if ((m + a) < M && (n + b) < N)
            {
                C[LOCATE((m + a), (n + b), N)] = sum[a][b];
            }
        }
    }
}

void launch_sgemm(float *C, float *A, float *B, int M, int N, int K)
{
    dim3 blockDim(32, 32);
    dim3 gridDim((N + 32 - 1) / 64, (M + 32 - 1) / 64);
    sgemm_kernal<<<gridDim, blockDim>>>(C, A, B, M, N, K);
}

核函数声明(sgemm.h):作为C++与CUDA的接口桥梁,声明设备函数原型。

void launch_sgemm(float *C, float *A, float *B, int M, int N, int K);

PyTorch扩展绑定(sgemm.cpp)

#include <torch/extension.h>
#include "sgemm.h"

void torch_launch_sgemm(torch::Tensor &c, const torch::Tensor &a, const torch::Tensor &b, int m, int n, int k)
{
    launch_sgemm((float *)c.data_ptr(), (float *)a.data_ptr(), (float *)b.data_ptr(), m, n, k);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("torch_launch_sgemm", &torch_launch_sgemm, "Launch SGEMM kernel");
}

Python测试模块(main.py)

import torch
from torch.utils.cpp_extension import load
import numpy as np
import time

# 定义源文件路径
source_files = ["sgemm.cpp", "sgemm.cu"]

# 使用 JIT 编译加载自定义扩展
custom_ops = load(
    name="custom_sgemm",  # 自定义模块名称
    sources=source_files,  # 源文件列表
    verbose=True  # 打印编译过程中的详细信息
)

if __name__ == "__main__":
    # 设置矩阵大小
    M, N, K = 512, 512, 512

    # 创建随机矩阵
    A = torch.randn(M, K, device="cuda", dtype=torch.float32)
    B = torch.randn(K, N, device="cuda", dtype=torch.float32)
    C = torch.empty(M, N, device="cuda", dtype=torch.float32)

    # NumPy 矩阵
    A_np = A.cpu().numpy()
    B_np = B.cpu().numpy()

    # 测试自定义的矩阵乘法算子
    start_time = time.time()
    custom_ops.torch_launch_sgemm(C, A, B, M, N, K)
    custom_time = time.time() - start_time

    # 测试 PyTorch 的 matmul
    start_time = time.time()
    C_ref = torch.matmul(A, B)  # 使用 PyTorch 的 matmul 作为参考
    torch_time = time.time() - start_time

    # 测试 NumPy 的矩阵乘法
    start_time = time.time()
    C_np = np.matmul(A_np, B_np)
    numpy_time = time.time() - start_time

    # 验证结果
    print("Custom SGEMM result:")
    print(C)
    print("Reference result (PyTorch):")
    print(C_ref)
    print("NumPy result:")
    print(C_np)
    print("Difference (L2 norm):", torch.norm(C - C_ref))
    print("Difference (L2 norm) between NumPy and PyTorch:", np.linalg.norm(C_np - C_ref.cpu().numpy()))

    # 打印时间测试结果
    print(f"Custom SGEMM time: {custom_time:.6f} seconds")
    print(f"PyTorch matmul time: {torch_time:.6f} seconds")
    print(f"NumPy matmul time: {numpy_time:.6f} seconds")
Logo

欢迎来到FlagOS开发社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐