cuda-100
  1. Day 5 - Matrix-vector multiplication
  • 100 days of CUDA
  • Day 0 - playing with PyCUDA
  • Day 1 - playing with nvcc
  • Day 2 - RGB to grayscale
  • Day 3 - RGB blur
  • Day 4 - Naive matmul+exercises
  • Day 5 - Matrix-vector multiplication
  • Day 6 - Tiled matmul
  • Day 7 - Tiled matmul experiments
  • Day 8 - Thread coarsening
  • Day 9 - Conv 2D
  • Day 10 - Improving Conv2d performance
  • Day 11 - conv2d with shared memory
  • Day 12 - conv2d with shared memory and halo

On this page

  • kernels/misc/matrix-vector-mul.cu
  • Testing the kernel
  • Numerical stability

Day 5 - Matrix-vector multiplication

Chapter 3 exercise 2.

A matrix-vector multiplication takes an input matrix B and a vector C
and produces one output vector A. Each element of the output vector A
is the dot  product of one row of the input matrix B and C, that is,
A[i] = sum{j} (B[i][j] * C[j]). For simplicity we will handle only square
matrices whose elements are singleprecision floating-point numbers. Write
a matrix-vector multiplication kernel and the host stub function that can
be called with four parameters: pointer to the output matrix, pointer to
the input matrix, pointer to the input vector, and the number of elements
in each dimension. Use one thread to calculate an output vector element.

I will actually implement it for any shape matrices.

import numpy as np
from PIL import Image
from pathlib import Path
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()

device = cuda.Device(0)

print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
cu_file = "kernels/misc/matrix-vector-mul.cu"

kernels/misc/matrix-vector-mul.cu

#include <stdint.h>
#include <stdio.h>

__global__ void mat_vec_mul(float* m, float* v, float* res,
                            uint32_t m_height,
                            uint32_t m_width) {

    int y = blockIdx.y * blockDim.y + threadIdx.y;

    float out;
    if (y < m_height) {
        out = 0;
        for (int i = 0; i < m_width; i++) {
            out += m[y * m_width + i] * v[i];
        }
        res[y] = out;
    }
}
from lovely_numpy import Lo
m = np.random.randn(2000, 1000).astype(np.float32)
v = np.random.randn(1000).astype(np.float32)

np_res = m @ v
Lo(np_res)
array[2000] f32 7.8Kb x∈[-101.809, 89.337] μ=-0.495 σ=31.320

Testing the kernel

BLOCK_SIZE_X = 1
BLOCK_SIZE_Y = 128

assert(len(m.shape) == 2)
assert(len(v.shape) == 1)
assert(m.shape[1] == v.shape[0])

out_dim = m.shape[0]

try:
    ctx = device.make_context()

    mod = SourceModule(Path(cu_file).read_text(),
        options=[
            '-Xcompiler', '-Wall',
            '-Xcompiler', '-Wextra',
            '-Xcompiler', '-Wsign-conversion',
            '-Xcompiler', '-Wcast-qual',
            '-Xcompiler', '-Wunused-parameter',
            '-Xcompiler', '-Wdouble-promotion',
            '-Xcompiler', '-Wformat=2',
            '-Xcompiler', '-Wfloat-equal',
            '-Xcompiler', '-Wshadow'
        ]
        )

    mat_vec_mul = mod.get_function("mat_vec_mul")

    gpu_m = cuda.mem_alloc_like(m)
    gpu_v = cuda.mem_alloc_like(v)

    res = np.empty((out_dim, ), dtype=np.float32)

    gpu_res = cuda.mem_alloc_like(res)


    cuda.memcpy_htod(gpu_m, m)
    cuda.memcpy_htod(gpu_v, v)

    block_size = (BLOCK_SIZE_X, BLOCK_SIZE_Y, 1)
    grid_size = (
        1,
        ((out_dim + BLOCK_SIZE_Y - 1) // BLOCK_SIZE_Y),
        1
    )


    print(f"Matrix shape: {m.shape}")
    print(f"Vector shape: {v.shape}")
    print(f"Grid size: {grid_size}")
    print(f"Block size: {block_size}")
    print(f"Result dimension: {out_dim}")
    print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")

    ctx.synchronize()

    mat_vec_mul(gpu_m, gpu_v, gpu_res, np.uint32(m.shape[0]), np.uint32(m.shape[1]), grid=grid_size, block=block_size)

    ctx.synchronize()

    cuda.memcpy_dtoh(res, gpu_res)
    ctx.synchronize()


finally:
    ctx.pop()
    ctx.detach()

Lo(res)
Matrix shape: (2000, 1000)
Vector shape: (1000,)
Grid size: (1, 16, 1)
Block size: (1, 128, 1)
Result dimension: 2000
Total threads: 2048
array[2000] f32 7.8Kb x∈[-101.809, 89.337] μ=-0.495 σ=31.320
np.isclose(res, np_res)
array([ True,  True,  True, ...,  True,  True,  True], shape=(2000,))
np.isclose(res, np_res).mean()
np.float64(0.9825)

Numerical stability

We have the same numerical error situation as with matmul, but seems to work fine otherwise.