Adventures with TinyGrad
  1. 0 - Introduction
  • Adventures with TinyGrad
  • 0 - Introduction
  • 1 - UOps
  • 2 - View and ShapeTracker
  • 3 - The Pattern Matcher
  • 4 - The .arange() insanity
  • Appendix A: helpers.py
  • Appendix B - UOp Summary
  • Misc 1 - elf.py and the ELF format
  • Misc 2 - CUDA Runtime library

On this page

  • Lazy tensors and UOps
    • Kernels on CPU
    • .numpy() and .tolist()
    • Kernels on GPU

0 - Introduction

import os

os.environ["CPU"] = "1"
os.environ["TRACEMETA"] = "0"
os.environ["DEBUG"]="4"
import tinygrad as tg
from tinygrad import Tensor, dtypes

Lazy tensors and UOps

Tinygrad API is quite similar to PyTorch, with some quirks.

a = Tensor.ones(10, 10)
b = a + 2
b
<Tensor <UOp CPU (10, 10) float (<Ops.ADD: 48>, None)> on CPU with grad None>

TinyGrad is lazy - it does not perform any computation until explicitly asked to.

Instead, it saves the operations required to get the result as a tree of UOps:

b.lazydata # It's called `lazydaya` for historic reasons. Rename to `Tensor.uops`?
UOp(Ops.ADD, dtypes.float, arg=None, src=(
  UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
    UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
      UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
        x3:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
          UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),)),)),)),
  UOp(Ops.EXPAND, dtypes.float, arg=(10, 10), src=(
    UOp(Ops.RESHAPE, dtypes.float, arg=(1, 1), src=(
      UOp(Ops.CONST, dtypes.float, arg=2.0, src=(
         x3,)),)),)),))

The ADD UOp has 2 sources, both being constants (1 and 2) that are both reshaped to (1, 1) and expanded to shape (10, 10).

The CONST UOp takes the value as argument, and has a VIEW UOp as it’s source, which in turn sources from a DEVICE Uop.

Note that x3:= walrus assignment, and x3 being reused for the second CONST UOp.

from tinygrad.ops import UOp, Ops

We will take a detailed look at UOps in the next chapter, but for now, let’s see how to actually compute the value of b.

Kernels on CPU

# This runs the computations needed to get the value of the tensor
# It does not get realized without the .contiguous() though (TODO: Explain why)
# Also, should .contiguous() just always be part of .realize()?
b_realized = b.contiguous().realize()
opened device CPU from pid:958683
E_25_4
 0: (25, 4)                   float.ptr(100)       (4, 1)
[Opt(op=OptOps.UPCAST, axis=0, arg=4)]
typedef float float4 __attribute__((aligned(16),vector_size(16)));
void E_25_4(float* restrict data0) {
  for (int ridx0 = 0; ridx0 < 25; ridx0++) {
    *((float4*)((data0+(ridx0<<2)))) = (float4){3.0f,3.0f,3.0f,3.0f};
  }
}
*** CPU        1 E_25_4                                    arg  1 mem  0.00 GB tm      3.28us/     0.00ms (     0.00 GFLOPS    0.1|0.1     GB/s) 

The debug output gives us a glimpse into how tinygrad performs the computations. It will take the UOp tree, perform a number of transformations on it, and creates one or more kernels - functions that run on the device (potentially many instances in parallel) and do the actual computation.

In this case, the device is CPU, which means the kernel will be just plain sequential C code, which will be compiled with clang into a small piece of position-independent binary, then loaded and executed using ctypes.

The float4 is a common optimization that you see on both CPU and GPU - it’s more optimal to access memory in 128-byte chunks (4 32-bit floats) at a time, so TinyGrad is being smart here. The optimal number might be device-specific, but 128 bytes is common.

And of course, since we used constants in our computation, there was no need to add 1+2 - TinyGrad was able to just fill the output with the correct value.

If we ran it on an NVida GPU, it would instead generate and run CUDA code, same for other devices.

We will cover the details of transformations done on the UOps tree at a later time, but for now, let’s look at the result.

Here is the buffer that contains the data:

print(type(b_realized.lazydata.base.realized))

b_realized.lazydata.base.realized
<class 'tinygrad.device.Buffer'>
<buf real:True device:CPU size:100 dtype:dtypes.float offset:0>

Since we used the CPU device, it’s in CPU memory, and we can peek into it directly using memoryview

view = memoryview(b_realized.lazydata.base.realized._buf)
view[:4].hex()
'00004040'

0x00004040 is the hex for float32 ‘3.0’. Let’s use numpy to get a better view.

import numpy as np

# Note: The buffer is shapeless, so we use `.reshape()` to bring it back to the correct shape
np.frombuffer(view, dtype=np.float32).reshape(b.shape)
array([[3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]], dtype=float32)

.numpy() and .tolist()

Of course there is a more convenient way to access the result from Python - use .numpy() on the tensor.

This will make sure the tensor ends up on CPU, realize it, and will give the result the correct shape and dtype.

.numpy() will allso create a copy of the data, so the memory buffer does not just disappear from under our feer when the tensor gets garbage collected.

b.numpy()
*** CPU        2 E_25_4                                    arg  1 mem  0.00 GB tm      8.03us/     0.01ms (     0.00 GFLOPS    0.0|0.0     GB/s) 
array([[3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]], dtype=float32)

Or you can use .tolist() to convert it to a python list (or a list of liss of lists … for the correct number of dimensions)

b.tolist()
*** CPU        3 E_25_4                                    arg  1 mem  0.00 GB tm      8.17us/     0.02ms (     0.00 GFLOPS    0.0|0.0     GB/s) 
[[3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0]]

Kernels on GPU

TinyGrad has a concept of “default device”, which we set using the “CPU” env variable in the beginning of the notebook.

The device and dtype can be set when creating the tensor, and you can also use .to() to copy the tensor to a different device.

a = Tensor(1, device="CUDA", dtype=tg.dtypes.float16)
a
<Tensor <UOp CUDA () half (<Ops.CONST: 74>, None)> on CUDA with grad None>
a.to(device="CPU")
<Tensor <UOp CPU () half (<Ops.COPY: 9>, None)> on CPU with grad None>

Let’s have a look at a CUDA kernel for the same computation

a = Tensor.ones((10, 10), device="CUDA")
b = a + 2
b.numpy()
opened device CUDA from pid:958683
E_25_4n1
 0: (25, 4)                   float.ptr(100)       (4, 1)
[Opt(op=OptOps.UPCAST, axis=0, arg=4)]
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
extern "C" __global__ void __launch_bounds__(1) E_25_4n1(float* data0) {
  int gidx0 = blockIdx.x; /* 25 */
  *((float4*)((data0+(gidx0<<2)))) = make_float4(3.0f,3.0f,3.0f,3.0f);
}
*** CUDA       4 E_25_4n1                                  arg  1 mem  0.00 GB tm     25.60us/     0.05ms (     0.00 GFLOPS    0.0|0.0     GB/s) 
*** CPU        5 copy      400,     CPU <- CUDA            arg  2 mem  0.00 GB tm     55.73us/     0.10ms (     0.00 GFLOPS    0.0|0.0     GB/s) 
array([[3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]], dtype=float32)

We see a similar pattern, but since the GPU is highly parallel, TinyGrad decided to create 25 threads, each setting its own 4-float chunk.

If you are familiar with CUDA, you might notice that we created 25 separate thread blocks with 1 thread each instead of 1 block with 25 threads, which is definitely suboptimal.

b_realized = b.contiguous().realize()
*** CUDA       6 E_25_4n1                                  arg  1 mem  0.00 GB tm    574.21us/     0.68ms (     0.00 GFLOPS    0.0|0.0     GB/s) 
print(b_realized.lazydata.base.realized)
print(b_realized.lazydata.base.realized._buf)
<buf real:True device:CUDA size:100 dtype:dtypes.float offset:0>
c_ulong(140551080902656)

As we can see, the output buffer is on the GPU this time, so we can’t access it from the CPU directly.

But trust me, the data is definitely there. Let’s use PyCuda to peek into the GPU memory.

import pycuda
import pycuda.driver as cuda
import numpy as np

# Create a numpy array to hold the data (100 32-bit floats)
cpu_array = np.empty((10, 10), dtype=np.float32)

# Copy data from GPU to CPU
cuda.memcpy_dtoh(cpu_array, b_realized.lazydata.base.realized._buf.value)

cpu_array
array([[3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
       [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]], dtype=float32)

TinyGrad is a complex beast, so it’s normal if this intro left you with more questions than answers. :)