import numpy as np
from PIL import Image
from pathlib import Path
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()
device = cuda.Device(0 )
print (f"Cuda version: { "." . join([str (i) for i in cuda.get_version()])} " )
print (f"Device: \t { device. name()} " )
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
cu_file= "day_08_thread-coarsening.cu"
day_08_thread-coarsening.cu
#include <stdint.h>
#include <stdio.h>
#ifndef TILE_WIDTH
#ifdef __INTELLISENSE__
#define TILE_WIDTH 16
#else
#error "TILE_WIDTH must be defined"
#endif
#endif
#ifndef THREAD_COARSENING
#ifdef __INTELLISENSE__
#define THREAD_COARSENING 2
#else
#error "THREAD_COARSENING must be defined"
#endif
#endif
__global__ void matmul_fp32_tiled(float *m1, float *m2, float *res, uint32_t out_shape_0,
uint32_t out_shape_1, uint32_t inner_dim, uint32_t) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
__shared__ float m1_tile[TILE_WIDTH][TILE_WIDTH];
__shared__ float m2_tile[TILE_WIDTH][TILE_WIDTH];
int m1_x = inner_dim;
int m2_x = out_shape_1;
// Assume the matrices are multiples my block size on both dims.
float R = 0;
for (int tile = 0; tile < inner_dim / TILE_WIDTH; tile++) {
m1_tile[threadIdx.y][threadIdx.x] = m1[y * m1_x + tile * TILE_WIDTH + threadIdx.x];
m2_tile[threadIdx.y][threadIdx.x] = m2[(tile * TILE_WIDTH + threadIdx.y) * m2_x + x];
__syncthreads();
for (int i = 0; i < TILE_WIDTH; i++) {
R += m1_tile[threadIdx.y][i] * m2_tile[i][threadIdx.x];
}
__syncthreads();
}
res[y * out_shape_1 + x] = R;
}
__global__ void matmul_fp32_tiled_coarse(float *m1, float *m2, float *res, uint32_t out_shape_0,
uint32_t out_shape_1, uint32_t inner_dim, uint32_t) {
int x = blockIdx.x * blockDim.x * THREAD_COARSENING + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
// if (threadIdx.x == 0 && threadIdx.y == 0) {
// printf("blockIdx = (%d, %d), mx = %d, y = %d\n", blockIdx.x, blockIdx.y, x, y);
// }
__shared__ float m1_tile[TILE_WIDTH][TILE_WIDTH];
__shared__ float m2_tile[TILE_WIDTH][TILE_WIDTH];
float R[THREAD_COARSENING];
for (int i = 0; i < THREAD_COARSENING; i++) {
R[i] = 0;
}
int m1_x = inner_dim;
int m2_x = out_shape_1;
// We are going to coarse the thread over x, so let's load the tile from the
// second matrix.
for (int tile = 0; tile < inner_dim / TILE_WIDTH; tile++) {
m1_tile[threadIdx.y][threadIdx.x] = m1[y * m1_x + tile * TILE_WIDTH + threadIdx.x];
// Now, we are going to calculate a bunch consecutive tiles one by one,
// so we need to load the
for (int c = 0; c < THREAD_COARSENING; c++) {
m2_tile[threadIdx.y][threadIdx.x] =
m2[(tile * TILE_WIDTH + threadIdx.y) * m2_x + c * TILE_WIDTH + x];
__syncthreads();
for (int i = 0; i < TILE_WIDTH; i++) {
R[c] += m1_tile[threadIdx.y][i] * m2_tile[i][threadIdx.x];
}
__syncthreads();
}
}
for (int c = 0; c < THREAD_COARSENING; c++) {
res[y * out_shape_1 + c * TILE_WIDTH + x] = R[c];
}
}
from lovely_numpy import Lo
## Compiler options for more compile-time warnings.
warn_options= [
'-Xcompiler' , '-Wall' ,
'-Xcompiler' , '-Wextra' ,
'-Xcompiler' , '-Wsign-conversion' ,
'-Xcompiler' , '-Wcast-qual' ,
'-Xcompiler' , '-Wunused-parameter' ,
'-Xcompiler' , '-Wdouble-promotion' ,
'-Xcompiler' , '-Wformat=2' ,
'-Xcompiler' , '-Wfloat-equal' ,
'-Xcompiler' , '-Wshadow'
]
def benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat= 10 , warmup= True ):
assert len (m1.shape) == 2
assert len (m2.shape) == 2
assert m1.shape[1 ] == m2.shape[0 ]
out_shape = (m1.shape[0 ], m2.shape[1 ])
gpu_m1 = cuda.mem_alloc_like(m1)
gpu_m2 = cuda.mem_alloc_like(m2)
res = np.empty(out_shape, dtype= np.float32)
cuda.memcpy_htod(gpu_m1, m1)
cuda.memcpy_htod(gpu_m2, m2)
ctx.synchronize()
timing= 0
for _ in range (repeat):
start = cuda.Event()
end = cuda.Event()
gpu_res = cuda.mem_alloc_like(res)
if warmup:
kernel(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[0 ]), np.uint32(out_shape[1 ]), np.uint32(m1.shape[1 ]), grid= grid_size, block= block_size)
ctx.synchronize()
start.record()
kernel(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[0 ]), np.uint32(out_shape[1 ]), np.uint32(m1.shape[1 ]), grid= grid_size, block= block_size)
end.record()
end.synchronize()
timing += end.time_since(start)
timing /= repeat
cuda.memcpy_dtoh(res, gpu_res)
return res, timing;
m1 = np.random.randn(8192 , 8192 ).astype(np.float32)
m2 = np.random.randn(8192 , 8192 ).astype(np.float32)
np_res = np.matmul(m1, m2)
tile_width = 32
coarsening = 4
ctx = device.make_context()
try :
mod = SourceModule(
Path(cu_file).read_text(),
options= warn_options + [
f"-D TILE_WIDTH= { tile_width} " ,
f"-D THREAD_COARSENING= { coarsening} "
])
kernel = mod.get_function("matmul_fp32_tiled_coarse" )
out_shape = (m1.shape[0 ], m2.shape[1 ])
block_size = (tile_width, tile_width, 1 )
grid_size = (
((out_shape[1 ] + tile_width * coarsening - 1 ) // (tile_width * coarsening)),
((out_shape[0 ] + tile_width - 1 ) // tile_width),
1
)
print (f"Matrix 1 shape: { m1. shape} " )
print (f"Matrix 2 shape: { m2. shape} " )
print (f"Result shape: { out_shape} " )
print (f"Grid size: { grid_size} " )
print (f"Block size: { block_size} " )
print (f"Total threads: { grid_size[0 ] * grid_size[1 ] * block_size[0 ] * block_size[1 ]} " )
res, timing = benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat= 2 , warmup= True )
finally :
ctx.pop()
ctx.detach()
print (Lo(res))
print (f"Took { timing:.3f} ms" )
Matrix 1 shape: (8192, 8192)
Matrix 2 shape: (8192, 8192)
Result shape: (8192, 8192)
Grid size: (64, 256, 1)
Block size: (32, 32, 1)
Total threads: 16777216
array[8192, 8192] f32 n=67108864 (0.2Gb) x∈[-489.726, 533.151] μ=0.007 σ=90.525
Took 755.906ms
np.isclose(res, np_res).mean()
np.float64(0.9413013905286789)
def benchmark(dim, tile_width, coarsening):
m1 = np.random.randn(dim, dim).astype(np.float32)
m2 = np.random.randn(dim, dim).astype(np.float32)
ctx = device.make_context()
try :
mod = SourceModule(
Path(cu_file).read_text(),
options= warn_options + [
f"-D TILE_WIDTH= { tile_width} " ,
f"-D THREAD_COARSENING= { coarsening} "
])
kernel = mod.get_function("matmul_fp32_tiled_coarse" )
out_shape = (m1.shape[0 ], m2.shape[1 ])
block_size = (tile_width, tile_width, 1 )
grid_size = (
((out_shape[1 ] + tile_width * coarsening - 1 ) // (tile_width * coarsening)),
((out_shape[0 ] + tile_width - 1 ) // tile_width),
1
)
# print(f"Matrix 1 shape: {m1.shape}")
# print(f"Matrix 2 shape: {m2.shape}")
# print(f"Result shape: {out_shape}")
# print(f"Grid size: {grid_size}")
# print(f"Block size: {block_size}")
# print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")
res, timing = benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat= 2 , warmup= True )
finally :
ctx.pop()
ctx.detach()
return res, timing
print ("Matmul 8192x8192 with tile size 32x32 and thread coarsening along x:" )
for c in [1 , 2 , 4 , 8 ]:
res, timing = benchmark(8192 , 32 , c)
print (f"coarsening = { c} : { timing:.2f} ms" )
Matmul 8192x8192 with tile size 32x32 and thread coarsening along x:
coarsening = 1: 764.46ms
coarsening = 2: 758.32ms
coarsening = 4: 749.86ms
coarsening = 8: 746.75ms
Coarsening helps, but only a small bit.