Adventures with TinyGrad
  1. Misc 2 - CUDA Runtime library
  • Adventures with TinyGrad
  • 0 - Introduction
  • 1 - UOps
  • 2 - View and ShapeTracker
  • 3 - The Pattern Matcher
  • 4 - The .arange() insanity
  • Appendix A: helpers.py
  • Appendix B - UOp Summary
  • Misc 1 - elf.py and the ELF format
  • Misc 2 - CUDA Runtime library

On this page

  • saxpy.cu
  • Device init
  • Device information
  • Context
  • Memory management
  • Load and run the kernel

Misc 2 - CUDA Runtime library

TinyGrad has 2 backends for Nvidia GPUS - CUDA and NV.

  • CUDA just performs calls into the CUDA Runtime library, not much different from host code in a .cu file.
  • NV skips the library, and talks to the driver directly.

We will play with the CUDA library here, and do things NV-style in the next chapter.

See CUDA Driver API.

saxpy.cu

Let’s look at a free-standing CUDA kernel. This does not include any host code:

saxpy.cu

#include <stdint.h>

// SAXPY kernel: y = alpha*x + y
__global__ void saxpy(float alpha, float *x, float *y, int32_t n)
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) {
        y[i] = alpha * x[i] + y[i];
    }
}
!nvcc -cubin saxpy.cu -o saxpy.cubin -arch=sm_86

Cuda compilation.

.cu files often contain a mixture of host and device code. Cuda will split and group them, and will compile them separately, with different compilers for host and device. In out example, we don’t have any host code - we will perform all calls to CUDA Runtime Library from python using ctypes.

NVCC compile the device part of .cu in 2 stages:

Source: NVCC Documentation

  • PTX is the higher-level and mode device-independent aseembly-like code.

  • SASS is the final assembly code that maps 1:1 to machine code. It is probably backward-compatible, but if you compile old PTX for a new target device SASS, you might get more optimal code that won’t run on the old devices.

Here we combined the 2 steps into one, and generated a .cubin. It’s actually just an ELF file that contains metadata and the machine code for the kernel:

!objdump --headers saxpy.cubin

saxpy.cubin:     file format elf64-little

Sections:
Idx Name          Size      VMA               LMA               File off  Algn
  0 .debug_frame  00000070  0000000000000000  0000000000000000  00000320  2**0
                  CONTENTS, RELOC, READONLY, DEBUGGING, OCTETS
  1 .nv.info      00000024  0000000000000000  0000000000000000  00000390  2**2
                  CONTENTS, READONLY
  2 .nv.info._Z5saxpyfPfS_i 0000006c  0000000000000000  0000000000000000  000003b4  2**2
                  CONTENTS, READONLY
  3 .nv.callgraph 00000020  0000000000000000  0000000000000000  00000420  2**2
                  CONTENTS, READONLY
  4 .nv.rel.action 00000010  0000000000000000  0000000000000000  00000440  2**3
                  CONTENTS, READONLY
  5 .nv.constant0._Z5saxpyfPfS_i 0000017c  0000000000000000  0000000000000000  00000460  2**2
                  CONTENTS, ALLOC, LOAD, READONLY, DATA
  6 .text._Z5saxpyfPfS_i 00000180  0000000000000000  0000000000000000  00000600  2**7
                  CONTENTS, ALLOC, LOAD, READONLY, CODE

.text._Z5saxpyfPfS_i must have the code for the kernel.

I don’t know why .nv.constant0._Z5saxpyfPfS_i is so large (300 bytes), but it’s all zeros.

There is also cuobjdump that can parse the CUDA-specific sections:

!cuobjdump saxpy.cubin --dump-resource-usage

Resource usage:
 Common:
  GLOBAL:0
 Function _Z5saxpyfPfS_i:
  REG:10 STACK:0 SHARED:0 LOCAL:0 CONSTANT[0]:380 TEXTURE:0 SURFACE:0 SAMPLER:0

And nvdiasm disassembler:

!nvdisasm -c saxpy.cubin
    .headerflags    @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM86 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM86)"
    .elftype    @"ET_EXEC"


//--------------------- .text._Z5saxpyfPfS_i      --------------------------
    .section    .text._Z5saxpyfPfS_i,"ax",@progbits
    .sectioninfo    @"SHI_REGISTERS=10"
    .align  128
        .global         _Z5saxpyfPfS_i
        .type           _Z5saxpyfPfS_i,@function
        .size           _Z5saxpyfPfS_i,(.L_x_1 - _Z5saxpyfPfS_i)
        .other          _Z5saxpyfPfS_i,@"STO_CUDA_ENTRY STV_DEFAULT"
_Z5saxpyfPfS_i:
.text._Z5saxpyfPfS_i:
        /*0000*/                   MOV R1, c[0x0][0x28] ;
        /*0010*/                   S2R R4, SR_CTAID.X ;
        /*0020*/                   S2R R3, SR_TID.X ;
        /*0030*/                   IMAD R4, R4, c[0x0][0x0], R3 ;
        /*0040*/                   ISETP.GE.AND P0, PT, R4, c[0x0][0x178], PT ;
        /*0050*/               @P0 EXIT ;
        /*0060*/                   MOV R5, 0x4 ;
        /*0070*/                   ULDC.64 UR4, c[0x0][0x118] ;
        /*0080*/                   IMAD.WIDE R2, R4, R5, c[0x0][0x168] ;
        /*0090*/                   IMAD.WIDE R4, R4, R5, c[0x0][0x170] ;
        /*00a0*/                   LDG.E R2, [R2.64] ;
        /*00b0*/                   LDG.E R7, [R4.64] ;
        /*00c0*/                   FFMA R7, R2, c[0x0][0x160], R7 ;
        /*00d0*/                   STG.E [R4.64], R7 ;
        /*00e0*/                   EXIT ;
.L_x_0:
        /*00f0*/                   BRA `(.L_x_0);
        /*0100*/                   NOP;
        /*0110*/                   NOP;
        /*0120*/                   NOP;
        /*0130*/                   NOP;
        /*0140*/                   NOP;
        /*0150*/                   NOP;
        /*0160*/                   NOP;
        /*0170*/                   NOP;
.L_x_1:

Device init

TinyGrad autogenerates ctypes bindings for the CUDA library, so let’s just use them.

This works pretty much the same way as calling those functions from C.

I will use the lower-level CUDA Driver API functions here.

import ctypes
from tinygrad.runtime.autogen import cuda  #
from tinygrad.helpers import init_c_var

# import os
# os.environ["IOCTL"] = "1"
# import nv_ioctl

def check(status):
    if status != 0:
        error_str = ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(),
                                               lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()
        raise RuntimeError(f"CUDA Error {status}, {error_str}")

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html

check(cuda.cuInit(0)) # This 0 is always 0
cu_device = ctypes.c_int() # It's actually just an int with the value of the device ID
# It fails if you try to get a device that does not exist, but oherwise if just returns the device ID you gave it.
check(cuda.cuDeviceGet(cu_device, 0)) # cu_device is passed by pointer, it's converted automatically based on the cuDeviceGet function signature
cu_device.value
0

Device information

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html

device_count = ctypes.c_int()
check(cuda.cuDeviceGetCount(device_count))
print(f"Number of CUDA devices available: {device_count.value}")
Number of CUDA devices available: 1
device_name = ctypes.create_string_buffer(100)
check(cuda.cuDeviceGetName(device_name, len(device_name), cu_device))
print(device_name.value.decode())
NVIDIA GeForce RTX 3080 Laptop GPU
minor = ctypes.c_int()
major = ctypes.c_int()
check(cuda.cuDeviceComputeCapability(major, minor, 0))
major.value, minor.value
(8, 6)
# Get total memory on the device
total_memory = ctypes.c_size_t()
check(cuda.cuDeviceTotalMem_v2(ctypes.byref(total_memory), cu_device))
print(f"Total memory on device: {total_memory.value / (1024**3):.2f} GB")
Total memory on device: 15.59 GB

Context

In CUDA, a context represents the primary object for managing resources and execution on a GPU device.

Each thread can have one active context at a time, and contexts can be pushed/popped from a stack. The context must be created before any CUDA operations can be performed on a device.

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html

cu_context = cuda.CUcontext()

# cuCtxCreate_v2 is actually identical to cuCtxCreate:
# include/cuda.h:
# ...
# #define cuCtxCreate                         cuCtxCreate_v2
check(cuda.cuCtxCreate_v2(cu_context, 0, cu_device))
cu_context.contents # This is a pointer to the context object. It is opaque, no idea what is the size or composition of the context object.
<tinygrad.runtime.autogen.cuda.struct_CUctx_st>

Memory management

The device is ready. Let’s allocate the memory for the input and output arrays.

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html

import numpy as np
N = 128

# saxpy performs the operation y = a*x + y
dev_x = cuda.CUdeviceptr() # 64 bit pointer to device memory
dev_y = cuda.CUdeviceptr()


# Allocate the buffers on the device
# include/cuda.h:
# #define cuMemAlloc                          cuMemAlloc_v2
check(cuda.cuMemAlloc_v2(dev_x, N*4))
check(cuda.cuMemAlloc_v2(dev_y, N*4))
host_x = np.linspace(0, 100, N).astype(np.float32)
host_y = np.zeros(N).astype(np.float32)

# Copy data to device. IDK why they are all called _v2
check(cuda.cuMemcpyHtoD_v2(dev_x, host_x.ctypes.data_as(ctypes.c_void_p), N*4))
check(cuda.cuMemcpyHtoD_v2(dev_y, host_y.ctypes.data_as(ctypes.c_void_p), N*4))

Load and run the kernel

Normally when you build a .cu file that contains both host and device code, it looks along the lines of

The device code gets compiled into a cuda binary.

The host code gets compiled, and the cuda binary is included as a binary blob.

The saxpy_parallel<<<nblocks,1024>>>(alpha,d_x,d_y,N); is just syntactic sugar that

  • m = cuModuleLoadData(blob) - loads the blob as a “CUDA module”
  • fx = cuModuleGetFunction(m, "saxpy") - create a handle to the kernel
  • cuLaunchKernel(fx, grid, block, &params) - launch the kernel with the grid/bliock config and parameters

We built the .cubin as a separate file, so let’s do the same thing manually.

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html

image = open("saxpy.cubin", "rb").read()
module = cuda.CUmodule() # Another opaque object.

check(cuda.cuModuleLoadData(module, image))
module
<tinygrad.runtime.autogen.cuda.LP_struct_CUmod_st>
fx = cuda.CUfunction() # You guessed it, anoher opaque object
check(cuda.cuModuleGetFunction(fx, module, "_Z5saxpyfPfS_i".encode("utf-8")))
fx
<tinygrad.runtime.autogen.cuda.LP_struct_CUfunc_st>

Create the parameter array for the kernel

grid_size =  (1,1,1)
block_size = (N,1,1)

shared_mem_size = 0 # We don't need shared memory for this kernel

# Args
# 0 - alpha (float)
# 1 - input (pointer to float)
# 2 - output (pointer to float)
# 3 - N (int32)

# The params argument to cuLaunchKernel is a pointer to an array of pointers to the parameters.
# The CUDA library knows the size of each parameter from the metadata, so it can figure out how to pass them to the kernel.
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html

alpha_val = ctypes.c_float(2.0)
n = ctypes.c_int32(N)

alpha_ptr = ctypes.cast(ctypes.addressof(alpha_val), ctypes.c_void_p) # Pointer to alpha value
dev_x_ptr = ctypes.cast(ctypes.addressof(dev_x), ctypes.c_void_p) # Pointer to the x array
dev_y_ptr = ctypes.cast(ctypes.addressof(dev_y), ctypes.c_void_p) # Pointer to the y array
n_ptr = ctypes.cast(ctypes.addressof(n), ctypes.c_void_p) # Pointer to the N value

VoidPtrArrayType = ctypes.c_void_p * 4
params = VoidPtrArrayType() # Create the array to hold pointers to args

# Populate the array with pointers to the actual kernel arguments
params[0] = alpha_ptr
params[1] = dev_x_ptr
params[2] = dev_y_ptr
params[3] = n_ptr

Run the kernel

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html

check(cuda.cuLaunchKernel(fx, *grid_size, *block_size, shared_mem_size, None, params, None))
# Copy the result back to the host
check(cuda.cuMemcpyDtoH_v2(host_y.ctypes.data_as(ctypes.c_void_p), dev_y, N*4))

Verify the result

host_x[:16]
array([ 0.        ,  0.78740156,  1.5748031 ,  2.3622048 ,  3.1496062 ,
        3.937008  ,  4.7244096 ,  5.5118113 ,  6.2992125 ,  7.086614  ,
        7.874016  ,  8.661417  ,  9.448819  , 10.23622   , 11.0236225 ,
       11.811024  ], dtype=float32)
host_y[:16]
array([ 0.       ,  1.5748031,  3.1496062,  4.7244096,  6.2992125,
        7.874016 ,  9.448819 , 11.0236225, 12.598425 , 14.173228 ,
       15.748032 , 17.322834 , 18.897638 , 20.47244  , 22.047245 ,
       23.622047 ], dtype=float32)
(host_y == host_x * alpha_val.value).all()
np.True_