import numpy as np
from PIL import Image
from pathlib import Path
Day 8 - Thread coarsening
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()
= cuda.Device(0)
device
print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
="kernels/matmul/matmul-thread-coarsening.cu" cu_file
kernels/matmul/matmul-thread-coarsening.cu
#include <stdint.h>
#include <stdio.h>
#ifndef TILE_WIDTH
#ifdef __INTELLISENSE__
#define TILE_WIDTH 16
#else
#error "TILE_WIDTH must be defined"
#endif
#endif
#ifndef THREAD_COARSENING
#ifdef __INTELLISENSE__
#define THREAD_COARSENING 2
#else
#error "THREAD_COARSENING must be defined"
#endif
#endif
void matmul_fp32_tiled_coarse(float *m1, float *m2, float *res, uint32_t out_shape_0,
__global__ uint32_t out_shape_1, uint32_t inner_dim, uint32_t) {
int x = blockIdx.x * blockDim.x * THREAD_COARSENING + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
// if (threadIdx.x == 0 && threadIdx.y == 0) {
// printf("blockIdx = (%d, %d), mx = %d, y = %d\n", blockIdx.x, blockIdx.y, x, y);
// }
float m1_tile[TILE_WIDTH][TILE_WIDTH];
__shared__ float m2_tile[TILE_WIDTH][TILE_WIDTH];
__shared__
float R[THREAD_COARSENING];
for (int i = 0; i < THREAD_COARSENING; i++) {
[i] = 0;
R}
int m1_x = inner_dim;
int m2_x = out_shape_1;
// We are going to coarse the thread over x, so let's load the tile from the
// second matrix.
for (int tile = 0; tile < inner_dim / TILE_WIDTH; tile++) {
[threadIdx.y][threadIdx.x] = m1[y * m1_x + tile * TILE_WIDTH + threadIdx.x];
m1_tile
// Now, we are going to calculate a bunch consecutive tiles one by one,
// so we need to load the
for (int c = 0; c < THREAD_COARSENING; c++) {
[threadIdx.y][threadIdx.x] =
m2_tile[(tile * TILE_WIDTH + threadIdx.y) * m2_x + c * TILE_WIDTH + x];
m2
();
__syncthreads
for (int i = 0; i < TILE_WIDTH; i++) {
[c] += m1_tile[threadIdx.y][i] * m2_tile[i][threadIdx.x];
R}
();
__syncthreads}
}
for (int c = 0; c < THREAD_COARSENING; c++) {
[y * out_shape_1 + c * TILE_WIDTH + x] = R[c];
res}
}
from lovely_numpy import Lo
## Compiler options for more compile-time warnings.
=[
warn_options'-Xcompiler', '-Wall',
'-Xcompiler', '-Wextra',
'-Xcompiler', '-Wsign-conversion',
'-Xcompiler', '-Wcast-qual',
'-Xcompiler', '-Wunused-parameter',
'-Xcompiler', '-Wdouble-promotion',
'-Xcompiler', '-Wformat=2',
'-Xcompiler', '-Wfloat-equal',
'-Xcompiler', '-Wshadow'
]
def benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat=10, warmup=True):
assert len(m1.shape) == 2
assert len(m2.shape) == 2
assert m1.shape[1] == m2.shape[0]
= (m1.shape[0], m2.shape[1])
out_shape
= cuda.mem_alloc_like(m1)
gpu_m1 = cuda.mem_alloc_like(m2)
gpu_m2
= np.empty(out_shape, dtype=np.float32)
res
cuda.memcpy_htod(gpu_m1, m1)
cuda.memcpy_htod(gpu_m2, m2)
ctx.synchronize()
=0
timingfor _ in range(repeat):
= cuda.Event()
start = cuda.Event()
end
= cuda.mem_alloc_like(res)
gpu_res
if warmup:
0]), np.uint32(out_shape[1]), np.uint32(m1.shape[1]), grid=grid_size, block=block_size)
kernel(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[
ctx.synchronize()
start.record()0]), np.uint32(out_shape[1]), np.uint32(m1.shape[1]), grid=grid_size, block=block_size)
kernel(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[
end.record()
end.synchronize()
+= end.time_since(start)
timing /= repeat
timing
cuda.memcpy_dtoh(res, gpu_res)return res, timing
= np.random.randn(8192, 8192).astype(np.float32)
m1 = np.random.randn(8192, 8192).astype(np.float32)
m2
= np.matmul(m1, m2)
np_res
= 32
tile_width = 4
coarsening
= device.make_context()
ctx try:
= SourceModule(
mod
Path(cu_file).read_text(),=warn_options + [
optionsf"-D TILE_WIDTH={tile_width}",
f"-D THREAD_COARSENING={coarsening}"
])
= mod.get_function("matmul_fp32_tiled_coarse")
kernel
= (m1.shape[0], m2.shape[1])
out_shape
= (tile_width, tile_width, 1)
block_size = (
grid_size 1] + tile_width * coarsening - 1) // (tile_width * coarsening)),
((out_shape[0] + tile_width - 1) // tile_width),
((out_shape[1
)
print(f"Matrix 1 shape: {m1.shape}")
print(f"Matrix 2 shape: {m2.shape}")
print(f"Result shape: {out_shape}")
print(f"Grid size: {grid_size}")
print(f"Block size: {block_size}")
print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")
= benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat=2, warmup=True)
res, timing finally:
ctx.pop()
ctx.detach()
print(Lo(res))
print(f"Took {timing:.3f}ms")
Matrix 1 shape: (8192, 8192)
Matrix 2 shape: (8192, 8192)
Result shape: (8192, 8192)
Grid size: (64, 256, 1)
Block size: (32, 32, 1)
Total threads: 16777216
array[8192, 8192] f32 n=67108864 (0.2Gb) x∈[-517.091, 500.085] μ=0.005 σ=90.506
Took 1331.796ms
np.isclose(res, np_res).mean()
np.float64(0.9412752240896225)
Run the test
def benchmark(dim, tile_width, coarsening):
= np.random.randn(dim, dim).astype(np.float32)
m1 = np.random.randn(dim, dim).astype(np.float32)
m2
= device.make_context()
ctx try:
= SourceModule(
mod
Path(cu_file).read_text(),=warn_options + [
optionsf"-D TILE_WIDTH={tile_width}",
f"-D THREAD_COARSENING={coarsening}"
])
= mod.get_function("matmul_fp32_tiled_coarse")
kernel
= (m1.shape[0], m2.shape[1])
out_shape
= (tile_width, tile_width, 1)
block_size = (
grid_size 1] + tile_width * coarsening - 1) // (tile_width * coarsening)),
((out_shape[0] + tile_width - 1) // tile_width),
((out_shape[1
)
# print(f"Matrix 1 shape: {m1.shape}")
# print(f"Matrix 2 shape: {m2.shape}")
# print(f"Result shape: {out_shape}")
# print(f"Grid size: {grid_size}")
# print(f"Block size: {block_size}")
# print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")
= benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat=2, warmup=True)
res, timing finally:
ctx.pop()
ctx.detach()
return res, timing
print("Matmul 8192x8192 with tile size 32x32 and thread coarsening along x:")
for c in [1, 2, 4, 8]:
= benchmark(8192, 32, c)
res, timing print(f"coarsening = {c}: {timing:.2f}ms")
Matmul 8192x8192 with tile size 32x32 and thread coarsening along x:
coarsening = 1: 1357.31ms
coarsening = 2: 1323.33ms
coarsening = 4: 1300.67ms
coarsening = 8: 1300.94ms
Coarsening helps, but only a small bit.