lovely-jax
  1. Data representations
  2. 🧾 View as a summary
  • 💘 Lovely JAX
  • Data representations
    • 🧾 View as a summary
    • 🖌️ View as RGB images
    • 📊 View as a histogram
    • 📺 View channels
  • Misc
    • 🤔 Config
    • 🙈 Monkey-patching
    • 🎭 Matplotlib integration
  1. Data representations
  2. 🧾 View as a summary

🧾 View as a summary

spicy = (randoms[:12].at[0].mul(10000)
                    .at[1].divide(10000)
                    .at[3].set(float('inf'))
                    .at[4].set(float('-inf'))
                    .at[5].set(float('nan'))
                    .reshape((2,6)))

source

jax_to_str_common

 jax_to_str_common (x:jax.Array, color=True, ddof=0)
Type Default Details
x Array Input
color bool True ANSI color highlighting
ddof int 0 For “std” unbiasing

source

lovely

 lovely (x:jax.Array, verbose=False, plain=False, depth=0, color=None)
Type Default Details
x Array Tensor of interest
verbose bool False Whether to show the full tensor
plain bool False Just print if exactly as before
depth int 0 Show stats in depth
color NoneType None Force color (True/False) or auto.

Examples

print(lovely(randoms[0]))
print(lovely(randoms[:2]))
print(lovely(randoms[:6].reshape((2, 3)))) # More than 2 elements -> show statistics
print(lovely(randoms[:11]))           # More than 10 -> suppress data output
Array cpu:0 1.623
Array[2] μ=1.824 σ=0.201 cpu:0 [1.623, 2.025]
Array[2, 3] n=6 x∈[-0.972, 2.025] μ=0.390 σ=1.080 cpu:0 [[1.623, 2.025, -0.434], [-0.079, 0.176, -0.972]]
Array[11] x∈[-0.972, 2.180] μ=0.385 σ=1.081 cpu:0
grad = jnp.array(1., dtype=jnp.float16)
print(lovely(grad)); print(lovely(grad+1))
Array f16 cpu:0 1.000
Array f16 cpu:0 2.000
# if torch.cuda.is_available():
#     print(lovely(torch.tensor(1., device=torch.device("cuda:0"))))
#     test_eq(str(lovely(torch.tensor(1., device=torch.device("cuda:0")))), "tensor cuda:0 1.000")

Do we have any floating point nasties? Is the tensor all zeros?

# Statistics and range are calculated on good values only, if there are at lest 3 of them.
lovely(spicy)
Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 +Inf! -Inf! NaN! cpu:0
lovely(spicy, color=False)
Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 +Inf! -Inf! NaN! cpu:0
str(lovely(jnp.array([float("nan")]*11)))
'Array[11] \x1b[31mNaN!\x1b[0m cpu:0'
lovely(jnp.zeros(12))
Array[12] all_zeros cpu:0
lovely(jnp.array([], dtype=jnp.float16).reshape((0,0,0)))
Array[0, 0, 0] f16 empty cpu:0
lovely(jnp.array([1,2,3], dtype=jnp.int32))
Array[3] i32 x∈[1, 3] μ=2.000 σ=0.816 cpu:0 [1, 2, 3]
jnp.set_printoptions(linewidth=120, precision=2)
lovely(spicy, verbose=True)
Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 +Inf! -Inf! NaN! cpu:0
Array([[ 1.62e+04,  2.03e-04, -4.34e-01,       inf,      -inf,       nan],
       [-4.95e-01,  4.94e-01,  6.64e-01, -9.50e-01,  2.18e+00, -1.96e+00]], dtype=float32)
lovely(spicy, plain=True)
Array([[ 1.62e+04,  2.03e-04, -4.34e-01,       inf,      -inf,       nan],
       [-4.95e-01,  4.94e-01,  6.64e-01, -9.50e-01,  2.18e+00, -1.96e+00]], dtype=float32)
image = jnp.load("mysteryman.npy")
image = image.at[1,2,3].set(float('nan'))

lovely(image, depth=2) # Limited by set_config(deeper_lines=N)
Array[3, 196, 196] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 NaN! cpu:0
  Array[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036 cpu:0
    Array[196] x∈[-1.912, 2.249] μ=-0.673 σ=0.521 cpu:0
    Array[196] x∈[-1.861, 2.163] μ=-0.738 σ=0.417 cpu:0
    Array[196] x∈[-1.758, 2.198] μ=-0.806 σ=0.396 cpu:0
    Array[196] x∈[-1.656, 2.249] μ=-0.849 σ=0.368 cpu:0
    Array[196] x∈[-1.673, 2.198] μ=-0.857 σ=0.356 cpu:0
    Array[196] x∈[-1.656, 2.146] μ=-0.848 σ=0.371 cpu:0
    Array[196] x∈[-1.433, 2.215] μ=-0.784 σ=0.396 cpu:0
    Array[196] x∈[-1.279, 2.249] μ=-0.695 σ=0.485 cpu:0
    Array[196] x∈[-1.364, 2.249] μ=-0.637 σ=0.538 cpu:0
    ...
  Array[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973 NaN! cpu:0
    Array[196] x∈[-1.861, 2.411] μ=-0.529 σ=0.555 cpu:0
    Array[196] x∈[-1.826, 2.359] μ=-0.562 σ=0.472 cpu:0
    Array[196] x∈[-1.756, 2.376] μ=-0.622 σ=0.458 NaN! cpu:0
    Array[196] x∈[-1.633, 2.429] μ=-0.664 σ=0.429 cpu:0
    Array[196] x∈[-1.651, 2.376] μ=-0.669 σ=0.398 cpu:0
    Array[196] x∈[-1.633, 2.376] μ=-0.701 σ=0.390 cpu:0
    Array[196] x∈[-1.563, 2.429] μ=-0.670 σ=0.379 cpu:0
    Array[196] x∈[-1.475, 2.429] μ=-0.616 σ=0.385 cpu:0
    Array[196] x∈[-1.511, 2.429] μ=-0.593 σ=0.398 cpu:0
    ...
  Array[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178 cpu:0
    Array[196] x∈[-1.717, 2.396] μ=-0.982 σ=0.349 cpu:0
    Array[196] x∈[-1.752, 2.326] μ=-1.034 σ=0.313 cpu:0
    Array[196] x∈[-1.648, 2.379] μ=-1.086 σ=0.313 cpu:0
    Array[196] x∈[-1.630, 2.466] μ=-1.121 σ=0.304 cpu:0
    Array[196] x∈[-1.717, 2.448] μ=-1.120 σ=0.301 cpu:0
    Array[196] x∈[-1.717, 2.431] μ=-1.166 σ=0.313 cpu:0
    Array[196] x∈[-1.560, 2.448] μ=-1.124 σ=0.325 cpu:0
    Array[196] x∈[-1.421, 2.431] μ=-1.064 σ=0.382 cpu:0
    Array[196] x∈[-1.526, 2.396] μ=-1.047 σ=0.416 cpu:0
    ...
# We don't really supposed complex numbers yet
c = jnp.array([-0.4011-0.4035j,  1.1300+0.0788j, -0.0277+0.9978j, -0.4636+0.6064j, -1.1505-0.9865j])
lovely(c)
Array([-0.4 -0.4j ,  1.13+0.08j, -0.03+1.j  , -0.46+0.61j, -1.15-0.99j], dtype=complex64)
assert jax.__version_info__[0] == 0
if jax.__version_info__[1] >= 4:
    from jax.sharding import PositionalSharding
    from jax.experimental import mesh_utils
    sharding = PositionalSharding(mesh_utils.create_device_mesh((8,1)))
    x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
    y = jax.device_put(x, sharding)

    jax.debug.visualize_array_sharding(y)
else:
    # Note: Looks like ShardedDeviceArray needs an explicit device axis?
    x = jax.random.normal(jax.random.PRNGKey(0), (8, 1024, 8192))
    y = jax.device_put_sharded([x for x in x], jax.devices())

print(lovely(x))
print(lovely(y))
/tmp/ipykernel_487548/3921189633.py:4: DeprecationWarning: jax.sharding.PositionalSharding is deprecated. Use jax.NamedSharding instead.
  from jax.sharding import PositionalSharding
          CPU 0          
                         
          CPU 1          
                         
          CPU 2          
                         
          CPU 3          
                         
          CPU 4          
                         
          CPU 5          
                         
          CPU 6          
                         
          CPU 7          
                         
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 cpu:0
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 cpu:0,1,2,3,4,5,6,7
  • Report an issue