cuda-100
  1. Day 10 - Improving Conv2d performance
  • 100 days of CUDA
  • Day 0 - playing with PyCUDA
  • Day 1 - playing with nvcc
  • Day 2 - RGB to grayscale
  • Day 3 - RGB blur
  • Day 4 - Naive matmul+exercises
  • Day 5 - Matrix-vector multiplication
  • Day 6 - Tiled matmul
  • Day 7 - Tiled matmul experiments
  • Day 8 - Thread coarsening
  • Day 9 - Conv 2D
  • Day 10 - Improving Conv2d performance
  • Day 11 - conv2d with shared memory
  • Day 12 - conv2d with shared memory and halo

On this page

  • kernels/conv2d/conv2d_naive.cu
  • kernels/conv2d/conv2d-z-out.cu
  • Test matrix
  • Run the tests
  • Test results

Day 10 - Improving Conv2d performance

import pandas as pd
import numpy as np
from math import prod
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from lovely_numpy import Lo
from lovely_tensors import monkey_patch; monkey_patch()
from torch import Tensor
from torch.nn.functional import conv2d
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()

device = cuda.Device(0)

print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
cu_file_naive="kernels/conv2d/conv2d_naive.cu"

kernels/conv2d/conv2d_naive.cu

The first implementation from the previous day

#include <stdint.h>
#include <stdio.h>

#include "conv2d-helpers.h"

/* 2D convolution, with padding to valid shape. Channel-first */
__global__ void conv2d_pad(float *in,
                           float *out,
                           float *filter,
                           int h,
                           int w,
                           int in_channels,
                           int out_channels,
                           int filter_size /* Must be an odd number */,
                           float pad) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;

    int filter_r = (filter_size - 1) / 2;


    // In and Out data dimensions:
    // 0 - channel
    // 1 - height
    // 2 - width

    // Filter dimensions:
    // 0 - out channels
    // 1 - in channels
    // 2 - height
    // 3 - width

    if (x >= w || y >= h) return;

#ifdef DEBUG
    if (x == 0 && y == 0) {
        PRINT_INPUTS();
    }
#endif


    // Loop over the output channels
    for (int out_c = 0; out_c < out_channels; out_c++) {
        ACCUM_DTYPE R = 0;

        // Pointer to the 2d slice of the output
        float *sub_output = out + out_c * w * h;

        // Loop over the input channels
        for (int in_c = 0; in_c < in_channels; in_c++) {
            // Pointer to the 2d slice of the filter that corresponds to the active input and output
            // channels
            float *sub_filter = filter + (filter_size * filter_size * in_channels * out_c) +
                                (filter_size * filter_size * in_c);
            // Pinter to the current channel in the input
            float *sub_input = in + (w * h * in_c);

            // Apply the filter to the input or the pad value for outside indices.
            for (int filter_y = 0; filter_y < filter_size; filter_y++) {
                for (int filter_x = 0; filter_x < filter_size; filter_x++) {
                    float v = pad;
                    int input_x = x - filter_r + filter_x;
                    int input_y = y - filter_r + filter_y;

                    if (input_x >= 0 && input_x < w && input_y >= 0 && input_y < h) {
                        v = sub_input[input_y * w + input_x];
                    }
                    R += v * sub_filter[filter_y * filter_size + filter_x];
                }
            }
        }
        sub_output[y * w + x] = R;
    }
}
cu_file_z_out="kernels/conv2d/conv2d-z-out.cu"

kernels/conv2d/conv2d-z-out.cu

This implementation uses separate blocks per output channel

#include <stdint.h>
#include <stdio.h>

#include "conv2d-helpers.h"

// This version uses the z grid dimensions for out channels
// This means each thread has to calculate only one output channel
__global__ void conv2d_pad_z_out(float *in,
                                 float *out,
                                 float *filter,
                                 int h,
                                 int w,
                                 int in_channels,
                                 int out_channels,
                                 int filter_size /* Must be an odd number */,
                                 float pad) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;

    int out_ch = blockIdx.z;

    int filter_r = (filter_size - 1) / 2;

    // In and Out data dimensions:
    // 0 - channel
    // 1 - height
    // 2 - width

    // Filter dimensions:
    // 0 - out channels
    // 1 - in channels
    // 2 - height
    // 3 - width

    if (x >= w || y >= h) return;

#ifdef DEBUG
    if (x == 0 && y == 0) PRINT_INPUTS();
#endif

    // Loop over the output channels

    ACCUM_DTYPE R = 0;

    // // Pointer to the 2d slice of the output
    float *sub_output = out + out_ch * w * h;
    // Loop over the input channels
    for (int in_c = 0; in_c < in_channels; in_c++) {
        // Pointer to the 2d slice of the filter that corresponds to the active input and output
        // channels
        float *sub_filter = filter + (filter_size * filter_size * in_channels * out_ch) +
                            (filter_size * filter_size * in_c);
        // Pinter to the current channel in the input
        float *sub_input = in + (w * h * in_c);

        // Apply the filter to the input or the pad value for outside indices.
        for (int filter_y = 0; filter_y < filter_size; filter_y++) {
            for (int filter_x = 0; filter_x < filter_size; filter_x++) {
                float v = pad;
                int input_x = x - filter_r + filter_x;
                int input_y = y - filter_r + filter_y;

                if (input_x >= 0 && input_x < w && input_y >= 0 && input_y < h) {
                    v = sub_input[input_y * w + input_x];
                }
                R += v * sub_filter[filter_y * filter_size + filter_x];
            }
        }
    }
    sub_output[y * w + x] = R;
}
def benchmark_conv2d_pad(ctx, kernel, input, filter, pad, block_size, grid_size, repeat=10, warmup=True):
    # input, channel-first
    # - Channel
    # - Height
    # - Width
    assert len(input.shape) == 3

    # Filter shape should be
    # - Out channels
    # - In  channels
    # - Height
    # - Width
    assert len(filter.shape) == 4

    in_ch, h, w = input.shape
    out_ch, in_ch2, fh, fw = filter.shape

    assert fh == fw, f"Only square filters supported, got shape={filter.shape}"

    assert in_ch == in_ch2

    out_shape = (out_ch, h, w)

    gpu_input = cuda.mem_alloc_like(input)
    gpu_filter = cuda.mem_alloc_like(filter)

    out = np.empty(out_shape, dtype=np.float32)

    cuda.memcpy_htod(gpu_input, input)
    cuda.memcpy_htod(gpu_filter, filter)
    ctx.synchronize()

    timing=0
    for _ in range(repeat):
        start = cuda.Event()
        end = cuda.Event()

        gpu_out = cuda.mem_alloc_like(out)

        if warmup:
            kernel(gpu_input, gpu_out, gpu_filter,
                   np.int32(h),
                   np.int32(w),
                   np.int32(in_ch),
                   np.int32(out_ch),
                   np.int32(fh),
                   np.float32(pad),
                   grid=grid_size,
                   block=block_size)
            ctx.synchronize()

        start.record()
        kernel(gpu_input, gpu_out, gpu_filter,
               np.int32(h),
               np.int32(w),
               np.int32(in_ch),
               np.int32(out_ch),
               np.int32(fh),
               np.float32(pad),
               grid=grid_size,
               block=block_size)
        end.record()
        end.synchronize()

        timing += end.time_since(start)
    timing /= repeat

    cuda.memcpy_dtoh(out, gpu_out)
    return out, timing;

Test matrix

Sample some random shapes (not too big though) for input/output/channels/filter sizes

in_chan_range = [1, 3, 8, 32, 128, 512]
out_chan_range = [1, 4, 8, 32, 128, 512]

filter_size = [1, 3, 5, 9]

img_size_range = [64, 128, 256, 512, 1024]

# Let's sample from the available options.
n_samples = 50


# Generate all possible combinations
combinations = []
for in_ch in in_chan_range:
    for out_ch in out_chan_range:
        for fs in filter_size:
            for img_size in img_size_range:
                    n = in_ch * out_ch * img_size * img_size

                    # Skip combinatoins that are too large
                    if n < 1024*1024*32*32:
                        combinations.append((in_ch, out_ch, fs, img_size))

n_samples = min(n_samples, len(combinations))
sampled_combinations = np.random.choice(len(combinations), size=n_samples, replace=False)
test_cases = [combinations[i] for i in sampled_combinations]

Run the tests

import warn_options
tile_width = 32

data = []

# test_cases = [(3, 4, 32,32)]

ctx = device.make_context()
try:
    mod_naive = SourceModule(
        Path(cu_file_naive).read_text(),
        options=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
        include_dirs=[str(Path(cu_file_naive).parent.absolute())]
        )

    mod_z_out = SourceModule(
        Path(cu_file_z_out).read_text(),
        options=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
        include_dirs=[str(Path(cu_file_z_out).parent.absolute())])


    kernels = {
        "conv2d_pad": mod_naive.get_function("conv2d_pad"),
        "conv2d_pad_z_out":mod_z_out.get_function("conv2d_pad_z_out"),
    }

    for tc in tqdm(test_cases):
        ch_in, ch_out, fs, pixels = tc

        array_in = np.random.randn(ch_in, pixels, pixels).astype(np.float32)
        filter = np.random.randn(ch_out, ch_in, fs, fs).astype(np.float32)

        torch_out = conv2d(Tensor(array_in), Tensor(filter), padding="same")

        timings = {}

        for kernel_name, kernel in kernels.items():

            block_size = (tile_width,tile_width,1)
            grid_size = (
                ((pixels + tile_width - 1) // tile_width),
                ((pixels + tile_width - 1) // tile_width),
                ch_out if kernel_name == "conv2d_pad_z_out" else 1
            )


            out, timing = benchmark_conv2d_pad(ctx, kernel, array_in, filter, 0, block_size, grid_size, repeat=5, warmup=True)

            if np.isclose(out, torch_out).mean() < 0.8:
                print("### Result mismatch ###")
                print(f"Kernel: {kernel_name}")
                print(f"Input shape: {array_in.shape}")
                print(f"Filter shape: {filter.shape}")
                print(f"Result shape: {(filter.shape[0], array_in.shape[1], array_in.shape[2])}")
                print(f"Grid size: {grid_size}")
                print(f"Block size: {block_size}")
                print(f"Total threads: {prod((*grid_size, *block_size))}")

            timings[kernel_name] = timing

        data.append({
            'in_ch': ch_in,
            'out_ch': ch_out,
            'filter_size': fs,
            'img_size': pixels,
            # 'kernel': kernel_name,
        } | timings)


finally:
    ctx.pop()
    ctx.detach()

results = pd.DataFrame(data)

Test results

# Sort by conv2d_pad timing
results_sorted = results.sort_values(by='conv2d_pad')

# Create a plot comparing the two kernels
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12, 8))
sns.set_style("whitegrid")

# Create labels for x-axis that include dimensions
results_sorted['dimensions'] = results_sorted.apply(
    lambda row: f"{int(row['img_size'])}×{int(row['img_size'])}×{int(row['in_ch'])} -> {int(row['out_ch'])}, f:{int(row['filter_size'])}×{int(row['filter_size'])}",
    axis=1
)

# Melt the dataframe to get it in the right format for seaborn
melted_results = pd.melt(
    results_sorted,
    id_vars=['in_ch', 'out_ch', 'filter_size', 'img_size', 'dimensions'],
    value_vars=['conv2d_pad', 'conv2d_pad_z_out'],
    var_name='kernel',
    value_name='time'
)

# Create a barplot with dimensions as x-ticks
ax = sns.barplot(x='dimensions', y='time', hue='kernel', data=melted_results)

# Add labels
plt.xlabel('Input and Filter Dimensions')
plt.ylabel('Time (seconds)')
plt.title('Performance Comparison of conv2d_pad vs conv2d_pad_z_out')
plt.xticks(rotation=90)

# Add a legend
plt.legend(title='Kernel')

# Show the plot
plt.tight_layout()
plt.show()

# Also display the sorted results table
results_sorted

in_ch out_ch filter_size img_size conv2d_pad conv2d_pad_z_out dimensions
49 1 1 1 256 0.019046 0.019046 256×256×1 -> 1, f:1×1
8 8 1 1 128 0.021709 0.020890 128×128×8 -> 1, f:1×1
45 1 1 1 512 0.022528 0.020275 512×512×1 -> 1, f:1×1
27 3 8 1 64 0.026419 0.014336 64×64×3 -> 8, f:1×1
16 3 4 1 256 0.031744 0.033382 256×256×3 -> 4, f:1×1
17 1 4 5 64 0.035840 0.022118 64×64×1 -> 4, f:5×5
38 1 1 1 1024 0.056525 0.055706 1024×1024×1 -> 1, f:1×1
40 32 1 3 128 0.087450 0.084378 128×128×32 -> 1, f:3×3
4 1 8 5 256 0.095232 0.082739 256×256×1 -> 8, f:5×5
37 3 8 3 256 0.107520 0.096256 256×256×3 -> 8, f:3×3
36 1 8 9 128 0.117555 0.055910 128×128×1 -> 8, f:9×9
5 1 4 9 256 0.119194 0.094003 256×256×1 -> 4, f:9×9
43 8 4 1 512 0.132710 0.177357 512×512×8 -> 4, f:1×1
19 128 1 1 256 0.204186 0.203981 256×256×128 -> 1, f:1×1
39 8 8 1 512 0.231834 0.336896 512×512×8 -> 8, f:1×1
13 1 8 5 512 0.247194 0.273203 512×512×1 -> 8, f:5×5
24 3 8 3 512 0.292659 0.353075 512×512×3 -> 8, f:3×3
10 3 8 9 64 0.312934 0.050995 64×64×3 -> 8, f:9×9
0 8 4 9 64 0.417997 0.113254 64×64×8 -> 4, f:9×9
32 1 512 1 128 0.425779 0.243507 128×128×1 -> 512, f:1×1
31 128 4 1 256 0.743834 0.588595 256×256×128 -> 4, f:1×1
26 3 512 1 64 0.885965 0.120422 64×64×3 -> 512, f:1×1
18 8 4 5 512 0.916275 0.968294 512×512×8 -> 4, f:5×5
1 1 512 3 128 1.092198 0.514048 128×128×1 -> 512, f:3×3
46 3 32 3 512 1.098752 1.355366 512×512×3 -> 32, f:3×3
9 8 4 3 1024 1.419674 2.212659 1024×1024×8 -> 4, f:3×3
21 1 128 1 1024 1.955635 5.893126 1024×1024×1 -> 128, f:1×1
3 32 4 1 1024 2.050253 2.409267 1024×1024×32 -> 4, f:1×1
30 512 1 3 256 2.841395 2.728755 256×256×512 -> 1, f:3×3
33 8 32 9 64 3.557990 0.306995 64×64×8 -> 32, f:9×9
14 8 512 1 256 4.713267 4.175462 256×256×8 -> 512, f:1×1
47 32 32 5 128 5.372723 1.687757 128×128×32 -> 32, f:5×5
11 128 4 5 256 5.690982 3.754189 256×256×128 -> 4, f:5×5
6 512 8 1 256 6.060646 4.183040 256×256×512 -> 8, f:1×1
34 3 8 9 1024 7.197081 8.368333 1024×1024×3 -> 8, f:9×9
25 3 32 9 512 7.662797 7.265075 512×512×3 -> 32, f:9×9
12 3 32 5 1024 10.742989 14.039655 1024×1024×3 -> 32, f:5×5
22 512 8 3 64 11.946598 1.192960 64×64×512 -> 8, f:3×3
28 1 128 5 1024 14.895719 19.455181 1024×1024×1 -> 128, f:5×5
23 128 4 9 256 14.967603 10.709811 256×256×128 -> 4, f:9×9
35 128 8 3 512 18.949939 16.436429 512×512×128 -> 8, f:3×3
7 3 512 3 512 20.223590 24.182989 512×512×3 -> 512, f:3×3
20 128 32 3 256 23.864934 15.436800 256×256×128 -> 32, f:3×3
44 512 4 5 256 25.573376 16.888422 256×256×512 -> 4, f:5×5
48 128 8 9 256 30.146976 20.129997 256×256×128 -> 8, f:9×9
15 32 8 5 1024 35.380025 33.746944 1024×1024×32 -> 8, f:5×5
2 32 128 9 128 59.110400 19.480167 128×128×32 -> 128, f:9×9
42 128 4 5 1024 72.265121 66.161868 1024×1024×128 -> 4, f:5×5
29 128 512 3 64 155.834366 12.261376 64×64×128 -> 512, f:3×3
41 32 512 9 64 236.589465 19.932570 64×64×32 -> 512, f:9×9