import os
"CPU"] = "1"
os.environ["DEBUG"]="4"
os.environ[
from tinygrad import dtypes
from tinygrad.ops import UOp, Ops
2 - View and ShapeTracker
So far we have been making scalar UOps that don’t have a shape associated with them.
While we have been getting away with it so far, the UOp trees we made are not really valid without a shape.
= UOp.const(dtypes.float, 1)
a a
UOp(Ops.CONST, dtypes.float, arg=1.0, src=())
try:
print(a.shape)
except Exception as e:
print_last_frame_context(e)
AssertionError in /home/xl0/work/projects/grads/tinygrad/tinygrad/helpers.py:61 in unwrap()
Code context:
59 return ret
60 def unwrap(x:Optional[T]) -> T:
---> 61 assert x is not None
62 return x
63 def get_single_element(x:list[T]) -> T:
Another thing we were missing is the device:
try:
print(a.device)
except Exception as e:
print_last_frame_context(e)
AssertionError in /home/xl0/work/projects/grads/tinygrad/tinygrad/helpers.py:61 in unwrap()
Code context:
59 return ret
60 def unwrap(x:Optional[T]) -> T:
---> 61 assert x is not None
62 return x
63 def get_single_element(x:list[T]) -> T:
Let’s fix this real quick
from tinygrad.shape.shapetracker import ShapeTracker, View
= UOp.const(dtypes.float, 1).replace(src=(
a =ShapeTracker.from_shape( (0,) ), src=(
UOp(Ops.VIEW, dtypes.void, arg="CPU", src=()),)),))
UOp(Ops.DEVICE, dtypes.void, arg a
UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(0,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),))
a.shape
(0,)
a.device
'CPU'
Looks better.
Now, what’s up with that ShapeTracker
and View
. Let’s start with the later.
Shape and Stride
You are probably familiar with how shape and strides work in Pytorch or Numpy:
import torch
= torch.linspace(0, 31, 32, dtype=torch.int32).view(4, 8)
a a
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
a.shape
torch.Size([4, 8])
The shape defined the, well, the shape of the array. It can have any number of dimensions (2 in this case), and each dimension has its size.
A Tensor is just a linear array, and the shape is there for convenience, because we usually want to work with multi-dimensional data.
We can change the shape, as long as the number of elements in the new shape stays the same.
= a.view(2,4,4) # This creates a view that refers to the same data, but now it's seen as a 3-d array.
b b
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]], dtype=torch.int32)
The stride tells us how many elements do we need to move in the underlying 1-d array (base), to get to the next element in the given dimension.
For out 2x4x4 array, to move 1 element in the row (last dimension), we need to move … 1 element in the base.
And to move by one element in the column dimension, we need to move by 4 elements in the base, because each row is 4 elements.
This is the standard
C
, orrow-major
order format for multidimensional data.
Row-major and Column-major order
You might have seen references to the
F
, orcolumn-major
order at some point. Historically this is how data was stored in Fortran, and I’m sure they had their reasons for it, but it’s definitely less intuitive.
To move in the next dimension, we’d have to skip 4 columns, and for each column we skip 4 elements, so 16 in total:
b.stride()
(16, 4, 1)
Now, if the stride always matched the shape, things would be boring. We can set the stride independently.
Let’s go back to our 4x8 view to make things easies. In this case we need to skip 8 elements to move by one row:
print(a)
print("Shape: ",a.shape)
print("Stride:",a.stride())
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
Shape: torch.Size([4, 8])
Stride: (8, 1)
What if we want to create a view that would skip every other element in the rows? We can do this by creatig a view with shape (torch refers to it as size
) 4x4, and stride (8, 2)!
= a.as_strided(size=(4,4), stride=(8, 2))
c c
tensor([[ 0, 2, 4, 6],
[ 8, 10, 12, 14],
[16, 18, 20, 22],
[24, 26, 28, 30]], dtype=torch.int32)
We can also specify an offset from the start of the base array. This will give us the odd elements in each row:
=(4,4), stride=(8, 2), storage_offset=1) a.as_strided(size
tensor([[ 1, 3, 5, 7],
[ 9, 11, 13, 15],
[17, 19, 21, 23],
[25, 27, 29, 31]], dtype=torch.int32)
Let’s create a view that has the diagonal elements of a
(0, 9, 18, 27)
=(4,), stride=(9,)) a.as_strided(size
tensor([ 0, 9, 18, 27], dtype=torch.int32)
Another fun thing we can do - set one of more of the strides to 0, to duplicate (broadcast) dimensions:
= torch.linspace(1, 4, 4, dtype=torch.int32)
d d
tensor([1, 2, 3, 4], dtype=torch.int32)
=(4,4), stride=(1, 0)) d.as_strided(size
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]], dtype=torch.int32)
For each step in the output column, we take 1 step in the base, and for each step in the output row, we don’t take any steps at all!
That’s how .full()
works - it creates 1 single element, and makes all elements in the Tensor refer to it by setting the strides to 0.
= torch.Tensor([1]).to(torch.int32)
e e
tensor([1], dtype=torch.int32)
=(4,4), stride=(0,0)) e.as_strided(size
tensor([[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]], dtype=torch.int32)
class View
The View
class is intended to keep track of the shape and stride of the data. Let’s play with it a bit.
= View(shape=(4,8), strides=(8,1), offset=0, mask=None, contiguous=True)
v # A normal array 4x8 v
View(shape=(4, 8), strides=(8, 1), offset=0, mask=None, contiguous=True)
a
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
=v.shape, stride=v.strides) a.as_strided(size
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31]], dtype=torch.int32)
= v.reshape( (32,) ) # 1-d array of 32 elements
v32 v32
View(shape=(32,), strides=(1,), offset=0, mask=None, contiguous=True)
=v32.shape, stride=v32.strides) a.as_strided(size
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
dtype=torch.int32)
= v.flip( (False, True) ) # Flip the last dimension
v_flip v_flip
View(shape=(4, 8), strides=(8, -1), offset=7, mask=None, contiguous=False)
try:
=v_flip.shape, stride=v_flip.strides)
a.as_strided(sizeexcept Exception as e:
0) print_last_frame_context(e,
RuntimeError in /tmp/ipykernel_795477/772861321.py:2 in <module>()
Code context:
---> 2 a.as_strided(size=v_flip.shape, stride=v_flip.strides)
Oops, torch actually does not support negative strides. This should have looked like this:
1)) a.flip((
tensor([[ 7, 6, 5, 4, 3, 2, 1, 0],
[15, 14, 13, 12, 11, 10, 9, 8],
[23, 22, 21, 20, 19, 18, 17, 16],
[31, 30, 29, 28, 27, 26, 25, 24]], dtype=torch.int32)
The Mask
Now, what’s up with the mask? It allows us to create arrays with elements that are outside of the underlying storage!
For example, if we want to pad a 2-d array, we don’t want to allocate a new array - just mark the padded elements as being not valid!
v
View(shape=(4, 8), strides=(8, 1), offset=0, mask=None, contiguous=True)
2,2,),(2,2))) # left-right, top-bottom v.pad(((
View(shape=(8, 12), strides=(8, 1), offset=-18, mask=((2, 6), (2, 10)), contiguous=False)
Torch does not allow negative offsets either, but I think the idea is clear:
class ShapeTracker
Not all transforms can be represented with a single View
.
= View.create((3,2))
v = v.pad(((1,1),(1,1)))
v_padded v_padded
View(shape=(5, 4), strides=(2, 1), offset=-3, mask=((1, 4), (1, 3)), contiguous=False)
= v_padded.reshape((20,)) # Linearize into a 1-d array
v_padded_reshaped print(v_padded_reshaped)
None
Oops, we get a None
, which means the operation could not be performed!
It makes sense, because we can’t specify the valid mask of the linearized result using just the start/stop indices.
The ShapeTracker
keeps a list of sequentially applied Views.
= ShapeTracker((v,))
st st
ShapeTracker(views=(View(shape=(3, 2), strides=(2, 1), offset=0, mask=None, contiguous=True),))
If some of the views can be merged together, it will do so.
In this example, we reshaped 3x2 -> 2x3 and flipped it along the first axis and padded it on all sides by 1.
This can be represented with a single View
:
= st.reshape((2,3)).flip((True, False)).pad( ((1,1),(1,1)) )
st_rfp st_rfp
ShapeTracker(views=(View(shape=(4, 5), strides=(-3, 1), offset=5, mask=((1, 3), (1, 4)), contiguous=False),))
If the transformation can not be represented with a single View
, the Shapetracker will keep them separate
20,) ) st_rfp.reshape( (
ShapeTracker(views=(View(shape=(4, 5), strides=(-3, 1), offset=5, mask=((1, 3), (1, 4)), contiguous=False), View(shape=(20,), strides=(1,), offset=0, mask=None, contiguous=True)))
We can also generate the UOp trees that represent the expressoins for indexing and valudating memory acces in the code:
= st_rfp.to_indexed_uops()
idx, valid idx
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.RANGE, dtypes.int, arg=0, src=(
x3:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),
UOp(Ops.CONST, dtypes.int, arg=4, src=()),)),
UOp(Ops.CONST, dtypes.int, arg=-3, src=()),)),
UOp(Ops.RANGE, dtypes.int, arg=1, src=(
x3,
x7:=UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),)),
x7,))
valid
UOp(Ops.AND, dtypes.bool, arg=None, src=(
UOp(Ops.AND, dtypes.bool, arg=None, src=(
UOp(Ops.AND, dtypes.bool, arg=None, src=(
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x4:=UOp(Ops.RANGE, dtypes.int, arg=0, src=(
x5:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),
x6:=UOp(Ops.CONST, dtypes.int, arg=4, src=()),)),
x7:=UOp(Ops.CONST, dtypes.int, arg=1, src=()),)),
x8:=UOp(Ops.CONST, dtypes.bool, arg=True, src=()),)),
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x4,
UOp(Ops.CONST, dtypes.int, arg=3, src=()),)),)),
UOp(Ops.CMPNE, dtypes.bool, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x13:=UOp(Ops.RANGE, dtypes.int, arg=1, src=(
x5,
UOp(Ops.CONST, dtypes.int, arg=5, src=()),)),
x7,)),
x8,)),)),
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x13,
x6,)),))
Let’s use .render()
to try and convert the UOps into equivalent C expressions:
Note:
.render()
is not how TinyGrad normally generates the code, it’s for debug purpose only.
idx.render()
'(((ridx0*-3)+ridx1)+5)'
And valid
can be used to check the validity of input elements in and if
statement:
valid.render()
'(((((ridx0<1)!=True)&(ridx0<3))&((ridx1<1)!=True))&(ridx1<4))'