cuda-100
  1. Day 11 - conv2d with shared memory
  • 100 days of CUDA
  • Day 0 - playing with PyCUDA
  • Day 1 - playing with nvcc
  • Day 2 - RGB to grayscale
  • Day 3 - RGB blur
  • Day 4 - Naive matmul+exercises
  • Day 5 - Matrix-vector multiplication
  • Day 6 - Tiled matmul
  • Day 7 - Tiled matmul experiments
  • Day 8 - Thread coarsening
  • Day 9 - Conv 2D
  • Day 10 - Improving Conv2d performance
  • Day 11 - conv2d with shared memory
  • Day 12 - conv2d with shared memory and halo

On this page

  • kernels/conv2d/conv2d-z-out-shared.cu
  • Results

Day 11 - conv2d with shared memory

import pandas as pd
import numpy as np
from math import prod
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
from lovely_numpy import Lo
from lovely_tensors import monkey_patch; monkey_patch()
from torch import Tensor
from torch.nn.functional import conv2d

import warn_options
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()

device = cuda.Device(0)

print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
cu_file="kernels/conv2d/conv2d-z-out-shared.cu"

kernels/conv2d/conv2d-z-out-shared.cu

#include <stdint.h>
#include <stdio.h>

#include "conv2d-helpers.h"

// This version copies each input channel into shared memory before performing the
// convolution. Grid Z is used for output channels, so each thread only handles one
// output channel
__global__ void conv2d_pad_z_out_shared(float *in,
                                        float *out,
                                        float *filter,
                                        int h,
                                        int w,
                                        int in_channels,
                                        int out_channels,
                                        int filter_size /* Must be an odd number */,
                                        float pad) {
    int x = blockIdx.x * blockDim.x + threadIdx.x;
    int y = blockIdx.y * blockDim.y + threadIdx.y;

    int out_ch = blockIdx.z;

    int filter_r = (filter_size - 1) / 2;

    extern __shared__ float cell[];

    // In and Out data dimensions:
    // 0 - channel
    // 1 - height
    // 2 - width

    // Filter dimensions:
    // 0 - out channels
    // 1 - in channels
    // 2 - height
    // 3 - width

    if (x >= w || y >= h) return;

#ifdef DEBUG
    if (x == 0 && y == 0) PRINT_INPUTS();
#endif

    // Loop over the output channels

    // // Pointer to the 2d slice of the output

    float *sub_output = out + out_ch * w * h;
    ACCUM_DTYPE R = 0;
    // Loop over the input channels
    for (int in_c = 0; in_c < in_channels; in_c++) {
        // Pointer to the 2d slice of the filter that corresponds to the active input and output
        // channels
        float *sub_filter = filter + (filter_size * filter_size * in_channels * out_ch) +
                            (filter_size * filter_size * in_c);
        // Pinter to the current channel in the input
        float *sub_input = in + (w * h * in_c);

        cell[threadIdx.y * blockDim.x + threadIdx.x] = sub_input[y * w + x];
        __syncthreads();  // Wait for all threads to load the input

        // Apply the filter to the input or the pad value for outside indices.
        for (int filter_y = 0; filter_y < filter_size; filter_y++) {
            for (int filter_x = 0; filter_x < filter_size; filter_x++) {
                int tile_x = threadIdx.x - filter_r + filter_x;
                int tile_y = threadIdx.y - filter_r + filter_y;

                int input_x = x - filter_r + filter_x;
                int input_y = y - filter_r + filter_y;

                if (tile_x >= 0 && tile_x < blockDim.x && tile_y >= 0 && tile_y < blockDim.y) {
                    R += cell[tile_y * blockDim.x + tile_x] *
                         sub_filter[filter_y * filter_size + filter_x];
                } else if (input_x >= 0 && input_x < w && input_y >= 0 && input_y < h) {
                    R += sub_input[input_y * w + input_x] *
                         sub_filter[filter_y * filter_size + filter_x];
                } else {
                    R += pad * sub_filter[filter_y * filter_size + filter_x];
                }
            }
        }

        __syncthreads();  // Wait for all threads to complete before we load the next input
    }

    sub_output[y * w + x] = R;
}
def benchmark_conv2d_pad(ctx, kernel, input, filter, pad, block_size, grid_size, shared=None, repeat=10, warmup=True):
    # input, channel-first
    # - Channel
    # - Height
    # - Width
    assert len(input.shape) == 3

    # Filter shape should be
    # - Out channels
    # - In  channels
    # - Height
    # - Width
    assert len(filter.shape) == 4

    in_ch, h, w = input.shape
    out_ch, in_ch2, fh, fw = filter.shape

    assert fh == fw, f"Only square filters supported, got shape={filter.shape}"

    assert in_ch == in_ch2

    out_shape = (out_ch, h, w)
    # print(f"shared = {shared}")
    # print(f"out_shape={out_shape}")

    gpu_input = cuda.mem_alloc_like(input)
    gpu_filter = cuda.mem_alloc_like(filter)

    out = np.empty(out_shape, dtype=np.float32)

    cuda.memcpy_htod(gpu_input, input)
    cuda.memcpy_htod(gpu_filter, filter)
    ctx.synchronize()

    timing=0
    for _ in range(repeat):
        start = cuda.Event()
        end = cuda.Event()

        gpu_out = cuda.mem_alloc_like(out)

        if warmup:
            kernel(gpu_input, gpu_out, gpu_filter,
                   np.int32(h),
                   np.int32(w),
                   np.int32(in_ch),
                   np.int32(out_ch),
                   np.int32(fh),
                   np.float32(pad),
                   grid=grid_size,
                   block=block_size,
                   shared=shared
                   )
            ctx.synchronize()

        start.record()
        kernel(gpu_input, gpu_out, gpu_filter,
               np.int32(h),
               np.int32(w),
               np.int32(in_ch),
               np.int32(out_ch),
               np.int32(fh),
               np.float32(pad),
               grid=grid_size,
               block=block_size,
               shared=shared
               )
        end.record()
        end.synchronize()

        timing += end.time_since(start)
    timing /= repeat

    cuda.memcpy_dtoh(out, gpu_out)
    return out, timing;
in_chan_range = [1, 3, 8, 32, 128, 512]
out_chan_range = [1, 4, 8, 32, 128, 512]

filter_size = [1, 3, 5]

img_size_range = [64, 128, 256, 512, 1024]

# Let's sample from the available options.
n_samples = 50


# Generate all possible combinations
combinations = []
for in_ch in in_chan_range:
    for out_ch in out_chan_range:
        for fs in filter_size:
            for img_size in img_size_range:
                    n = in_ch * out_ch * img_size * img_size

                    # Skip combinatoins that are too large
                    if n < 1024*1024*32*32:
                        combinations.append((in_ch, out_ch, fs, img_size))

n_samples = min(n_samples, len(combinations))
sampled_combinations = np.random.choice(len(combinations), size=n_samples, replace=False)
test_cases = [combinations[i] for i in sampled_combinations]
tile_width = 32

data = []

# test_cases = [(512, 8, 9, 64)]

ctx = device.make_context()
try:
    mod = SourceModule(
        Path(cu_file).read_text(),
        options=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
        include_dirs=[str(Path(cu_file).parent.absolute())]
    )

    mod_z_out = SourceModule(
        Path("kernels/conv2d/conv2d-z-out.cu").read_text(),
        options=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
        include_dirs=[str(Path(cu_file).parent.absolute())]
    )

    mod_naive = SourceModule(
        Path("kernels/conv2d/conv2d_naive.cu").read_text(),
        options=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
        include_dirs=[str(Path(cu_file).parent.absolute())]
    )

    kernels = {
        "conv2d_pad_z_out_shared": mod.get_function("conv2d_pad_z_out_shared"),
        "conv2d_pad": mod_naive.get_function("conv2d_pad"),
        "conv2d_pad_z_out": mod_z_out.get_function("conv2d_pad_z_out")
    }

    for tc in tqdm(test_cases):
        ch_in, ch_out, fs, pixels = tc

        array_in = np.random.randn(ch_in, pixels, pixels).astype(np.float32)
        filter = np.random.randn(ch_out, ch_in, fs, fs).astype(np.float32)

        torch_out = conv2d(Tensor(array_in), Tensor(filter), padding="same")

        timings = {}

        for kernel_name, kernel in kernels.items():

            block_size = (tile_width, tile_width, 1)
            grid_size = (((pixels+tile_width-1) // tile_width), ((pixels+tile_width-1) // tile_width),
                         1 if kernel_name == "conv2d_pad" else ch_out)

            out, timing = benchmark_conv2d_pad(
                ctx=ctx,
                kernel=kernel,
                input=array_in,
                filter=filter,
                pad=0,
                block_size=block_size,
                grid_size=grid_size,
                shared=tile_width * tile_width * 4 if kernel_name == "conv2d_pad_z_out_shared" else 0,
                repeat=5,
                warmup=True
            )

            if np.isclose(out, torch_out).mean() < 0.8:
                print("### Result mismatch ###")
                print(f"Kernel: {kernel_name}")
                print(f"Input shape: {array_in.shape}")
                print(f"Filter shape: {filter.shape}")
                print(f"Result shape: {(filter.shape[0], array_in.shape[1], array_in.shape[2])}")
                print(f"Grid size: {grid_size}")
                print(f"Block size: {block_size}")
                print(f"Total threads: {prod((*grid_size, *block_size))}")

            timings[kernel_name] = timing
            # time.sleep(10)

        data.append({
            'in_ch': ch_in,
            'out_ch': ch_out,
            'filter_size': fs,
            'img_size': pixels,
            # 'kernel': kernel_name,
        } | timings)

finally:
    ctx.pop()
    ctx.detach()

results = pd.DataFrame(data)

Results

# Sort by conv2d_pad timing
results_sorted = results.sort_values(by='conv2d_pad')

# Create a plot comparing the two kernels
import matplotlib.pyplot as plt
import seaborn as sns

# Create labels for x-axis that include dimensions
results_sorted['dimensions'] = results_sorted.apply(
    lambda row: f"{int(row['img_size'])}×{int(row['img_size'])}×{int(row['in_ch'])} -> {int(row['out_ch'])}, f:{int(row['filter_size'])}×{int(row['filter_size'])}",
    axis=1
)

# Melt the dataframe to get it in the right format for seaborn
melted_results = pd.melt(
    results_sorted,
    id_vars=['in_ch', 'out_ch', 'filter_size', 'img_size', 'dimensions'],
    value_vars=['conv2d_pad', 'conv2d_pad_z_out', 'conv2d_pad_z_out_shared'],
    var_name='kernel',
    value_name='time'
)

# Split the data into two halves based on timing
midpoint = len(results_sorted) // 2
faster_results = melted_results[melted_results['dimensions'].isin(results_sorted['dimensions'][:midpoint])]
slower_results = melted_results[melted_results['dimensions'].isin(results_sorted['dimensions'][midpoint:])]

# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 16))

# Plot faster results in the first subplot
sns.barplot(x='dimensions', y='time', hue='kernel', data=faster_results, ax=ax1)
ax1.set_xlabel('')
ax1.set_ylabel('Time (ms)')
ax1.set_title('Performance Comparison - Faster Results')
ax1.tick_params(axis='x', rotation=90)
ax1.legend(title='Kernel')

# Plot slower results in the second subplot
sns.barplot(x='dimensions', y='time', hue='kernel', data=slower_results, ax=ax2)
ax2.set_xlabel('Input and Filter Dimensions')
ax2.set_ylabel('Time (ms)')
ax2.set_title('Performance Comparison - Slower Results')
ax2.tick_params(axis='x', rotation=90)
ax2.legend(title='Kernel')

# Adjust layout
plt.tight_layout()
plt.show()

# Also display the sorted results table
results_sorted

in_ch out_ch filter_size img_size conv2d_pad_z_out_shared conv2d_pad conv2d_pad_z_out dimensions
18 1 4 1 128 0.015974 0.014950 0.014336 128×128×1 -> 4, f:1×1
47 1 1 3 64 0.016794 0.018637 0.014541 64×64×1 -> 1, f:3×3
21 3 1 3 256 0.037274 0.027034 0.025805 256×256×3 -> 1, f:3×3
23 1 8 3 64 0.022733 0.033587 0.020480 64×64×1 -> 8, f:3×3
43 3 8 3 64 0.025395 0.058573 0.019866 64×64×3 -> 8, f:3×3
13 3 1 3 512 0.092570 0.058778 0.058163 512×512×3 -> 1, f:3×3
12 32 1 1 256 0.106086 0.062259 0.064922 256×256×32 -> 1, f:1×1
24 1 4 3 512 0.120627 0.073114 0.084378 512×512×1 -> 4, f:3×3
30 1 32 3 64 0.030106 0.080077 0.109363 64×64×1 -> 32, f:3×3
33 32 4 1 128 0.085811 0.086221 0.056525 128×128×32 -> 4, f:1×1
48 3 32 1 256 0.186368 0.126362 0.129638 256×256×3 -> 32, f:1×1
20 1 32 5 128 0.128205 0.169984 0.084173 128×128×1 -> 32, f:5×5
0 1 128 1 128 0.093184 0.192512 0.073523 128×128×1 -> 128, f:1×1
26 3 128 1 128 0.180838 0.233882 0.126566 128×128×3 -> 128, f:1×1
16 1 512 1 128 0.319693 0.369664 0.244326 128×128×1 -> 512, f:1×1
14 8 4 3 512 0.820224 0.403046 0.499302 512×512×8 -> 4, f:3×3
34 32 32 1 128 0.395059 0.556237 0.237363 128×128×32 -> 32, f:1×1
32 3 128 3 64 0.141722 0.725811 0.089702 64×64×3 -> 128, f:3×3
44 512 1 1 256 1.329971 0.743629 0.754483 256×256×512 -> 1, f:1×1
37 8 8 1 1024 1.926963 0.826368 1.283072 1024×1024×8 -> 8, f:1×1
25 8 4 5 512 1.817190 0.915046 0.969114 512×512×8 -> 4, f:5×5
38 32 32 1 256 1.848525 1.133158 1.107354 256×256×32 -> 32, f:1×1
1 3 32 1 1024 3.284378 1.242317 2.644992 1024×1024×3 -> 32, f:1×1
22 1 128 5 256 1.764147 1.248870 1.013350 256×256×1 -> 128, f:5×5
3 512 1 3 64 2.400461 1.253990 1.257882 64×64×512 -> 1, f:3×3
28 32 8 5 128 0.943514 1.274675 0.467763 128×128×32 -> 8, f:5×5
45 512 1 3 128 2.519245 1.280614 1.275904 128×128×512 -> 1, f:3×3
7 3 128 1 512 3.371622 1.326080 2.317722 512×512×3 -> 128, f:1×1
17 3 512 1 256 2.973082 1.753088 1.833779 256×256×3 -> 512, f:1×1
41 3 128 5 64 0.328909 1.810022 0.183296 64×64×3 -> 128, f:5×5
9 3 128 5 128 1.266893 1.844634 0.679117 128×128×3 -> 128, f:5×5
15 32 4 3 512 3.226624 1.954816 1.873306 512×512×32 -> 4, f:3×3
40 8 512 1 64 0.394240 2.289254 0.242483 64×64×8 -> 512, f:1×1
49 32 8 5 256 3.610214 2.547712 1.794867 256×256×32 -> 8, f:5×5
8 3 8 5 1024 5.853798 2.647040 3.288064 1024×1024×3 -> 8, f:5×5
35 512 8 1 64 0.688538 2.897920 0.348160 64×64×512 -> 8, f:1×1
46 128 32 1 128 1.876787 3.093299 0.982426 128×128×128 -> 32, f:1×1
10 1 32 5 1024 9.511936 3.605504 4.901683 1024×1024×1 -> 32, f:5×5
29 3 32 3 1024 11.349197 4.278477 6.614630 1024×1024×3 -> 32, f:3×3
19 8 512 1 256 6.478848 4.831232 4.056064 256×256×8 -> 512, f:1×1
27 32 32 5 64 0.920986 5.430477 0.462234 64×64×32 -> 32, f:5×5
5 128 8 5 64 1.239859 5.528781 0.618906 64×64×128 -> 8, f:5×5
42 1 512 3 512 14.623539 6.328320 9.598566 512×512×1 -> 512, f:3×3
6 8 8 5 1024 16.636518 7.377920 8.506163 1024×1024×8 -> 8, f:5×5
39 32 512 1 128 6.386483 9.455616 3.391898 128×128×32 -> 512, f:1×1
36 32 128 3 128 5.712077 9.837773 3.195904 128×128×32 -> 128, f:3×3
4 32 128 3 256 27.863654 22.527795 15.678259 256×256×32 -> 128, f:3×3
31 512 8 5 128 17.193369 26.259456 7.303168 128×128×512 -> 8, f:5×5
2 32 512 3 64 5.996749 38.739558 3.200000 64×64×32 -> 512, f:3×3
11 512 4 5 512 134.970367 81.715004 68.868095 512×512×512 -> 4, f:5×5

For some reason, the version with shared memory is actually slower. Not entirely sure why, because it looks correct