在pytorch中使用自定义的cuda算子
本文完整演示了在PyTorch中集成自定义CUDA算子的标准流程,重点展示如何通过JIT编译机制实现自定义的cuda算子的开发与验证。核函数声明(sgemm.h):作为C++与CUDA的接口桥梁,声明设备函数原型。PyTorch扩展绑定(sgemm.cpp)CUDA核函数实现(sgemm.cu)Python测试模块(main.py).cpp文件:pybind绑定。.cu文件:具体的内核实现。.py
·
本文完整演示了在PyTorch中集成自定义CUDA算子的标准流程,重点展示如何通过JIT编译机制实现自定义的cuda算子的开发与验证。整个实现包含以下关键组件:
需预先安装Ninja构建工具:
sudo apt-get install ninja-build
主要涉及四个文件编写:
.cu文件:具体的内核实现
.h文件:声明
.cpp文件:pybind绑定
.py文件:加载后使用
CUDA核函数实现(sgemm.cu)
#define LOCATE(row, col, ld) ((row) * (ld) + (col))
__global__ void sgemm_kernal(float *C, float *A, float *B, const int M, const int N, const int K)
{
const int BLOCK_SIZE = 32;
const int SPLIT = 2;
int n = blockIdx.x * blockDim.x * SPLIT + threadIdx.x * SPLIT;
int m = blockIdx.y * blockDim.y * SPLIT + threadIdx.y * SPLIT;
int tn = threadIdx.x * SPLIT;
int tm = threadIdx.y * SPLIT;
constexpr int SHARE_SIZE = BLOCK_SIZE * SPLIT;
__shared__ float As[SHARE_SIZE][SHARE_SIZE];
__shared__ float Bs[SHARE_SIZE][SHARE_SIZE];
float sum[SPLIT][SPLIT] = {0.0f};
for (int i = 0; i < (K + SHARE_SIZE - 1) / SHARE_SIZE; i++)
{
for (int a = 0; a < SPLIT; a++)
{
for (int b = 0; b < SPLIT; b++)
{
if ((m + a) < M && (i * SHARE_SIZE + tn + b) < K) // A中不越界
{
As[tm + a][tn + b] = A[LOCATE((m + a), (i * SHARE_SIZE + tn + b), K)];
}
else
{
As[tm + a][tn + b] = 0.0f;
}
if ((n + b) < N && ((i * SHARE_SIZE + tm + a) < K)) // B中不越界
{
Bs[tm + a][tn + b] = B[LOCATE((i * SHARE_SIZE + tm + a), (n + b), N)];
}
else
{
Bs[tm + a][tn + b] = 0.0f;
}
}
}
__syncthreads();
for (int c = 0; c < SHARE_SIZE; c++)
{
for (int a = 0; a < SPLIT; a++)
{
for (int b = 0; b < SPLIT; b++)
{
sum[a][b] += As[tm + a][c] * Bs[c][tn + b];
}
}
}
__syncthreads();
}
for (int a = 0; a < SPLIT; a++)
{
for (int b = 0; b < SPLIT; b++)
{
if ((m + a) < M && (n + b) < N)
{
C[LOCATE((m + a), (n + b), N)] = sum[a][b];
}
}
}
}
void launch_sgemm(float *C, float *A, float *B, int M, int N, int K)
{
dim3 blockDim(32, 32);
dim3 gridDim((N + 32 - 1) / 64, (M + 32 - 1) / 64);
sgemm_kernal<<<gridDim, blockDim>>>(C, A, B, M, N, K);
}
核函数声明(sgemm.h):作为C++与CUDA的接口桥梁,声明设备函数原型。
void launch_sgemm(float *C, float *A, float *B, int M, int N, int K);
PyTorch扩展绑定(sgemm.cpp)
#include <torch/extension.h>
#include "sgemm.h"
void torch_launch_sgemm(torch::Tensor &c, const torch::Tensor &a, const torch::Tensor &b, int m, int n, int k)
{
launch_sgemm((float *)c.data_ptr(), (float *)a.data_ptr(), (float *)b.data_ptr(), m, n, k);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("torch_launch_sgemm", &torch_launch_sgemm, "Launch SGEMM kernel");
}
Python测试模块(main.py)
import torch
from torch.utils.cpp_extension import load
import numpy as np
import time
# 定义源文件路径
source_files = ["sgemm.cpp", "sgemm.cu"]
# 使用 JIT 编译加载自定义扩展
custom_ops = load(
name="custom_sgemm", # 自定义模块名称
sources=source_files, # 源文件列表
verbose=True # 打印编译过程中的详细信息
)
if __name__ == "__main__":
# 设置矩阵大小
M, N, K = 512, 512, 512
# 创建随机矩阵
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
B = torch.randn(K, N, device="cuda", dtype=torch.float32)
C = torch.empty(M, N, device="cuda", dtype=torch.float32)
# NumPy 矩阵
A_np = A.cpu().numpy()
B_np = B.cpu().numpy()
# 测试自定义的矩阵乘法算子
start_time = time.time()
custom_ops.torch_launch_sgemm(C, A, B, M, N, K)
custom_time = time.time() - start_time
# 测试 PyTorch 的 matmul
start_time = time.time()
C_ref = torch.matmul(A, B) # 使用 PyTorch 的 matmul 作为参考
torch_time = time.time() - start_time
# 测试 NumPy 的矩阵乘法
start_time = time.time()
C_np = np.matmul(A_np, B_np)
numpy_time = time.time() - start_time
# 验证结果
print("Custom SGEMM result:")
print(C)
print("Reference result (PyTorch):")
print(C_ref)
print("NumPy result:")
print(C_np)
print("Difference (L2 norm):", torch.norm(C - C_ref))
print("Difference (L2 norm) between NumPy and PyTorch:", np.linalg.norm(C_np - C_ref.cpu().numpy()))
# 打印时间测试结果
print(f"Custom SGEMM time: {custom_time:.6f} seconds")
print(f"PyTorch matmul time: {torch_time:.6f} seconds")
print(f"NumPy matmul time: {numpy_time:.6f} seconds")
欢迎来到FlagOS开发社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。
更多推荐
所有评论(0)