!nvcc -cubin saxpy.cu -o saxpy.cubin -arch=sm_86Misc 3 - Running kernels without CUDA libraries
We will again use out saxpy kernel as example.
saxpy.cu
#include <stdint.h>
// SAXPY kernel: y = alpha*x + y
__global__ void saxpy(float alpha, float *x, float *y, int32_t n)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
y[i] = alpha * x[i] + y[i];
}
}This could also be done with NVRTC
Device init
TinyGrad autogenerates ctypes bindings for the CUDA library, so let’s just use them.
This works pretty much the same way as calling those functions from C.
I will use the lower-level CUDA Driver API functions here.
import ctypes
from tinygrad.runtime.autogen import nv_gpu
from tinygrad.runtime.autogen import libc
from tinygrad.helpers import init_c_var
import os
# os.environ["IOCTL"] = "1"
# import nv_ioctl
# def check(status):
# if status != 0:
# error_str = ctypes.string_at(init_c_var(ctypes.POINTER(ctypes.c_char)(),
# lambda x: cuda.cuGetErrorString(status, ctypes.byref(x)))).decode()
# raise RuntimeError(f"CUDA Error {status}, {error_str}")A class for working with file descriptors, similar to hcq.py:HWInterface
import fcntl
from tinygrad.runtime.autogen import libc
class FD:
def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
self.path:str = path
self.fd:int = fd or os.open(path, flags)
def __del__(self):
if hasattr(self, 'fd'): os.close(self.fd)
def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)stuff = [
"NV01_ROOT_CLIENT",
"NV_ESC_RM_ALLOC"
]
for s in stuff:
v=getattr(nv_gpu, s)
print(f"{s:20}: {v} {hex(v)}")
hex(nv_gpu.NV01_ROOT_CLIENT), hex(nv_gpu.NV_ESC_RM_ALLOC)NV01_ROOT_CLIENT : 65 0x41
NV_ESC_RM_ALLOC : 43 0x2b
('0x41', '0x2b')
fd_uvm_2 = FD("/dev/nvidia-uvm", os.O_RDWR | os.O_CLOEXEC)First thing we do is allocate a “root client”.
I don’t know what the numbers in the macro names mean.
fd_ctl = FD("/dev/nvidiactl", os.O_RDWR | os.O_CLOEXEC)For nvidiactl ioctls, we need to use the following format.
Arg is expected to be a _PARAMS structure appropriate for that ioctl.
def nv_iowr(fd:FD, nr, args):
"""
Create and execute an IOCTL request for NVIDIA driver.
Note: The ioctls often both read and write to/from the args!
See also:
- kernel-open/common/inc/nv.h: Contains NVIDIA driver interface definitions
"""
size = 0 if args is None else ctypes.sizeof(args)
assert size <= 8192, "args size exceeds 8192 bytes"
# From Linux kernel include/uapi/asm-generic/ioctl.h:
# #define _IOC_NRBITS 8
# #define _IOC_TYPEBITS 8
# #define _IOC_SIZEBITS 14
# #define _IOC_DIRBITS 2
# #define _IOC_NRSHIFT 0
# #define _IOC_TYPESHIFT (_IOC_NRSHIFT+_IOC_NRBITS) // 8
# #define _IOC_SIZESHIFT (_IOC_TYPESHIFT+_IOC_TYPEBITS) // 16
# #define _IOC_DIRSHIFT (_IOC_SIZESHIFT+_IOC_SIZEBITS) // 30
# #define _IOC_WRITE 1U
# #define _IOC_READ 2U
# #define _IOC(dir,type,nr,size) \
# (((dir) << _IOC_DIRSHIFT) | \
# ((type) << _IOC_TYPESHIFT) | \
# ((nr) << _IOC_NRSHIFT) | \
# ((size) << _IOC_SIZESHIFT))
# Create a 32-bit value according to the rules in the comments above
# For NVIDIA ioctls, we want both read and write (dir = 3)
# type = 'F' (NVIDIA uses 'F' as type, which is ASCII 70)
dir_bits = 3 << 30
size_bits = size << 16
type_bits = ord('F') << 8
nr_bits = nr
# Combine all parts to form the request code
request = ctypes.c_uint(dir_bits | type_bits | nr_bits | size_bits)
request_code = ctypes.cast(ctypes.addressof(request), ctypes.POINTER(ctypes.c_uint32)).contents.value
if (ret := fd.ioctl(request_code, args)) != 0: raise RuntimeError(f"NVIDIA ioctl command {nr} failed with error code {ret}")Time to allocate the root client.
I don’t know what those numbers 21 and 01 in NVOS21_PARAMETERS and NV01_ROOT_CLIENT mean exactly.
get_root_params = nv_gpu.NVOS21_PARAMETERS(
hRoot=0,
hObjectParent=0,
hClass=nv_gpu.NV01_ROOT_CLIENT,
pAllocParms=None
)
nv_iowr(fd_ctl, nv_gpu.NV_ESC_RM_ALLOC, get_root_params)The ioctl both reads and writes to the params structure.
If an invalid ioctl has been called, the call itself will fail, but if a valid iocrl could not complete successfully, it will set some sort of status fiels in params.
For ESC_RM family of calls, the status field seems to be called status:
nv_gpu.NVOS21_PARAMETERS.as_dict(get_root_params){'hRoot': 0,
'hObjectParent': 0,
'hObjectNew': 3251634574,
'hClass': 65,
'pAllocParms': None,
'paramsSize': 0,
'status': 0}
Let’s write a function to check the status.
def get_error_str(status): return f"{status}: {nv_gpu.nv_status_codes.get(status, 'Unknown error')}"
def check_nvctl(arg):
if arg.status != 0: raise RuntimeError(f"IOCTL Error {arg.status}, {get_error_str(arg.status)}")
return argThis one was successful, and it set the hObjectNew to point to the newly allocated root client handle.
root = check_nvctl(get_root_params).hObjectNewNow let’s initialize the Unified Virtual Memory. Here the ioctls are defined differently, with all the information already embedded in the macro:
hex(nv_gpu.UVM_INITIALIZE)'0x30000001'
Looks like the read-write bits are there, but we don’t have the size and type bits, just the IOCTL number.
def uvm_iowr(fd:FD, nr, args):
if (ret := fd.ioctl(nr, args)) != 0: raise RuntimeError(f"NVIDIA UVM ioctl command {nr} failed with error code {ret}")fd_uvm = FD("/dev/nvidia-uvm", os.O_RDWR | os.O_CLOEXEC)
udm_initialize_params = nv_gpu.UVM_INITIALIZE_PARAMS()
uvm_iowr(fd_uvm, nv_gpu.UVM_INITIALIZE, udm_initialize_params)The status field is also names differently
nv_gpu.UVM_INITIALIZE_PARAMS.as_dict(udm_initialize_params){'flags': 0, 'rmStatus': 0}
def check_uvm(arg):
if arg.rmStatus != 0: raise RuntimeError(f"IOCTL Error {arg.rmStatus}, {get_error_str(arg.rmStatus)}")
return argcheck_uvm(udm_initialize_params);Let’s get the information for the available GPUs.
gpus_info_type = nv_gpu.nv_ioctl_card_info_t*32
gpus_info = gpus_info_type()
nv_iowr(fd_ctl, nv_gpu.NV_ESC_CARD_INFO, gpus_info)
num_gpus = sum(1 for i in gpus_info if i.valid)
print(f"Number of GPUs: {num_gpus}")
nv_gpu.nv_ioctl_card_info_t.as_dict(gpus_info[0])Number of GPUs: 1
{'valid': 1,
'pci_info': {'domain': 0,
'bus': 1,
'slot': 0,
'function': 0,
'vendor_id': 4318,
'device_id': 9436},
'gpu_id': 256,
'interrupt_line': 131,
'reg_address': 3489660928,
'reg_size': 16777216,
'fb_address': 1065151889408,
'fb_size': 17179869184,
'minor_number': 0,
'dev_name': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
root3251634574
gpu_info = nv_gpu.NV0000_CTRL_GPU_GET_ID_INFO_V2_PARAMS(gpuId=gpus_info[0].gpu_id)
ctrl_params = nv_gpu.NVOS54_PARAMETERS(hClient=root,
hObject=root,
cmd=nv_gpu.NV0000_CTRL_CMD_GPU_GET_ID_INFO_V2,
paramsSize=ctypes.sizeof(gpu_info),
params=ctypes.cast(ctypes.byref(gpu_info), ctypes.c_void_p))
nv_iowr(fd_ctl, nv_gpu.NV_ESC_RM_CONTROL, ctrl_params)
nv_gpu.NV0000_CTRL_GPU_GET_ID_INFO_V2_PARAMS.as_dict(gpu_info){'gpuId': 256,
'gpuFlags': 13,
'deviceInstance': 0,
'subDeviceInstance': 0,
'sliStatus': 65,
'boardId': 256,
'gpuInstance': 0,
'numaId': -1}
fd_dev = FD(f"/dev/nvidia0", os.O_RDWR | os.O_CLOEXEC)
fd_info = nv_gpu.nv_ioctl_register_fd_t(ctl_fd=fd_ctl.fd)
nv_iowr(fd_dev, nv_gpu.NV_ESC_REGISTER_FD, fd_info)device_params = nv_gpu.NV0080_ALLOC_PARAMETERS(deviceId=gpu_info.deviceInstance, hClientShare=root,
vaMode=nv_gpu.NV_DEVICE_ALLOCATION_VAMODE_MULTIPLE_VASPACES)
nvdevice_alloc_params = nv_gpu.NVOS21_PARAMETERS(hRoot=root, hObjectParent=root, hClass=nv_gpu.NV01_DEVICE_0,
pAllocParms=ctypes.cast(ctypes.byref(device_params), ctypes.c_void_p))
nv_iowr(fd_ctl, nv_gpu.NV_ESC_RM_ALLOC, nvdevice_alloc_params)
nvdevice = check_nvctl(nvdevice_alloc_params)nv_gpu.NVOS21_PARAMETERS.as_dict(nvdevice){'hRoot': 3251634574,
'hObjectParent': 3251634574,
'hObjectNew': 3404726272,
'hClass': 128,
'pAllocParms': 132724301406960,
'paramsSize': 0,
'status': 0}
nvsubdevice_alloc_params = nv_gpu.NVOS21_PARAMETERS(hRoot=root, hObjectParent=nvdevice.hObjectNew, hClass=nv_gpu.NV20_SUBDEVICE_0,
pAllocParms=ctypes.cast(ctypes.byref(device_params), ctypes.c_void_p))
nv_iowr(fd_ctl, nv_gpu.NV_ESC_RM_ALLOC, nvsubdevice_alloc_params)
nvsubdevice = check_nvctl(nvsubdevice_alloc_params)nv_gpu.NVOS21_PARAMETERS.as_dict(nvsubdevice){'hRoot': 3251634574,
'hObjectParent': 3404726272,
'hObjectNew': 3404726273,
'hClass': 8320,
'pAllocParms': 132724301406960,
'paramsSize': 0,
'status': 0}
usermode_alloc_params = nv_gpu.NVOS21_PARAMETERS(hRoot=root, hObjectParent=nvsubdevice.hObjectNew, hClass=nv_gpu.TURING_USERMODE_A,
pAllocParms=ctypes.cast(ctypes.byref(device_params), ctypes.c_void_p))
nv_iowr(fd_ctl, nv_gpu.NV_ESC_RM_ALLOC, usermode_alloc_params)
usermode = check_nvctl(usermode_alloc_params)nv_gpu.NVOS21_PARAMETERS.as_dict(usermode){'hRoot': 3251634574,
'hObjectParent': 3404726273,
'hObjectNew': 3404726274,
'hClass': 50273,
'pAllocParms': 132724301406960,
'paramsSize': 0,
'status': 0}
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")import mmap
# Allocate a new FD and register it.
mmio_mmap_fd_dev = FD(f"/dev/nvidia0", os.O_RDWR | os.O_CLOEXEC)
mmio_mmap_fd_dev_info = nv_gpu.nv_ioctl_register_fd_t(ctl_fd=fd_ctl.fd)
nv_iowr(mmio_mmap_fd_dev, nv_gpu.NV_ESC_REGISTER_FD, mmio_mmap_fd_dev_info)
mmio_mmap_params = nv_gpu.nv_ioctl_nvos33_parameters_with_fd(fd=mmio_mmap_fd_dev.fd,
params=nv_gpu.NVOS33_PARAMETERS(hClient=root, hDevice=nvdevice.hObjectNew, hMemory=usermode.hObjectNew, length=0x10000, flags=2))
nv_iowr(fd_ctl, nv_gpu.NV_ESC_RM_MAP_MEMORY, mmio_mmap_params)
check_nvctl(mmio_mmap_params.params)
mmio_mmap = mmio_mmap_fd_dev.mmap(None, 0x10000, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, 0)
gpu_mmio = to_mv(mmio_mmap, 0x10000).cast("I")
# self.gpu_mmio = to_mv(self._gpu_map_to_cpu(self.usermode, mmio_sz:=0x10000, flags=2), mmio_sz).cast("I")gpu_mmio<memory>
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html
check(cuda.cuInit(0)) # This 0 is always 0--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[22], line 2 1 # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html ----> 2 check(cuda.cuInit(0)) # This 0 is always 0 NameError: name 'check' is not defined
cu_device = ctypes.c_int() # It's actually just an int with the value of the device ID
# It fails if you try to get a device that does not exist, but oherwise if just returns the device ID you gave it.
check(cuda.cuDeviceGet(cu_device, 0)) # cu_device is passed by pointer, it's converted automatically based on the cuDeviceGet function signature
cu_device.value0
Device information
device_count = ctypes.c_int()
check(cuda.cuDeviceGetCount(device_count))
print(f"Number of CUDA devices available: {device_count.value}")Number of CUDA devices available: 1
device_name = ctypes.create_string_buffer(100)
check(cuda.cuDeviceGetName(device_name, len(device_name), cu_device))
print(device_name.value.decode())NVIDIA GeForce RTX 3080 Laptop GPU
minor = ctypes.c_int()
major = ctypes.c_int()
check(cuda.cuDeviceComputeCapability(major, minor, 0))
major.value, minor.value(8, 6)
# Get total memory on the device
total_memory = ctypes.c_size_t()
check(cuda.cuDeviceTotalMem_v2(ctypes.byref(total_memory), cu_device))
print(f"Total memory on device: {total_memory.value / (1024**3):.2f} GB")Total memory on device: 15.59 GB
Context
In CUDA, a context represents the primary object for managing resources and execution on a GPU device.
Each thread can have one active context at a time, and contexts can be pushed/popped from a stack. The context must be created before any CUDA operations can be performed on a device.
cu_context = cuda.CUcontext()
# cuCtxCreate_v2 is actually identical to cuCtxCreate:
# include/cuda.h:
# ...
# #define cuCtxCreate cuCtxCreate_v2
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html
check(cuda.cuCtxCreate_v2(cu_context, 0, cu_device))
cu_context.contents # This is a pointer to the context object. It is opaque, no idea what is the size or composition of the context object.<tinygrad.runtime.autogen.cuda.struct_CUctx_st>
Memory management
The device is ready. Let’s allocate the memory for the input and output arrays.
import numpy as np
N = 128
# saxpy performs the operation y = a*x + y
dev_x = cuda.CUdeviceptr() # 64 bit pointer to device memory
dev_y = cuda.CUdeviceptr()
# Allocate the buffers on the device
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html
# include/cuda.h:
# #define cuMemAlloc cuMemAlloc_v2
check(cuda.cuMemAlloc_v2(dev_x, N*4))
check(cuda.cuMemAlloc_v2(dev_y, N*4))host_x = np.linspace(0, 100, N).astype(np.float32)
host_y = np.zeros(N).astype(np.float32)
# Copy data to device. IDK why they are all called _v2
check(cuda.cuMemcpyHtoD_v2(dev_x, host_x.ctypes.data_as(ctypes.c_void_p), N*4))
check(cuda.cuMemcpyHtoD_v2(dev_y, host_y.ctypes.data_as(ctypes.c_void_p), N*4))Load and run the kernel
image = open("saxpy.cubin", "rb").read()
module = cuda.CUmodule() # Another opaque object.
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html
check(cuda.cuModuleLoadData(module, image))
module<tinygrad.runtime.autogen.cuda.LP_struct_CUmod_st>
fx = cuda.CUfunction() # You guessed it, anoher opaque object
check(cuda.cuModuleGetFunction(fx, module, "_Z5saxpyfPfS_i".encode("utf-8")))
fx<tinygrad.runtime.autogen.cuda.LP_struct_CUfunc_st>
Create the parameter array for the kernel
grid_size = (1,1,1)
block_size = (N,1,1)
shared_mem_size = 0 # We don't need shared memory for this kernel
# Args
# 0 - alpha (float)
# 1 - input (pointer to float)
# 2 - output (pointer to float)
# 3 - N (int32)
# The params argument to cuLaunchKernel is a pointer to an array of pointers to the parameters.
# The CUDA library knows the size of each parameter from the metadata, so it can figure out how to pass them to the kernel.
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html
alpha_val = ctypes.c_float(2.0)
n = ctypes.c_int32(N)
alpha_ptr = ctypes.cast(ctypes.addressof(alpha_val), ctypes.c_void_p) # Pointer to alpha value
dev_x_ptr = ctypes.cast(ctypes.addressof(dev_x), ctypes.c_void_p) # Pointer to the x array
dev_y_ptr = ctypes.cast(ctypes.addressof(dev_y), ctypes.c_void_p) # Pointer to the y array
n_ptr = ctypes.cast(ctypes.addressof(n), ctypes.c_void_p) # Pointer to the N value
VoidPtrArrayType = ctypes.c_void_p * 4
params = VoidPtrArrayType() # Create the array to hold pointers to args
# Populate the array with pointers to the actual kernel arguments
params[0] = alpha_ptr
params[1] = dev_x_ptr
params[2] = dev_y_ptr
params[3] = n_ptrRun the kernel
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html
check(cuda.cuLaunchKernel(fx, *grid_size, *block_size, shared_mem_size, None, params, None))# Copy the result back to the host
check(cuda.cuMemcpyDtoH_v2(host_y.ctypes.data_as(ctypes.c_void_p), dev_y, N*4))Verify the result
host_x[:16]array([ 0. , 0.78740156, 1.5748031 , 2.3622048 , 3.1496062 ,
3.937008 , 4.7244096 , 5.5118113 , 6.2992125 , 7.086614 ,
7.874016 , 8.661417 , 9.448819 , 10.23622 , 11.0236225 ,
11.811024 ], dtype=float32)
host_y[:16]array([ 0. , 1.5748031, 3.1496062, 4.7244096, 6.2992125,
7.874016 , 9.448819 , 11.0236225, 12.598425 , 14.173228 ,
15.748032 , 17.322834 , 18.897638 , 20.47244 , 22.047245 ,
23.622047 ], dtype=float32)
(host_y == host_x * alpha_val.value).all()np.True_