!nvcc -cubin saxpy.cu -o saxpy.cubin -arch=sm_86
Misc 2 - CUDA Runtime library
TinyGrad has 2 backends for Nvidia GPUS - CUDA
and NV
.
CUDA
just performs calls into the CUDA Runtime library, not much different from host code in a .cu file.NV
skips the library, and talks to the driver directly.
We will play with the CUDA library here, and do things NV-style in the next chapter.
See CUDA Driver API.
saxpy.cu
Let’s look at a free-standing CUDA kernel. This does not include any host code:
saxpy.cu
#include <stdint.h>
// SAXPY kernel: y = alpha*x + y
void saxpy(float alpha, float *x, float *y, int32_t n)
__global__ {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
[i] = alpha * x[i] + y[i];
y}
}
Cuda compilation.
.cu files often contain a mixture of host and device code. Cuda will split and group them, and will compile them separately, with different compilers for host and device. In out example, we don’t have any host code - we will perform all calls to CUDA Runtime Library from python using ctypes
.
NVCC compile the device part of .cu in 2 stages:
Source: NVCC Documentation
PTX is the higher-level and mode device-independent aseembly-like code.
SASS is the final assembly code that maps 1:1 to machine code. It is probably backward-compatible, but if you compile old PTX for a new target device SASS, you might get more optimal code that won’t run on the old devices.
Here we combined the 2 steps into one, and generated a .cubin. It’s actually just an ELF file that contains metadata and the machine code for the kernel:
!objdump --headers saxpy.cubin
saxpy.cubin: file format elf64-little
Sections:
Idx Name Size VMA LMA File off Algn
0 .debug_frame 00000070 0000000000000000 0000000000000000 00000320 2**0
CONTENTS, RELOC, READONLY, DEBUGGING, OCTETS
1 .nv.info 00000024 0000000000000000 0000000000000000 00000390 2**2
CONTENTS, READONLY
2 .nv.info._Z5saxpyfPfS_i 0000006c 0000000000000000 0000000000000000 000003b4 2**2
CONTENTS, READONLY
3 .nv.callgraph 00000020 0000000000000000 0000000000000000 00000420 2**2
CONTENTS, READONLY
4 .nv.rel.action 00000010 0000000000000000 0000000000000000 00000440 2**3
CONTENTS, READONLY
5 .nv.constant0._Z5saxpyfPfS_i 0000017c 0000000000000000 0000000000000000 00000460 2**2
CONTENTS, ALLOC, LOAD, READONLY, DATA
6 .text._Z5saxpyfPfS_i 00000180 0000000000000000 0000000000000000 00000600 2**7
CONTENTS, ALLOC, LOAD, READONLY, CODE
.text._Z5saxpyfPfS_i
must have the code for the kernel.
I don’t know why .nv.constant0._Z5saxpyfPfS_i is so large (300 bytes), but it’s all zeros.
There is also cuobjdump
that can parse the CUDA-specific sections:
!cuobjdump saxpy.cubin --dump-resource-usage
Resource usage:
Common:
GLOBAL:0
Function _Z5saxpyfPfS_i:
REG:10 STACK:0 SHARED:0 LOCAL:0 CONSTANT[0]:380 TEXTURE:0 SURFACE:0 SAMPLER:0
And nvdiasm
disassembler:
!nvdisasm -c saxpy.cubin
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM86 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM86)"
.elftype @"ET_EXEC"
//--------------------- .text._Z5saxpyfPfS_i --------------------------
.section .text._Z5saxpyfPfS_i,"ax",@progbits
.sectioninfo @"SHI_REGISTERS=10"
.align 128
.global _Z5saxpyfPfS_i
.type _Z5saxpyfPfS_i,@function
.size _Z5saxpyfPfS_i,(.L_x_1 - _Z5saxpyfPfS_i)
.other _Z5saxpyfPfS_i,@"STO_CUDA_ENTRY STV_DEFAULT"
_Z5saxpyfPfS_i:
.text._Z5saxpyfPfS_i:
/*0000*/ MOV R1, c[0x0][0x28] ;
/*0010*/ S2R R4, SR_CTAID.X ;
/*0020*/ S2R R3, SR_TID.X ;
/*0030*/ IMAD R4, R4, c[0x0][0x0], R3 ;
/*0040*/ ISETP.GE.AND P0, PT, R4, c[0x0][0x178], PT ;
/*0050*/ @P0 EXIT ;
/*0060*/ MOV R5, 0x4 ;
/*0070*/ ULDC.64 UR4, c[0x0][0x118] ;
/*0080*/ IMAD.WIDE R2, R4, R5, c[0x0][0x168] ;
/*0090*/ IMAD.WIDE R4, R4, R5, c[0x0][0x170] ;
/*00a0*/ LDG.E R2, [R2.64] ;
/*00b0*/ LDG.E R7, [R4.64] ;
/*00c0*/ FFMA R7, R2, c[0x0][0x160], R7 ;
/*00d0*/ STG.E [R4.64], R7 ;
/*00e0*/ EXIT ;
.L_x_0:
/*00f0*/ BRA `(.L_x_0);
/*0100*/ NOP;
/*0110*/ NOP;
/*0120*/ NOP;
/*0130*/ NOP;
/*0140*/ NOP;
/*0150*/ NOP;
/*0160*/ NOP;
/*0170*/ NOP;
.L_x_1:
Device init
TinyGrad autogenerates ctypes
bindings for the CUDA library, so let’s just use them.
This works pretty much the same way as calling those functions from C.
I will use the lower-level CUDA Driver API functions here.
import ctypes
from tinygrad.runtime.autogen import cuda #
from tinygrad.helpers import init_c_var
# import os
# os.environ["IOCTL"] = "1"
# import nv_ioctl
def check(status):
if status != 0:
= ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(),
error_str lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()
raise RuntimeError(f"CUDA Error {status}, {error_str}")
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html
0)) # This 0 is always 0 check(cuda.cuInit(
= ctypes.c_int() # It's actually just an int with the value of the device ID
cu_device # It fails if you try to get a device that does not exist, but oherwise if just returns the device ID you gave it.
0)) # cu_device is passed by pointer, it's converted automatically based on the cuDeviceGet function signature
check(cuda.cuDeviceGet(cu_device, cu_device.value
0
Device information
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html
= ctypes.c_int()
device_count
check(cuda.cuDeviceGetCount(device_count))print(f"Number of CUDA devices available: {device_count.value}")
Number of CUDA devices available: 1
= ctypes.create_string_buffer(100)
device_name len(device_name), cu_device))
check(cuda.cuDeviceGetName(device_name, print(device_name.value.decode())
NVIDIA GeForce RTX 3080 Laptop GPU
= ctypes.c_int()
minor = ctypes.c_int()
major 0))
check(cuda.cuDeviceComputeCapability(major, minor, major.value, minor.value
(8, 6)
# Get total memory on the device
= ctypes.c_size_t()
total_memory
check(cuda.cuDeviceTotalMem_v2(ctypes.byref(total_memory), cu_device))print(f"Total memory on device: {total_memory.value / (1024**3):.2f} GB")
Total memory on device: 15.59 GB
Context
In CUDA, a context represents the primary object for managing resources and execution on a GPU device.
Each thread can have one active context at a time, and contexts can be pushed/popped from a stack. The context must be created before any CUDA operations can be performed on a device.
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
= cuda.CUcontext()
cu_context
# cuCtxCreate_v2 is actually identical to cuCtxCreate:
# include/cuda.h:
# ...
# #define cuCtxCreate cuCtxCreate_v2
0, cu_device))
check(cuda.cuCtxCreate_v2(cu_context, # This is a pointer to the context object. It is opaque, no idea what is the size or composition of the context object. cu_context.contents
<tinygrad.runtime.autogen.cuda.struct_CUctx_st>
Memory management
The device is ready. Let’s allocate the memory for the input and output arrays.
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html
import numpy as np
= 128
N
# saxpy performs the operation y = a*x + y
= cuda.CUdeviceptr() # 64 bit pointer to device memory
dev_x = cuda.CUdeviceptr()
dev_y
# Allocate the buffers on the device
# include/cuda.h:
# #define cuMemAlloc cuMemAlloc_v2
*4))
check(cuda.cuMemAlloc_v2(dev_x, N*4)) check(cuda.cuMemAlloc_v2(dev_y, N
= np.linspace(0, 100, N).astype(np.float32)
host_x = np.zeros(N).astype(np.float32)
host_y
# Copy data to device. IDK why they are all called _v2
*4))
check(cuda.cuMemcpyHtoD_v2(dev_x, host_x.ctypes.data_as(ctypes.c_void_p), N*4)) check(cuda.cuMemcpyHtoD_v2(dev_y, host_y.ctypes.data_as(ctypes.c_void_p), N
Load and run the kernel
Normally when you build a .cu file that contains both host and device code, it looks along the lines of
The device code gets compiled into a cuda binary.
The host code gets compiled, and the cuda binary is included as a binary blob.
The saxpy_parallel<<<nblocks,1024>>>(alpha,d_x,d_y,N);
is just syntactic sugar that
m = cuModuleLoadData(blob)
- loads the blob as a “CUDA module”fx = cuModuleGetFunction(m, "saxpy")
- create a handle to the kernelcuLaunchKernel(fx, grid, block, ¶ms)
- launch the kernel with the grid/bliock config and parameters
We built the .cubin as a separate file, so let’s do the same thing manually.
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html
= open("saxpy.cubin", "rb").read()
image = cuda.CUmodule() # Another opaque object.
module
check(cuda.cuModuleLoadData(module, image)) module
<tinygrad.runtime.autogen.cuda.LP_struct_CUmod_st>
= cuda.CUfunction() # You guessed it, anoher opaque object
fx "_Z5saxpyfPfS_i".encode("utf-8")))
check(cuda.cuModuleGetFunction(fx, module, fx
<tinygrad.runtime.autogen.cuda.LP_struct_CUfunc_st>
Create the parameter array for the kernel
= (1,1,1)
grid_size = (N,1,1)
block_size
= 0 # We don't need shared memory for this kernel
shared_mem_size
# Args
# 0 - alpha (float)
# 1 - input (pointer to float)
# 2 - output (pointer to float)
# 3 - N (int32)
# The params argument to cuLaunchKernel is a pointer to an array of pointers to the parameters.
# The CUDA library knows the size of each parameter from the metadata, so it can figure out how to pass them to the kernel.
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html
= ctypes.c_float(2.0)
alpha_val = ctypes.c_int32(N)
n
= ctypes.cast(ctypes.addressof(alpha_val), ctypes.c_void_p) # Pointer to alpha value
alpha_ptr = ctypes.cast(ctypes.addressof(dev_x), ctypes.c_void_p) # Pointer to the x array
dev_x_ptr = ctypes.cast(ctypes.addressof(dev_y), ctypes.c_void_p) # Pointer to the y array
dev_y_ptr = ctypes.cast(ctypes.addressof(n), ctypes.c_void_p) # Pointer to the N value
n_ptr
= ctypes.c_void_p * 4
VoidPtrArrayType = VoidPtrArrayType() # Create the array to hold pointers to args
params
# Populate the array with pointers to the actual kernel arguments
0] = alpha_ptr
params[1] = dev_x_ptr
params[2] = dev_y_ptr
params[3] = n_ptr params[
Run the kernel
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html
*grid_size, *block_size, shared_mem_size, None, params, None)) check(cuda.cuLaunchKernel(fx,
# Copy the result back to the host
*4)) check(cuda.cuMemcpyDtoH_v2(host_y.ctypes.data_as(ctypes.c_void_p), dev_y, N
Verify the result
16] host_x[:
array([ 0. , 0.78740156, 1.5748031 , 2.3622048 , 3.1496062 ,
3.937008 , 4.7244096 , 5.5118113 , 6.2992125 , 7.086614 ,
7.874016 , 8.661417 , 9.448819 , 10.23622 , 11.0236225 ,
11.811024 ], dtype=float32)
16] host_y[:
array([ 0. , 1.5748031, 3.1496062, 4.7244096, 6.2992125,
7.874016 , 9.448819 , 11.0236225, 12.598425 , 14.173228 ,
15.748032 , 17.322834 , 18.897638 , 20.47244 , 22.047245 ,
23.622047 ], dtype=float32)
== host_x * alpha_val.value).all() (host_y
np.True_