import os
"CPU"] = "1"
os.environ["DEBUG"]="4"
os.environ[
from tinygrad import Tensor, dtypes
from tinygrad.ops import UOp, Ops, PatternMatcher, UPat, graph_rewrite
3 - The Pattern Matcher
Our next TinyGrad abstraction is the Pattern Matcher (PM)
PM is used all over TinyGrad for different purposes
= (Tensor(2) * 5 + 1).lazydata
a a
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CONST, dtypes.int, arg=2, src=(
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),)),
UOp(Ops.CONST, dtypes.int, arg=5, src=(
x2,)),)),
UOp(Ops.CONST, dtypes.int, arg=1, src=(
x2,)),))
The PM operates on a list of rules.
Each rule consists of a UPat
, and a function that is called when the pattern matches part of the tree.
The return value of the function is the result of the match, or it’s a None if no match was found among the rules:
= PatternMatcher([
test_rules lambda: "a DEVICE Uop"), # This rule matches any `DEVICE` UOp
(UPat(Ops.DEVICE), ="x"), lambda x: f"Got a CONST dtype {x.dtype} arg {x.arg}"), # Can pass the Op to the function
(UPat(Ops.CONST, namelambda x: f"Another rule for CONST"), # Oops, only one rule can match!
(UPat(Ops.CONST), lambda: "ADD or MUL"), # Can match more than one UOp type
(UPat((Ops.ADD, Ops.MUL)), =(UPat(Ops.RESHAPE, src=UPat(Ops.CONST, arg=2)))),
(UPat(Ops.EXPAND, srclambda: "Expand with reshape from a const with arg=2") # Can match a specific sub-tree.
# Note: This one only matches the EXPAND for 2, not 1
# No match - return Null
])
for op in a.toposort] [test_rules.rewrite(op)
['a DEVICE Uop',
None,
'Got a CONST dtype dtypes.int arg 2',
'Got a CONST dtype dtypes.int arg 5',
'ADD or MUL',
'Got a CONST dtype dtypes.int arg 1',
'ADD or MUL']
Rewriting trees
A more interesting pattern is to replace the matched UOps with some other UOps. We can also use graph_rewrite
to operate on a tree.
= PatternMatcher([
insanity ="x"), lambda x: UOp(Ops.SUB, dtype=x.dtype, arg=x.arg, src=x.src)),
(UPat(Ops.ADD, name=dtypes.ints, name="x"), lambda x: UOp(Ops.IDIV, dtype=x.dtype, src=x.src))
(UPat(Ops.MUL, dtype
])
= graph_rewrite(a, insanity)
rewritten rewritten
UOp(Ops.SUB, dtypes.int, arg=None, src=(
UOp(Ops.IDIV, dtypes.int, arg=None, src=(
UOp(Ops.CONST, dtypes.int, arg=2, src=(
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),)),
UOp(Ops.CONST, dtypes.int, arg=5, src=(
x2,)),)),
UOp(Ops.CONST, dtypes.int, arg=1, src=(
x2,)),))
=False) a.render(simplify
'((2*5)+1)'
=False) rewritten.render(simplify
'((2//5)-1)'
int(rewritten)
-1
PatternMatcher in TinyGrad
I think you get the idea. The Pattern Matches is a powerful tool that is used throughout Tinygrad.
When we played with Tensor.schedule_with_vars()
and lower_schedule_item()
in the chapter on UOps, both function made extensive use of many Pattern Matchers. We will attempt a deep dive into their details in the next chapter.
TinyGrad spec
Another use for the Pattern Matcher - checking the validity of UOp trees, according to the spec
, , found in tinyngrad/spec.py.
It’s very much possible to create UOp trees that are not valid in general, or not valid at certain stages of processing.
The spec contains rules that check for silly mistakes in different types of (sub-)trees.
For example thre is a tensor_uop_spec
for sanity checking the UOp trees created by tensor operations:
a
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CONST, dtypes.int, arg=2, src=(
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),)),
UOp(Ops.CONST, dtypes.int, arg=5, src=(
x2,)),)),
UOp(Ops.CONST, dtypes.int, arg=1, src=(
x2,)),))
from tinygrad.spec import type_verify, tensor_uop_spec
list(a.toposort.keys()), tensor_uop_spec) # It throws on errors, no errors found! type_verify(
Let’s make a broken tree by changing the dtype of the ADD
UOp in a
to float
:
= a.replace(dtype=dtypes.float)
bad bad
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CONST, dtypes.int, arg=2, src=(
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),)),
UOp(Ops.CONST, dtypes.int, arg=5, src=(
x2,)),)),
UOp(Ops.CONST, dtypes.int, arg=1, src=(
x2,)),))
This is not a valid tree - we are adding 2 ints, but the result is a float? There would need to be a cast there!
try:
list(bad.toposort.keys()), tensor_uop_spec)
type_verify(except Exception as e:
print(f"{type(e).__name__}: {' '.join(e.args)}")
0 Ops.DEVICE : dtypes.void [] CPU
1 Ops.VIEW : dtypes.void [0] ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))
2 Ops.CONST : dtypes.int [1] 2
3 Ops.CONST : dtypes.int [1] 5
4 Ops.MUL : dtypes.int ['2', '5'] None
5 Ops.CONST : dtypes.int [1] 1
6 Ops.ADD : dtypes.float [4, '1'] None
RuntimeError: UOp verification failed at 6 on Ops.ADD dtypes.float 2 [<Ops.MUL: 50>, <Ops.CONST: 76>] None
Indeed, we caught the error. Let’s fix the tree by casting the 2 ADD
sources to a float.
= bad.replace(src=tuple([UOp(Ops.CAST, dtype=dtypes.float, src=(src,)) for src in bad.src]))
fixed fixed
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.int, arg=None, src=(
UOp(Ops.CONST, dtypes.int, arg=2, src=(
x3:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.DEVICE, dtypes.void, arg='CPU', src=()),)),)),
UOp(Ops.CONST, dtypes.int, arg=5, src=(
x3,)),)),)),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CONST, dtypes.int, arg=1, src=(
x3,)),)),))
list(fixed.toposort.keys()), tensor_uop_spec) type_verify(
Now it works!