cuda-100
  1. Day 8 - Thread coarsening
  • 100 days of CUDA
  • Day 0 - playing with PyCUDA
  • Day 1 - playing with nvcc
  • Day 2 - RGB to grayscale
  • Day 3 - RGB blur
  • Day 4 - Naive matmul+exercises
  • Day 5 - Matrix-vector multiplication
  • Day 6 - Tiled matmul
  • Day 7 - Tiled matmul experiments
  • Day 8 - Thread coarsening
  • Day 9 - Conv 2D
  • Day 10 - Improving Conv2d performance
  • Day 11 - conv2d with shared memory
  • Day 12 - conv2d with shared memory and halo

On this page

  • kernels/matmul/matmul-thread-coarsening.cu
  • Run the test

Day 8 - Thread coarsening

import numpy as np
from PIL import Image
from pathlib import Path
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()

device = cuda.Device(0)

print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
cu_file="kernels/matmul/matmul-thread-coarsening.cu"

kernels/matmul/matmul-thread-coarsening.cu

#include <stdint.h>
#include <stdio.h>

#ifndef TILE_WIDTH
#ifdef __INTELLISENSE__
#define TILE_WIDTH 16
#else
#error "TILE_WIDTH must be defined"
#endif
#endif

#ifndef THREAD_COARSENING
#ifdef __INTELLISENSE__
#define THREAD_COARSENING 2
#else
#error "THREAD_COARSENING must be defined"
#endif
#endif

__global__ void matmul_fp32_tiled_coarse(float *m1, float *m2, float *res, uint32_t out_shape_0,
                                         uint32_t out_shape_1, uint32_t inner_dim, uint32_t) {
    int x = blockIdx.x * blockDim.x * THREAD_COARSENING + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;

    // if (threadIdx.x == 0 && threadIdx.y == 0) {
    //     printf("blockIdx = (%d, %d), mx = %d, y = %d\n", blockIdx.x, blockIdx.y, x, y);
    // }

    __shared__ float m1_tile[TILE_WIDTH][TILE_WIDTH];
    __shared__ float m2_tile[TILE_WIDTH][TILE_WIDTH];

    float R[THREAD_COARSENING];
    for (int i = 0; i < THREAD_COARSENING; i++) {
        R[i] = 0;
    }

    int m1_x = inner_dim;
    int m2_x = out_shape_1;

    // We are going to coarse the thread over x, so let's load the tile from the
    // second matrix.

    for (int tile = 0; tile < inner_dim / TILE_WIDTH; tile++) {
        m1_tile[threadIdx.y][threadIdx.x] = m1[y * m1_x + tile * TILE_WIDTH + threadIdx.x];

        // Now, we are going to calculate a bunch consecutive tiles one by one,
        // so we need to load the
        for (int c = 0; c < THREAD_COARSENING; c++) {
            m2_tile[threadIdx.y][threadIdx.x] =
                m2[(tile * TILE_WIDTH + threadIdx.y) * m2_x + c * TILE_WIDTH + x];

            __syncthreads();

            for (int i = 0; i < TILE_WIDTH; i++) {
                R[c] += m1_tile[threadIdx.y][i] * m2_tile[i][threadIdx.x];
            }

            __syncthreads();
        }
    }

    for (int c = 0; c < THREAD_COARSENING; c++) {
        res[y * out_shape_1 + c * TILE_WIDTH + x] = R[c];
    }
}
from lovely_numpy import Lo
## Compiler options for more compile-time warnings.
warn_options=[
    '-Xcompiler', '-Wall',
    '-Xcompiler', '-Wextra',
    '-Xcompiler', '-Wsign-conversion',
    '-Xcompiler', '-Wcast-qual',
    '-Xcompiler', '-Wunused-parameter',
    '-Xcompiler', '-Wdouble-promotion',
    '-Xcompiler', '-Wformat=2',
    '-Xcompiler', '-Wfloat-equal',
    '-Xcompiler', '-Wshadow'
]
def benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat=10, warmup=True):
    assert len(m1.shape) == 2
    assert len(m2.shape) == 2
    assert m1.shape[1] == m2.shape[0]

    out_shape = (m1.shape[0], m2.shape[1])

    gpu_m1 = cuda.mem_alloc_like(m1)
    gpu_m2 = cuda.mem_alloc_like(m2)

    res = np.empty(out_shape, dtype=np.float32)

    cuda.memcpy_htod(gpu_m1, m1)
    cuda.memcpy_htod(gpu_m2, m2)
    ctx.synchronize()

    timing=0
    for _ in range(repeat):
        start = cuda.Event()
        end = cuda.Event()

        gpu_res = cuda.mem_alloc_like(res)

        if warmup:
            kernel(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[0]), np.uint32(out_shape[1]), np.uint32(m1.shape[1]), grid=grid_size, block=block_size)
            ctx.synchronize()

        start.record()
        kernel(gpu_m1, gpu_m2, gpu_res, np.uint32(out_shape[0]), np.uint32(out_shape[1]), np.uint32(m1.shape[1]), grid=grid_size, block=block_size)
        end.record()
        end.synchronize()

        timing += end.time_since(start)
    timing /= repeat

    cuda.memcpy_dtoh(res, gpu_res)
    return res, timing
m1 = np.random.randn(8192, 8192).astype(np.float32)
m2 = np.random.randn(8192, 8192).astype(np.float32)

np_res = np.matmul(m1, m2)

tile_width = 32
coarsening = 4

ctx = device.make_context()
try:
    mod = SourceModule(
        Path(cu_file).read_text(),
        options=warn_options + [
            f"-D TILE_WIDTH={tile_width}",
            f"-D THREAD_COARSENING={coarsening}"
            ])

    kernel = mod.get_function("matmul_fp32_tiled_coarse")

    out_shape = (m1.shape[0], m2.shape[1])

    block_size = (tile_width, tile_width, 1)
    grid_size = (
        ((out_shape[1] + tile_width * coarsening - 1) // (tile_width * coarsening)),
        ((out_shape[0] + tile_width - 1) // tile_width),
        1
    )


    print(f"Matrix 1 shape: {m1.shape}")
    print(f"Matrix 2 shape: {m2.shape}")
    print(f"Result shape: {out_shape}")
    print(f"Grid size: {grid_size}")
    print(f"Block size: {block_size}")
    print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")


    res, timing = benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat=2, warmup=True)
finally:
    ctx.pop()
    ctx.detach()

print(Lo(res))
print(f"Took {timing:.3f}ms")
Matrix 1 shape: (8192, 8192)
Matrix 2 shape: (8192, 8192)
Result shape: (8192, 8192)
Grid size: (64, 256, 1)
Block size: (32, 32, 1)
Total threads: 16777216
array[8192, 8192] f32 n=67108864 (0.2Gb) x∈[-517.091, 500.085] μ=0.005 σ=90.506
Took 1331.796ms
np.isclose(res, np_res).mean()
np.float64(0.9412752240896225)

Run the test

def benchmark(dim, tile_width, coarsening):
    m1 = np.random.randn(dim, dim).astype(np.float32)
    m2 = np.random.randn(dim, dim).astype(np.float32)

    ctx = device.make_context()
    try:
        mod = SourceModule(
            Path(cu_file).read_text(),
            options=warn_options + [
                f"-D TILE_WIDTH={tile_width}",
                f"-D THREAD_COARSENING={coarsening}"
                ])

        kernel = mod.get_function("matmul_fp32_tiled_coarse")

        out_shape = (m1.shape[0], m2.shape[1])

        block_size = (tile_width, tile_width, 1)
        grid_size = (
            ((out_shape[1] + tile_width * coarsening - 1) // (tile_width * coarsening)),
            ((out_shape[0] + tile_width - 1) // tile_width),
            1
        )

        # print(f"Matrix 1 shape: {m1.shape}")
        # print(f"Matrix 2 shape: {m2.shape}")
        # print(f"Result shape: {out_shape}")
        # print(f"Grid size: {grid_size}")
        # print(f"Block size: {block_size}")
        # print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")


        res, timing = benchmark_matmul(ctx, kernel, m1, m2, block_size, grid_size, repeat=2, warmup=True)
    finally:
        ctx.pop()
        ctx.detach()

    return res, timing



print("Matmul 8192x8192 with tile size 32x32 and thread coarsening along x:")
for c in [1, 2, 4, 8]:
    res, timing = benchmark(8192, 32, c)
    print(f"coarsening = {c}: {timing:.2f}ms")
Matmul 8192x8192 with tile size 32x32 and thread coarsening along x:
coarsening = 1: 1357.31ms
coarsening = 2: 1323.33ms
coarsening = 4: 1300.67ms
coarsening = 8: 1300.94ms

Coarsening helps, but only a small bit.