import pandas as pd
import numpy as np
from math import prod
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
Day 10 - Improving Conv2d performance
from lovely_numpy import Lo
from lovely_tensors import monkey_patch; monkey_patch()
from torch import Tensor
from torch.nn.functional import conv2d
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()
= cuda.Device(0)
device
print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
="kernels/conv2d/conv2d_naive.cu" cu_file_naive
kernels/conv2d/conv2d_naive.cu
The first implementation from the previous day
#include <stdint.h>
#include <stdio.h>
#include "conv2d-helpers.h"
/* 2D convolution, with padding to valid shape. Channel-first */
void conv2d_pad(float *in,
__global__ float *out,
float *filter,
int h,
int w,
int in_channels,
int out_channels,
int filter_size /* Must be an odd number */,
float pad) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int filter_r = (filter_size - 1) / 2;
// In and Out data dimensions:
// 0 - channel
// 1 - height
// 2 - width
// Filter dimensions:
// 0 - out channels
// 1 - in channels
// 2 - height
// 3 - width
if (x >= w || y >= h) return;
#ifdef DEBUG
if (x == 0 && y == 0) {
();
PRINT_INPUTS}
#endif
// Loop over the output channels
for (int out_c = 0; out_c < out_channels; out_c++) {
= 0;
ACCUM_DTYPE R
// Pointer to the 2d slice of the output
float *sub_output = out + out_c * w * h;
// Loop over the input channels
for (int in_c = 0; in_c < in_channels; in_c++) {
// Pointer to the 2d slice of the filter that corresponds to the active input and output
// channels
float *sub_filter = filter + (filter_size * filter_size * in_channels * out_c) +
(filter_size * filter_size * in_c);
// Pinter to the current channel in the input
float *sub_input = in + (w * h * in_c);
// Apply the filter to the input or the pad value for outside indices.
for (int filter_y = 0; filter_y < filter_size; filter_y++) {
for (int filter_x = 0; filter_x < filter_size; filter_x++) {
float v = pad;
int input_x = x - filter_r + filter_x;
int input_y = y - filter_r + filter_y;
if (input_x >= 0 && input_x < w && input_y >= 0 && input_y < h) {
= sub_input[input_y * w + input_x];
v }
+= v * sub_filter[filter_y * filter_size + filter_x];
R }
}
}
[y * w + x] = R;
sub_output}
}
="kernels/conv2d/conv2d-z-out.cu" cu_file_z_out
kernels/conv2d/conv2d-z-out.cu
This implementation uses separate blocks per output channel
#include <stdint.h>
#include <stdio.h>
#include "conv2d-helpers.h"
// This version uses the z grid dimensions for out channels
// This means each thread has to calculate only one output channel
void conv2d_pad_z_out(float *in,
__global__ float *out,
float *filter,
int h,
int w,
int in_channels,
int out_channels,
int filter_size /* Must be an odd number */,
float pad) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int out_ch = blockIdx.z;
int filter_r = (filter_size - 1) / 2;
// In and Out data dimensions:
// 0 - channel
// 1 - height
// 2 - width
// Filter dimensions:
// 0 - out channels
// 1 - in channels
// 2 - height
// 3 - width
if (x >= w || y >= h) return;
#ifdef DEBUG
if (x == 0 && y == 0) PRINT_INPUTS();
#endif
// Loop over the output channels
= 0;
ACCUM_DTYPE R
// // Pointer to the 2d slice of the output
float *sub_output = out + out_ch * w * h;
// Loop over the input channels
for (int in_c = 0; in_c < in_channels; in_c++) {
// Pointer to the 2d slice of the filter that corresponds to the active input and output
// channels
float *sub_filter = filter + (filter_size * filter_size * in_channels * out_ch) +
(filter_size * filter_size * in_c);
// Pinter to the current channel in the input
float *sub_input = in + (w * h * in_c);
// Apply the filter to the input or the pad value for outside indices.
for (int filter_y = 0; filter_y < filter_size; filter_y++) {
for (int filter_x = 0; filter_x < filter_size; filter_x++) {
float v = pad;
int input_x = x - filter_r + filter_x;
int input_y = y - filter_r + filter_y;
if (input_x >= 0 && input_x < w && input_y >= 0 && input_y < h) {
= sub_input[input_y * w + input_x];
v }
+= v * sub_filter[filter_y * filter_size + filter_x];
R }
}
}
[y * w + x] = R;
sub_output}
def benchmark_conv2d_pad(ctx, kernel, input, filter, pad, block_size, grid_size, repeat=10, warmup=True):
# input, channel-first
# - Channel
# - Height
# - Width
assert len(input.shape) == 3
# Filter shape should be
# - Out channels
# - In channels
# - Height
# - Width
assert len(filter.shape) == 4
= input.shape
in_ch, h, w = filter.shape
out_ch, in_ch2, fh, fw
assert fh == fw, f"Only square filters supported, got shape={filter.shape}"
assert in_ch == in_ch2
= (out_ch, h, w)
out_shape
= cuda.mem_alloc_like(input)
gpu_input = cuda.mem_alloc_like(filter)
gpu_filter
= np.empty(out_shape, dtype=np.float32)
out
input)
cuda.memcpy_htod(gpu_input, filter)
cuda.memcpy_htod(gpu_filter,
ctx.synchronize()
=0
timingfor _ in range(repeat):
= cuda.Event()
start = cuda.Event()
end
= cuda.mem_alloc_like(out)
gpu_out
if warmup:
kernel(gpu_input, gpu_out, gpu_filter,
np.int32(h),
np.int32(w),
np.int32(in_ch),
np.int32(out_ch),
np.int32(fh),
np.float32(pad),=grid_size,
grid=block_size)
block
ctx.synchronize()
start.record()
kernel(gpu_input, gpu_out, gpu_filter,
np.int32(h),
np.int32(w),
np.int32(in_ch),
np.int32(out_ch),
np.int32(fh),
np.float32(pad),=grid_size,
grid=block_size)
block
end.record()
end.synchronize()
+= end.time_since(start)
timing /= repeat
timing
cuda.memcpy_dtoh(out, gpu_out)return out, timing;
Test matrix
Sample some random shapes (not too big though) for input/output/channels/filter sizes
= [1, 3, 8, 32, 128, 512]
in_chan_range = [1, 4, 8, 32, 128, 512]
out_chan_range
= [1, 3, 5, 9]
filter_size
= [64, 128, 256, 512, 1024]
img_size_range
# Let's sample from the available options.
= 50
n_samples
# Generate all possible combinations
= []
combinations for in_ch in in_chan_range:
for out_ch in out_chan_range:
for fs in filter_size:
for img_size in img_size_range:
= in_ch * out_ch * img_size * img_size
n
# Skip combinatoins that are too large
if n < 1024*1024*32*32:
combinations.append((in_ch, out_ch, fs, img_size))
= min(n_samples, len(combinations))
n_samples = np.random.choice(len(combinations), size=n_samples, replace=False)
sampled_combinations = [combinations[i] for i in sampled_combinations] test_cases
Run the tests
import warn_options
= 32
tile_width
= []
data
# test_cases = [(3, 4, 32,32)]
= device.make_context()
ctx try:
= SourceModule(
mod_naive
Path(cu_file_naive).read_text(),=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
options=[str(Path(cu_file_naive).parent.absolute())]
include_dirs
)
= SourceModule(
mod_z_out
Path(cu_file_z_out).read_text(),=warn_options.warn_options + ["-DACCUM_DTYPE=float"],
options=[str(Path(cu_file_z_out).parent.absolute())])
include_dirs
= {
kernels "conv2d_pad": mod_naive.get_function("conv2d_pad"),
"conv2d_pad_z_out":mod_z_out.get_function("conv2d_pad_z_out"),
}
for tc in tqdm(test_cases):
= tc
ch_in, ch_out, fs, pixels
= np.random.randn(ch_in, pixels, pixels).astype(np.float32)
array_in filter = np.random.randn(ch_out, ch_in, fs, fs).astype(np.float32)
= conv2d(Tensor(array_in), Tensor(filter), padding="same")
torch_out
= {}
timings
for kernel_name, kernel in kernels.items():
= (tile_width,tile_width,1)
block_size = (
grid_size + tile_width - 1) // tile_width),
((pixels + tile_width - 1) // tile_width),
((pixels if kernel_name == "conv2d_pad_z_out" else 1
ch_out
)
= benchmark_conv2d_pad(ctx, kernel, array_in, filter, 0, block_size, grid_size, repeat=5, warmup=True)
out, timing
if np.isclose(out, torch_out).mean() < 0.8:
print("### Result mismatch ###")
print(f"Kernel: {kernel_name}")
print(f"Input shape: {array_in.shape}")
print(f"Filter shape: {filter.shape}")
print(f"Result shape: {(filter.shape[0], array_in.shape[1], array_in.shape[2])}")
print(f"Grid size: {grid_size}")
print(f"Block size: {block_size}")
print(f"Total threads: {prod((*grid_size, *block_size))}")
= timing
timings[kernel_name]
data.append({'in_ch': ch_in,
'out_ch': ch_out,
'filter_size': fs,
'img_size': pixels,
# 'kernel': kernel_name,
| timings)
}
finally:
ctx.pop()
ctx.detach()
= pd.DataFrame(data) results
Test results
# Sort by conv2d_pad timing
= results.sort_values(by='conv2d_pad')
results_sorted
# Create a plot comparing the two kernels
import matplotlib.pyplot as plt
import seaborn as sns
=(12, 8))
plt.figure(figsize"whitegrid")
sns.set_style(
# Create labels for x-axis that include dimensions
'dimensions'] = results_sorted.apply(
results_sorted[lambda row: f"{int(row['img_size'])}×{int(row['img_size'])}×{int(row['in_ch'])} -> {int(row['out_ch'])}, f:{int(row['filter_size'])}×{int(row['filter_size'])}",
=1
axis
)
# Melt the dataframe to get it in the right format for seaborn
= pd.melt(
melted_results
results_sorted,=['in_ch', 'out_ch', 'filter_size', 'img_size', 'dimensions'],
id_vars=['conv2d_pad', 'conv2d_pad_z_out'],
value_vars='kernel',
var_name='time'
value_name
)
# Create a barplot with dimensions as x-ticks
= sns.barplot(x='dimensions', y='time', hue='kernel', data=melted_results)
ax
# Add labels
'Input and Filter Dimensions')
plt.xlabel('Time (seconds)')
plt.ylabel('Performance Comparison of conv2d_pad vs conv2d_pad_z_out')
plt.title(=90)
plt.xticks(rotation
# Add a legend
='Kernel')
plt.legend(title
# Show the plot
plt.tight_layout()
plt.show()
# Also display the sorted results table
results_sorted
in_ch | out_ch | filter_size | img_size | conv2d_pad | conv2d_pad_z_out | dimensions | |
---|---|---|---|---|---|---|---|
49 | 1 | 1 | 1 | 256 | 0.019046 | 0.019046 | 256×256×1 -> 1, f:1×1 |
8 | 8 | 1 | 1 | 128 | 0.021709 | 0.020890 | 128×128×8 -> 1, f:1×1 |
45 | 1 | 1 | 1 | 512 | 0.022528 | 0.020275 | 512×512×1 -> 1, f:1×1 |
27 | 3 | 8 | 1 | 64 | 0.026419 | 0.014336 | 64×64×3 -> 8, f:1×1 |
16 | 3 | 4 | 1 | 256 | 0.031744 | 0.033382 | 256×256×3 -> 4, f:1×1 |
17 | 1 | 4 | 5 | 64 | 0.035840 | 0.022118 | 64×64×1 -> 4, f:5×5 |
38 | 1 | 1 | 1 | 1024 | 0.056525 | 0.055706 | 1024×1024×1 -> 1, f:1×1 |
40 | 32 | 1 | 3 | 128 | 0.087450 | 0.084378 | 128×128×32 -> 1, f:3×3 |
4 | 1 | 8 | 5 | 256 | 0.095232 | 0.082739 | 256×256×1 -> 8, f:5×5 |
37 | 3 | 8 | 3 | 256 | 0.107520 | 0.096256 | 256×256×3 -> 8, f:3×3 |
36 | 1 | 8 | 9 | 128 | 0.117555 | 0.055910 | 128×128×1 -> 8, f:9×9 |
5 | 1 | 4 | 9 | 256 | 0.119194 | 0.094003 | 256×256×1 -> 4, f:9×9 |
43 | 8 | 4 | 1 | 512 | 0.132710 | 0.177357 | 512×512×8 -> 4, f:1×1 |
19 | 128 | 1 | 1 | 256 | 0.204186 | 0.203981 | 256×256×128 -> 1, f:1×1 |
39 | 8 | 8 | 1 | 512 | 0.231834 | 0.336896 | 512×512×8 -> 8, f:1×1 |
13 | 1 | 8 | 5 | 512 | 0.247194 | 0.273203 | 512×512×1 -> 8, f:5×5 |
24 | 3 | 8 | 3 | 512 | 0.292659 | 0.353075 | 512×512×3 -> 8, f:3×3 |
10 | 3 | 8 | 9 | 64 | 0.312934 | 0.050995 | 64×64×3 -> 8, f:9×9 |
0 | 8 | 4 | 9 | 64 | 0.417997 | 0.113254 | 64×64×8 -> 4, f:9×9 |
32 | 1 | 512 | 1 | 128 | 0.425779 | 0.243507 | 128×128×1 -> 512, f:1×1 |
31 | 128 | 4 | 1 | 256 | 0.743834 | 0.588595 | 256×256×128 -> 4, f:1×1 |
26 | 3 | 512 | 1 | 64 | 0.885965 | 0.120422 | 64×64×3 -> 512, f:1×1 |
18 | 8 | 4 | 5 | 512 | 0.916275 | 0.968294 | 512×512×8 -> 4, f:5×5 |
1 | 1 | 512 | 3 | 128 | 1.092198 | 0.514048 | 128×128×1 -> 512, f:3×3 |
46 | 3 | 32 | 3 | 512 | 1.098752 | 1.355366 | 512×512×3 -> 32, f:3×3 |
9 | 8 | 4 | 3 | 1024 | 1.419674 | 2.212659 | 1024×1024×8 -> 4, f:3×3 |
21 | 1 | 128 | 1 | 1024 | 1.955635 | 5.893126 | 1024×1024×1 -> 128, f:1×1 |
3 | 32 | 4 | 1 | 1024 | 2.050253 | 2.409267 | 1024×1024×32 -> 4, f:1×1 |
30 | 512 | 1 | 3 | 256 | 2.841395 | 2.728755 | 256×256×512 -> 1, f:3×3 |
33 | 8 | 32 | 9 | 64 | 3.557990 | 0.306995 | 64×64×8 -> 32, f:9×9 |
14 | 8 | 512 | 1 | 256 | 4.713267 | 4.175462 | 256×256×8 -> 512, f:1×1 |
47 | 32 | 32 | 5 | 128 | 5.372723 | 1.687757 | 128×128×32 -> 32, f:5×5 |
11 | 128 | 4 | 5 | 256 | 5.690982 | 3.754189 | 256×256×128 -> 4, f:5×5 |
6 | 512 | 8 | 1 | 256 | 6.060646 | 4.183040 | 256×256×512 -> 8, f:1×1 |
34 | 3 | 8 | 9 | 1024 | 7.197081 | 8.368333 | 1024×1024×3 -> 8, f:9×9 |
25 | 3 | 32 | 9 | 512 | 7.662797 | 7.265075 | 512×512×3 -> 32, f:9×9 |
12 | 3 | 32 | 5 | 1024 | 10.742989 | 14.039655 | 1024×1024×3 -> 32, f:5×5 |
22 | 512 | 8 | 3 | 64 | 11.946598 | 1.192960 | 64×64×512 -> 8, f:3×3 |
28 | 1 | 128 | 5 | 1024 | 14.895719 | 19.455181 | 1024×1024×1 -> 128, f:5×5 |
23 | 128 | 4 | 9 | 256 | 14.967603 | 10.709811 | 256×256×128 -> 4, f:9×9 |
35 | 128 | 8 | 3 | 512 | 18.949939 | 16.436429 | 512×512×128 -> 8, f:3×3 |
7 | 3 | 512 | 3 | 512 | 20.223590 | 24.182989 | 512×512×3 -> 512, f:3×3 |
20 | 128 | 32 | 3 | 256 | 23.864934 | 15.436800 | 256×256×128 -> 32, f:3×3 |
44 | 512 | 4 | 5 | 256 | 25.573376 | 16.888422 | 256×256×512 -> 4, f:5×5 |
48 | 128 | 8 | 9 | 256 | 30.146976 | 20.129997 | 256×256×128 -> 8, f:9×9 |
15 | 32 | 8 | 5 | 1024 | 35.380025 | 33.746944 | 1024×1024×32 -> 8, f:5×5 |
2 | 32 | 128 | 9 | 128 | 59.110400 | 19.480167 | 128×128×32 -> 128, f:9×9 |
42 | 128 | 4 | 5 | 1024 | 72.265121 | 66.161868 | 1024×1024×128 -> 4, f:5×5 |
29 | 128 | 512 | 3 | 64 | 155.834366 | 12.261376 | 64×64×128 -> 512, f:3×3 |
41 | 32 | 512 | 9 | 64 | 236.589465 | 19.932570 | 64×64×32 -> 512, f:9×9 |