import numpy as np
from PIL import Image
from pathlib import Path
Day 6 - Tiled matmul
Let’start with a square matrix that is multiple of block width.
TODO: Check what’s the performance penalty of the boundary check. It might be better to just force matrices to be multiple of block size with padding.
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()
= cuda.Device(0)
device
print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
= "kernels/matmul/matmul-tiled.cu" cu_file
kernels/matmul/matmul-tiled.cu
#include <stdint.h>
#include <stdio.h>
// We will use square blocks to keep things sane.
#define BLOCK_WIDTH 16
void matmul_fp32_tiled(float *m1, float *m2, float *res, uint32_t out_shape_0,
__global__ uint32_t out_shape_1, uint32_t inner_dim, uint32_t) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
float m1_tile[BLOCK_WIDTH][BLOCK_WIDTH];
__shared__ float m2_tile[BLOCK_WIDTH][BLOCK_WIDTH];
__shared__
int m1_x = inner_dim;
int m2_x = out_shape_1;
// Assume the matrices are multiples my block size on both dims.
float R = 0;
for (int tile = 0; tile < inner_dim / BLOCK_WIDTH; tile++) {
[threadIdx.y][threadIdx.x] = m1[y * m1_x + tile * BLOCK_WIDTH + threadIdx.x];
m1_tile[threadIdx.y][threadIdx.x] = m2[(tile * BLOCK_WIDTH + threadIdx.y) * m2_x + x];
m2_tile
();
__syncthreads
for (int i = 0; i < BLOCK_WIDTH; i++) {
+= m1_tile[threadIdx.y][i] * m2_tile[i][threadIdx.x];
R }
();
__syncthreads}
[y * out_shape_1 + x] = R;
res}
from lovely_numpy import Lo
= np.random.randn(1024, 1024).astype(np.float32)
m1 = np.random.randn(1024, 1024).astype(np.float32)
m2
= np.matmul(m1, m2)
np_res Lo(np_res)
array[1024, 1024] f32 n=1048576 (4Mb) x∈[-146.704, 153.927] μ=-0.014 σ=32.009
= 16 # 16x16
BLOCK_SIZE
assert(len(m1.shape) == 2)
assert(len(m2.shape) == 2)
assert(m1.shape == m2.shape) # Make them equal for now
= (m1.shape[0], m2.shape[1])
out_shape
try:
= device.make_context()
ctx
= SourceModule(Path(cu_file).read_text(),
mod =[
options'-Xcompiler', '-Wall',
'-Xcompiler', '-Wextra',
'-Xcompiler', '-Wsign-conversion',
'-Xcompiler', '-Wcast-qual',
'-Xcompiler', '-Wunused-parameter',
'-Xcompiler', '-Wdouble-promotion',
'-Xcompiler', '-Wformat=2',
'-Xcompiler', '-Wfloat-equal',
'-Xcompiler', '-Wshadow'
]
)
= mod.get_function("matmul_fp32_tiled")
matmul_tiled
= cuda.mem_alloc_like(m1)
gpu_m1 = cuda.mem_alloc_like(m2)
gpu_m2
= np.empty(out_shape, dtype=np.float32)
res
= cuda.mem_alloc_like(res)
gpu_res
cuda.memcpy_htod(gpu_m1, m1)
cuda.memcpy_htod(gpu_m2, m2)
= (BLOCK_SIZE, BLOCK_SIZE, 1)
block_size = (
grid_size 1] + BLOCK_SIZE - 1) // BLOCK_SIZE),
((out_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE),
((out_shape[1
)
print(f"Matrix 1 shape: {m1.shape}")
print(f"Matrix 2 shape: {m2.shape}")
print(f"Result shape: {out_shape}")
print(f"Grid size: {grid_size}")
print(f"Block size: {block_size}")
print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")
ctx.synchronize()
1]), np.uint32(out_shape[0]), np.uint32(m1.shape[1]), grid=grid_size, block=block_size)
matmul_tiled(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[
ctx.synchronize()
cuda.memcpy_dtoh(res, gpu_res)
ctx.synchronize()
finally:
ctx.pop()
ctx.detach()
Lo(res)
Matrix 1 shape: (1024, 1024)
Matrix 2 shape: (1024, 1024)
Result shape: (1024, 1024)
Grid size: (64, 64, 1)
Block size: (16, 16, 1)
Total threads: 1048576
array[1024, 1024] f32 n=1048576 (4Mb) x∈[-146.704, 153.927] μ=-0.014 σ=32.009
float(np.isclose(res, np_res).mean())
0.9787321090698242
Numerical stability
We have the same numerical error situation as with naive matmul, let’s compare with the non-tiled result - it should match exactly.
= 16 # 16x16
BLOCK_SIZE
assert(len(m1.shape) == 2)
assert(len(m2.shape) == 2)
assert(m1.shape == m2.shape) # Make them equal for now
= (m1.shape[0], m2.shape[1])
out_shape
try:
= device.make_context()
ctx
= SourceModule(Path("kernels/matmul/matmul.cu").read_text(),
mod =[
options'-Xcompiler', '-Wall',
'-Xcompiler', '-Wextra',
'-Xcompiler', '-Wsign-conversion',
'-Xcompiler', '-Wcast-qual',
'-Xcompiler', '-Wunused-parameter',
'-Xcompiler', '-Wdouble-promotion',
'-Xcompiler', '-Wformat=2',
'-Xcompiler', '-Wfloat-equal',
'-Xcompiler', '-Wshadow'
]
)
= mod.get_function("matmul_f32")
matmul_naive
= cuda.mem_alloc_like(m1)
gpu_m1 = cuda.mem_alloc_like(m2)
gpu_m2
= np.empty(out_shape, dtype=np.float32)
res_naive
= cuda.mem_alloc_like(res)
gpu_res_naive
cuda.memcpy_htod(gpu_m1, m1)
cuda.memcpy_htod(gpu_m2, m2)
= (BLOCK_SIZE, BLOCK_SIZE, 1)
block_size = (
grid_size 1] + BLOCK_SIZE - 1) // BLOCK_SIZE),
((out_shape[0] + BLOCK_SIZE - 1) // BLOCK_SIZE),
((out_shape[1
)
print(f"Matrix 1 shape: {m1.shape}")
print(f"Matrix 2 shape: {m2.shape}")
print(f"Result shape: {out_shape}")
print(f"Grid size: {grid_size}")
print(f"Block size: {block_size}")
print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")
ctx.synchronize()
1]), np.uint32(out_shape[0]), np.uint32(m1.shape[1]), grid=grid_size, block=block_size)
matmul_naive(gpu_m1, gpu_m2, gpu_res_naive, np.uint32(out_shape[
ctx.synchronize()
cuda.memcpy_dtoh(res_naive, gpu_res_naive)
ctx.synchronize()
finally:
ctx.pop()
ctx.detach()
Lo(res_naive)
Matrix 1 shape: (1024, 1024)
Matrix 2 shape: (1024, 1024)
Result shape: (1024, 1024)
Grid size: (64, 64, 1)
Block size: (16, 16, 1)
Total threads: 1048576
array[1024, 1024] f32 n=1048576 (4Mb) x∈[-146.704, 153.927] μ=-0.014 σ=32.009
all() np.isclose(res, res_naive).
np.True_
Yaaay, they match. This was the first attempt based on memory/understading of what I read in chapter 4, and it worked on the first try 😎