spicy = (randoms[:12 ].at[0 ].mul(10000 )
.at[1 ].divide(10000 )
.at[3 ].set (float ('inf' ))
.at[4 ].set (float ('-inf' ))
.at[5 ].set (float ('nan' ))
.reshape((2 ,6 )))
source
jax_to_str_common
jax_to_str_common (x:jax.Array, color=True, ddof=0)
x
Array
Input
color
bool
True
ANSI color highlighting
ddof
int
0
For “std” unbiasing
source
lovely
lovely (x:jax.Array, verbose=False, plain=False, depth=0, color=None)
x
Array
Tensor of interest
verbose
bool
False
Whether to show the full tensor
plain
bool
False
Just print if exactly as before
depth
int
0
Show stats in depth
color
NoneType
None
Force color (True/False) or auto.
Examples
print (lovely(randoms[0 ]))
print (lovely(randoms[:2 ]))
print (lovely(randoms[:6 ].reshape((2 , 3 )))) # More than 2 elements -> show statistics
print (lovely(randoms[:11 ])) # More than 10 -> suppress data output
Array cpu:0 -1.981
Array[2] μ=-0.466 σ=1.515 cpu:0 [-1.981, 1.048]
Array[2, 3] n=6 x∈[-1.981, 1.048] μ=-0.017 σ=1.113 cpu:0 [[-1.981, 1.048, 0.890], [0.035, -0.947, 0.851]]
Array[11] x∈[-1.981, 1.048] μ=-0.191 σ=0.899 cpu:0
grad = jnp.array(1. , dtype= jnp.float16)
print (lovely(grad)); print (lovely(grad+ 1 ))
Array f16 cpu:0 1.000
Array f16 cpu:0 2.000
# if torch.cuda.is_available():
# print(lovely(torch.tensor(1., device=torch.device("cuda:0"))))
# test_eq(str(lovely(torch.tensor(1., device=torch.device("cuda:0")))), "tensor cuda:0 1.000")
Do we have any floating point nasties? Is the tensor all zeros?
# Statistics and range are calculated on good values only, if there are at lest 3 of them.
lovely(spicy)
Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 +Inf! -Inf! NaN! cpu:0
lovely(spicy, color= False )
Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 +Inf! -Inf! NaN! cpu:0
str (lovely(jnp.array([float ("nan" )]* 11 )))
'Array[11] \x1b[31mNaN!\x1b[0m cpu:0'
Array[12] all_zeros cpu:0
lovely(jnp.array([], dtype= jnp.float16).reshape((0 ,0 ,0 )))
Array[0, 0, 0] f16 empty cpu:0
lovely(jnp.array([1 ,2 ,3 ], dtype= jnp.int32))
Array[3] i32 x∈[1, 3] μ=2.000 σ=0.816 cpu:0 [1, 2, 3]
jnp.set_printoptions(linewidth= 120 , precision= 2 )
lovely(spicy, verbose= True )
Array[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 +Inf! -Inf! NaN! cpu:0
Array([[-1.98e+04, 1.05e-04, 8.90e-01, inf, -inf, nan],
[ 3.12e-02, -3.90e-01, 1.32e-02, -4.21e-01, -1.23e+00, -1.25e+00]], dtype=float32)
lovely(spicy, plain= True )
Array([[-1.98e+04, 1.05e-04, 8.90e-01, inf, -inf, nan],
[ 3.12e-02, -3.90e-01, 1.32e-02, -4.21e-01, -1.23e+00, -1.25e+00]], dtype=float32)
image = jnp.load("mysteryman.npy" )
image = image.at[1 ,2 ,3 ].set (float ('nan' ))
lovely(image, depth= 2 ) # Limited by set_config(deeper_lines=N)
Array[3, 196, 196] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 NaN! cpu:0
Array[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036 cpu:0
Array[196] x∈[-1.912, 2.249] μ=-0.673 σ=0.521 cpu:0
Array[196] x∈[-1.861, 2.163] μ=-0.738 σ=0.417 cpu:0
Array[196] x∈[-1.758, 2.198] μ=-0.806 σ=0.396 cpu:0
Array[196] x∈[-1.656, 2.249] μ=-0.849 σ=0.368 cpu:0
Array[196] x∈[-1.673, 2.198] μ=-0.857 σ=0.356 cpu:0
Array[196] x∈[-1.656, 2.146] μ=-0.848 σ=0.371 cpu:0
Array[196] x∈[-1.433, 2.215] μ=-0.784 σ=0.396 cpu:0
Array[196] x∈[-1.279, 2.249] μ=-0.695 σ=0.485 cpu:0
Array[196] x∈[-1.364, 2.249] μ=-0.637 σ=0.538 cpu:0
...
Array[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973 NaN! cpu:0
Array[196] x∈[-1.861, 2.411] μ=-0.529 σ=0.555 cpu:0
Array[196] x∈[-1.826, 2.359] μ=-0.562 σ=0.472 cpu:0
Array[196] x∈[-1.756, 2.376] μ=-0.622 σ=0.458 NaN! cpu:0
Array[196] x∈[-1.633, 2.429] μ=-0.664 σ=0.429 cpu:0
Array[196] x∈[-1.651, 2.376] μ=-0.669 σ=0.398 cpu:0
Array[196] x∈[-1.633, 2.376] μ=-0.701 σ=0.390 cpu:0
Array[196] x∈[-1.563, 2.429] μ=-0.670 σ=0.379 cpu:0
Array[196] x∈[-1.475, 2.429] μ=-0.616 σ=0.385 cpu:0
Array[196] x∈[-1.511, 2.429] μ=-0.593 σ=0.398 cpu:0
...
Array[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178 cpu:0
Array[196] x∈[-1.717, 2.396] μ=-0.982 σ=0.349 cpu:0
Array[196] x∈[-1.752, 2.326] μ=-1.034 σ=0.313 cpu:0
Array[196] x∈[-1.648, 2.379] μ=-1.086 σ=0.313 cpu:0
Array[196] x∈[-1.630, 2.466] μ=-1.121 σ=0.304 cpu:0
Array[196] x∈[-1.717, 2.448] μ=-1.120 σ=0.301 cpu:0
Array[196] x∈[-1.717, 2.431] μ=-1.166 σ=0.313 cpu:0
Array[196] x∈[-1.560, 2.448] μ=-1.124 σ=0.325 cpu:0
Array[196] x∈[-1.421, 2.431] μ=-1.064 σ=0.382 cpu:0
Array[196] x∈[-1.526, 2.396] μ=-1.047 σ=0.416 cpu:0
...
# We don't really supposed complex numbers yet
c = jnp.array([- 0.4011 - 0.4035j , 1.1300 + 0.0788j , - 0.0277 + 0.9978j , - 0.4636 + 0.6064j , - 1.1505 - 0.9865j ])
lovely(c)
Array([-0.4 -0.4j , 1.13+0.08j, -0.03+1.j , -0.46+0.61j, -1.15-0.99j], dtype=complex64)
assert jax.__version_info__[0 ] == 0
if jax.__version_info__[1 ] >= 4 :
from jax.sharding import PositionalSharding
from jax.experimental import mesh_utils
sharding = PositionalSharding(mesh_utils.create_device_mesh((8 ,1 )))
x = jax.random.normal(jax.random.PRNGKey(0 ), (8192 , 8192 ))
y = jax.device_put(x, sharding)
jax.debug.visualize_array_sharding(y)
else :
# Note: Looks like ShardedDeviceArray needs an explicit device axis?
x = jax.random.normal(jax.random.PRNGKey(0 ), (8 , 1024 , 8192 ))
y = jax.device_put_sharded([x for x in x], jax.devices())
print (lovely(x))
print (lovely(y))
CPU 0
CPU 1
CPU 2
CPU 3
CPU 4
CPU 5
CPU 6
CPU 7
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=-0.000 σ=1.000 cpu:0
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=-0.000 σ=1.000 cpu:0,1,2,3,4,5,6,7