Adventures with TinyGrad
  1. 2 - View and ShapeTracker
  • Adventures with TinyGrad
  • 0 - Introduction
  • 1 - UOps
  • 2 - View and ShapeTracker
  • 3 - The Pattern Matcher
  • 4 - The .arange() insanity
  • Appendix A: helpers.py
  • Appendix B - UOp Summary
  • Misc 1 - elf.py and the ELF format
  • Misc 2 - CUDA Runtime library

On this page

  • Shape and Stride
  • class View
  • class ShapeTracker

2 - View and ShapeTracker

So far we have been making scalar UOps that don’t have a shape associated with them.

While we have been getting away with it so far, the UOp trees we made are not really valid without a shape.

import os
os.environ["CPU"] = "1"
os.environ["DEBUG"]="4"

from tinygrad import  dtypes
from tinygrad.ops import UOp, Ops
a = UOp.const(dtypes.float, 1)
a
UOp(Ops.CONST, dtypes.float, arg=1.0, src=())
try:
    print(a.shape)
except Exception as e:
    print_last_frame_context(e)
AssertionError in /home/xl0/work/projects/grads/tinygrad/tinygrad/helpers.py:61 in unwrap()

Code context:
       59   return ret
       60 def unwrap(x:Optional[T]) -> T:
--->   61   assert x is not None
       62   return x
       63 def get_single_element(x:list[T]) -> T:

Another thing we were missing is the device:

try:
    print(a.device)
except Exception as e:
    print_last_frame_context(e)
AssertionError in /home/xl0/work/projects/grads/tinygrad/tinygrad/helpers.py:61 in unwrap()

Code context:
       59   return ret
       60 def unwrap(x:Optional[T]) -> T:
--->   61   assert x is not None
       62   return x
       63 def get_single_element(x:list[T]) -> T:

Let’s fix this real quick

from tinygrad.shape.shapetracker import ShapeTracker, View

a = UOp.const(dtypes.float, 1).replace(src=(
        UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker.from_shape( (0,) ), src=(
            UOp(Ops.DEVICE, dtypes.void, arg="CPU", src=()),)),))
a
UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
  UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(0,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=(
    UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),))
a.shape
(0,)
a.device
'CPU'

Looks better.

Now, what’s up with that ShapeTracker and View. Let’s start with the later.

Shape and Stride

You are probably familiar with how shape and strides work in Pytorch or Numpy:

import torch
a = torch.linspace(0, 31, 32, dtype=torch.int32).view(4, 8)
a
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
a.shape
torch.Size([4, 8])

The shape defined the, well, the shape of the array. It can have any number of dimensions (2 in this case), and each dimension has its size.

A Tensor is just a linear array, and the shape is there for convenience, because we usually want to work with multi-dimensional data.

We can change the shape, as long as the number of elements in the new shape stays the same.

b = a.view(2,4,4) # This creates a view that refers to the same data, but now it's seen as a 3-d array.
b
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]], dtype=torch.int32)

The stride tells us how many elements do we need to move in the underlying 1-d array (base), to get to the next element in the given dimension.

For out 2x4x4 array, to move 1 element in the row (last dimension), we need to move … 1 element in the base.

And to move by one element in the column dimension, we need to move by 4 elements in the base, because each row is 4 elements.

This is the standard C, or row-major order format for multidimensional data.

Row-major and Column-major order

You might have seen references to the F, or column-major order at some point. Historically this is how data was stored in Fortran, and I’m sure they had their reasons for it, but it’s definitely less intuitive.

To move in the next dimension, we’d have to skip 4 columns, and for each column we skip 4 elements, so 16 in total:

b.stride()
(16, 4, 1)

Now, if the stride always matched the shape, things would be boring. We can set the stride independently.

Let’s go back to our 4x8 view to make things easies. In this case we need to skip 8 elements to move by one row:

print(a)
print("Shape: ",a.shape)
print("Stride:",a.stride())
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
Shape:  torch.Size([4, 8])
Stride: (8, 1)

What if we want to create a view that would skip every other element in the rows? We can do this by creatig a view with shape (torch refers to it as size) 4x4, and stride (8, 2)!

c = a.as_strided(size=(4,4), stride=(8, 2))
c
tensor([[ 0,  2,  4,  6],
        [ 8, 10, 12, 14],
        [16, 18, 20, 22],
        [24, 26, 28, 30]], dtype=torch.int32)

We can also specify an offset from the start of the base array. This will give us the odd elements in each row:

a.as_strided(size=(4,4), stride=(8, 2), storage_offset=1)
tensor([[ 1,  3,  5,  7],
        [ 9, 11, 13, 15],
        [17, 19, 21, 23],
        [25, 27, 29, 31]], dtype=torch.int32)

Let’s create a view that has the diagonal elements of a (0, 9, 18, 27)

a.as_strided(size=(4,), stride=(9,))
tensor([ 0,  9, 18, 27], dtype=torch.int32)

Another fun thing we can do - set one of more of the strides to 0, to duplicate (broadcast) dimensions:

d = torch.linspace(1, 4, 4, dtype=torch.int32)
d
tensor([1, 2, 3, 4], dtype=torch.int32)
d.as_strided(size=(4,4), stride=(1, 0))
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
        [4, 4, 4, 4]], dtype=torch.int32)

For each step in the output column, we take 1 step in the base, and for each step in the output row, we don’t take any steps at all!

That’s how .full() works - it creates 1 single element, and makes all elements in the Tensor refer to it by setting the strides to 0.

e = torch.Tensor([1]).to(torch.int32)
e
tensor([1], dtype=torch.int32)
e.as_strided(size=(4,4), stride=(0,0))
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]], dtype=torch.int32)

class View

The View class is intended to keep track of the shape and stride of the data. Let’s play with it a bit.

v = View(shape=(4,8), strides=(8,1), offset=0, mask=None, contiguous=True)
v # A normal array 4x8
View(shape=(4, 8), strides=(8, 1), offset=0, mask=None, contiguous=True)
a
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
a.as_strided(size=v.shape, stride=v.strides)
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
v32 = v.reshape( (32,) ) # 1-d array of 32 elements
v32
View(shape=(32,), strides=(1,), offset=0, mask=None, contiguous=True)
a.as_strided(size=v32.shape, stride=v32.strides)
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       dtype=torch.int32)
v_flip = v.flip( (False, True) ) # Flip the last dimension
v_flip
View(shape=(4, 8), strides=(8, -1), offset=7, mask=None, contiguous=False)
try:
    a.as_strided(size=v_flip.shape, stride=v_flip.strides)
except Exception as e:
    print_last_frame_context(e, 0)
RuntimeError in /tmp/ipykernel_795477/772861321.py:2 in <module>()

Code context:
--->    2     a.as_strided(size=v_flip.shape, stride=v_flip.strides)

Oops, torch actually does not support negative strides. This should have looked like this:

a.flip((1))
tensor([[ 7,  6,  5,  4,  3,  2,  1,  0],
        [15, 14, 13, 12, 11, 10,  9,  8],
        [23, 22, 21, 20, 19, 18, 17, 16],
        [31, 30, 29, 28, 27, 26, 25, 24]], dtype=torch.int32)

The Mask

Now, what’s up with the mask? It allows us to create arrays with elements that are outside of the underlying storage!

For example, if we want to pad a 2-d array, we don’t want to allocate a new array - just mark the padded elements as being not valid!

v
View(shape=(4, 8), strides=(8, 1), offset=0, mask=None, contiguous=True)
v.pad(((2,2,),(2,2))) # left-right, top-bottom
View(shape=(8, 12), strides=(8, 1), offset=-18, mask=((2, 6), (2, 10)), contiguous=False)

Torch does not allow negative offsets either, but I think the idea is clear:

class ShapeTracker

Not all transforms can be represented with a single View.

v = View.create((3,2))
v_padded = v.pad(((1,1),(1,1)))
v_padded
View(shape=(5, 4), strides=(2, 1), offset=-3, mask=((1, 4), (1, 3)), contiguous=False)
v_padded_reshaped = v_padded.reshape((20,)) # Linearize into a 1-d array
print(v_padded_reshaped)
None

Oops, we get a None, which means the operation could not be performed!

It makes sense, because we can’t specify the valid mask of the linearized result using just the start/stop indices.

The ShapeTracker keeps a list of sequentially applied Views.

st = ShapeTracker((v,))
st
ShapeTracker(views=(View(shape=(3, 2), strides=(2, 1), offset=0, mask=None, contiguous=True),))

If some of the views can be merged together, it will do so.

In this example, we reshaped 3x2 -> 2x3 and flipped it along the first axis and padded it on all sides by 1.

This can be represented with a single View:

st_rfp = st.reshape((2,3)).flip((True, False)).pad( ((1,1),(1,1)) )
st_rfp
ShapeTracker(views=(View(shape=(4, 5), strides=(-3, 1), offset=5, mask=((1, 3), (1, 4)), contiguous=False),))

If the transformation can not be represented with a single View, the Shapetracker will keep them separate

st_rfp.reshape( (20,) )
ShapeTracker(views=(View(shape=(4, 5), strides=(-3, 1), offset=5, mask=((1, 3), (1, 4)), contiguous=False), View(shape=(20,), strides=(1,), offset=0, mask=None, contiguous=True)))

We can also generate the UOp trees that represent the expressoins for indexing and valudating memory acces in the code:

idx, valid = st_rfp.to_indexed_uops()
idx
UOp(Ops.ADD, dtypes.int, arg=None, src=(
  UOp(Ops.ADD, dtypes.int, arg=None, src=(
    UOp(Ops.MUL, dtypes.int, arg=None, src=(
      UOp(Ops.RANGE, dtypes.int, arg=0, src=(
        x3:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),
        UOp(Ops.CONST, dtypes.int, arg=4, src=()),)),
      UOp(Ops.CONST, dtypes.int, arg=-3, src=()),)),
    UOp(Ops.RANGE, dtypes.int, arg=1, src=(
       x3,
      x7:=UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),)),
   x7,))
valid
UOp(Ops.AND, dtypes.bool, arg=None, src=(
  UOp(Ops.AND, dtypes.bool, arg=None, src=(
    UOp(Ops.AND, dtypes.bool, arg=None, src=(
      UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
        UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
          x4:=UOp(Ops.RANGE, dtypes.int, arg=0, src=(
            x5:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),
            x6:=UOp(Ops.CONST, dtypes.int, arg=4, src=()),)),
          x7:=UOp(Ops.CONST, dtypes.int, arg=1, src=()),)),
        x8:=UOp(Ops.CONST, dtypes.bool, arg=True, src=()),)),
      UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
         x4,
        UOp(Ops.CONST, dtypes.int, arg=3, src=()),)),)),
    UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
      UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
        x13:=UOp(Ops.RANGE, dtypes.int, arg=1, src=(
           x5,
          UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),
         x7,)),
       x8,)),)),
  UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
     x13,
     x6,)),))

Let’s use .render() to try and convert the UOps into equivalent C expressions:

Note: .render() is not how TinyGrad normally generates the code, it’s for debug purpose only.

idx.render()
'(((ridx0*-3)+ridx1)+5)'

And valid can be used to check the validity of input elements in and if statement:

valid.render()
'(((((ridx0<1)!=True)&(ridx0<3))&((ridx1<1)!=True))&(ridx1<4))'