import numpy as np
from PIL import Image
from pathlib import Path
Day 9 - Conv 2D
Unlike the book, I’m going to implement convolution with arbitrary number of input and output channels
from lovely_numpy import Lo
from lovely_tensors import monkey_patch; monkey_patch()
from torch import Tensor
from torch.nn.functional import conv2d
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
cuda.init()
= cuda.Device(0)
device
print(f"Cuda version: {".".join([str(i) for i in cuda.get_version()])}")
print(f"Device:\t{device.name()}")
Cuda version: 12.8.0
Device: NVIDIA GeForce RTX 3080 Laptop GPU
="day_09_conv2d.cu" cu_file
#include <stdint.h>
#include <stdio.h>
/* 2D convolution, with padding to valid shape. Channel-first */
__global__ void conv2d_pad(float *in,
float *out,
float *filter,
int h,
int w,
int in_channels,
int out_channels,
int filter_size /* Must be an odd number */,
float pad) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int filter_r = (filter_size - 1) / 2;
// In and Out data dimensions:
// 0 - channel
// 1 - height
// 2 - width
// Filter dimensions:
// 0 - out channels
// 1 - in channels
// 2 - height
// 3 - width
// if (x == 0 && y == 0) {
// printf("h: %d\n", h);
// printf("w: %d\n", w);
// printf("in_channels: %d\n", in_channels);
// printf("out_channels: %d\n", out_channels);
// printf("filter_size: %d\n", filter_size);
// printf("filter r: %d\n", filter_r);
// printf("pad: %f\n", pad);
// printf("Filter:\n");
// for (int oc = 0; oc < out_channels; oc++) {
// printf("Output channel %d:\n", oc);
// for (int ic = 0; ic < in_channels; ic++) {
// printf(" Input channel %d:\n", ic);
// float *sub_filter = filter + (filter_size * filter_size * in_channels * oc) +
// (filter_size * filter_size * ic);
// for (int i = 0; i < filter_size; i++) {
// printf(" ");
// for (int j = 0; j < filter_size; j++) {
// printf("%f ", sub_filter[i * filter_size + j]);
// }
// printf("\n");
// }
// }
// }
// }
if (x >= w || y >= h) return;
// Loop over the output channels
for (int out_c = 0; out_c < out_channels; out_c++) {
float R = 0;
// Pointer to the 2d slice of the output
float *sub_output = out + out_c * w * h;
// Loop over the input channels
for (int in_c = 0; in_c < in_channels; in_c++) {
// Pointer to the 2d slice of the filter that corresponds to the active input and output
// channels
float *sub_filter = filter + (filter_size * filter_size * in_channels * out_c) +
(filter_size * filter_size * in_c);
// Pinter to the current channel in the input
float *sub_input = in + (w * h * in_c);
// Apply the filter to the input or the pad value for outside indices.
for (int filter_y = 0; filter_y < filter_size; filter_y++) {
for (int filter_x = 0; filter_x < filter_size; filter_x++) {
float v = pad;
int input_x = x - filter_r + filter_x;
int input_y = y - filter_r + filter_y;
if (input_x >= 0 && input_x < w && input_y >= 0 && input_y < h) {
v = sub_input[input_y * w + input_x];
}
R += v * sub_filter[filter_y * filter_size + filter_x];
}
}
}
sub_output[y * w + x] = R;
}
}
## Compiler options for more compile-time warnings.
=[
warn_options'-Xcompiler', '-Wall',
'-Xcompiler', '-Wextra',
'-Xcompiler', '-Wsign-conversion',
'-Xcompiler', '-Wcast-qual',
'-Xcompiler', '-Wunused-parameter',
'-Xcompiler', '-Wdouble-promotion',
'-Xcompiler', '-Wformat=2',
'-Xcompiler', '-Wfloat-equal',
'-Xcompiler', '-Wshadow'
]
def benchmark_conv2d_pad(ctx, kernel, input, filter, pad, block_size, grid_size, repeat=10, warmup=True):
# input, channel-first
# - Channel
# - Height
# - Width
assert len(input.shape) == 3
# Filter shape should be
# - Out channels
# - In channels
# - Height
# - Width
assert len(filter.shape) == 4
= input.shape
in_ch, h, w = filter.shape
out_ch, in_ch2, fh, fw
assert fh == fw, f"Only square filters supported, got shape={filter.shape}"
assert in_ch == in_ch2
= (out_ch, h, w)
out_shape
= cuda.mem_alloc_like(input)
gpu_input = cuda.mem_alloc_like(filter)
gpu_filter
= np.empty(out_shape, dtype=np.float32)
out
input)
cuda.memcpy_htod(gpu_input, filter)
cuda.memcpy_htod(gpu_filter,
ctx.synchronize()
=0
timingfor _ in range(repeat):
= cuda.Event()
start = cuda.Event()
end
= cuda.mem_alloc_like(out)
gpu_out
if warmup:
kernel(gpu_input, gpu_out, gpu_filter,
np.int32(h),
np.int32(w),
np.int32(in_ch),
np.int32(out_ch),
np.int32(fh),
np.float32(pad),=grid_size,
grid=block_size)
block
ctx.synchronize()
start.record()
kernel(gpu_input, gpu_out, gpu_filter,
np.int32(h),
np.int32(w),
np.int32(in_ch),
np.int32(out_ch),
np.int32(fh),
np.float32(pad),=grid_size,
grid=block_size)
block
end.record()
end.synchronize()
+= end.time_since(start)
timing /= repeat
timing
cuda.memcpy_dtoh(out, gpu_out)return out, timing;
input = Image.open("../cat-1.jpg")
= np.ascontiguousarray(np.array(input).transpose(2, 0, 1)).astype(np.float32) / 256
input_array
# input_array = np.linspace(0, 0.5, 3*32*32).reshape(3, 32, 32).astype(np.float32) + 0.5
# Convert from HWC to CHW format
# input_array = np.ascontiguousarray(input_array.transpose(2, 0, 1))[:1,:,:]
print(Lo(input_array))
=False) Lo(input_array).chans(cl
array[3, 600, 451] f32 n=811800 (3.1Mb) x∈[0., 0.996] μ=0.592 σ=0.147
= 4
out_channels = 3
filter_size
filter = np.random.randn(out_channels, input_array.shape[0], filter_size, filter_size).astype(np.float32) / 5
# This filter does nothing to the input image.
# filter = np.array([
# [[[0, 0, 0],
# [0, 1, 0],
# [0, 0 ,0]]]
# ]).astype(np.float32)
filter) Lo(
array[4, 3, 3, 3] f32 n=108 x∈[-0.754, 0.447] μ=0.020 σ=0.204
# input_array = np.random.randn(3, 64,64).astype(np.float32)
= conv2d(Tensor(input_array), Tensor(filter), padding="same")
torch_res print(torch_res)
=1) torch_res.chans(scale
tensor[4, 600, 451] n=1082400 (4.1Mb) x∈[-0.401, 1.342] μ=0.306 σ=0.151
= 32
tile_width = input_array.shape
ch, h, w
= device.make_context()
ctx try:
= SourceModule(
mod
Path(cu_file).read_text(),=warn_options)
options
= mod.get_function("conv2d_pad")
kernel
= (tile_width,tile_width,1)
block_size = (
grid_size + tile_width - 1) // tile_width),
((w + tile_width - 1) // tile_width),
((h 1
)
print(f"Input shape: {input_array.shape}")
print(f"Filter shape: {filter.shape}")
print(f"Result shape: {(filter.shape[0], input_array.shape[1], input_array.shape[2])}")
print(f"Grid size: {grid_size}")
print(f"Block size: {block_size}")
print(f"Total threads: {grid_size[0] * grid_size[1] * block_size[0] * block_size[1]}")
= benchmark_conv2d_pad(ctx, kernel, input_array, filter, 0, block_size, grid_size, repeat=1, warmup=False)
res, timing finally:
ctx.pop()
ctx.detach()
print(Lo(res))
print(f"Took {timing:.3f}ms")
Input shape: (3, 600, 451)
Filter shape: (4, 3, 3, 3)
Result shape: (4, 600, 451)
Grid size: (15, 19, 1)
Block size: (32, 32, 1)
Total threads: 291840
array[4, 600, 451] f32 n=1082400 (4.1Mb) x∈[-0.401, 1.342] μ=0.306 σ=0.151
Took 0.142ms
np.isclose(res, torch_res).mean()
np.float64(0.9994484478935698)
Looks good!
I’ll leave benchmarks and performance improvements for tomorrow.