lovely-jax
  1. Data representations
  2. ๐Ÿ“บ View channels
  • ๐Ÿ’˜ Lovely JAX
  • Data representations
    • ๐Ÿงพ View as a summary
    • ๐Ÿ–Œ๏ธ View as RGB images
    • ๐Ÿ“Š View as a histogram
    • ๐Ÿ“บ View channels
  • Misc
    • ๐Ÿค” Config
    • ๐Ÿ™ˆ Monkey-patching
    • ๐ŸŽญ Matplotlib integration
  1. Data representations
  2. ๐Ÿ“บ View channels

๐Ÿ“บ View channels


source

chans

 chans (x:jax.Array, cmap:str='twilight', cm_below:str='blue',
        cm_above:str='red', cm_ninf:str='cyan', cm_pinf:str='fuchsia',
        cm_nan:str='yellow', view_width:int=966, gutter_px:int=3,
        frame_px:int=1, scale:int=1, cl:Any=True,
        ax:Optional[matplotlib.axes._axes.Axes]=None)

Map tensor values to colors. RGB[A] color is added as channel-last

Type Default Details
x Array Input, shape=([โ€ฆ], H, W)
cmap str twilight Use matplotlib colormap by this name
cm_below str blue Color for values below -1
cm_above str red Color for values above 1
cm_ninf str cyan Color for -inf values
cm_pinf str fuchsia Color for +inf values
cm_nan str yellow Color for NaN values
view_width int 966 Try to produce an image at most this wide
gutter_px int 3 Draw write gutters when tiling the images
frame_px int 1 Draw black frame around each image
scale int 1
cl Any True
ax Optional None
Returns ChanProxy
in_stats = ( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) )

image = jnp.load("mysteryman.npy").transpose(1,2,0)
image = (image * jnp.array(in_stats[1]))
image += jnp.array(in_stats[0])

image.rgb

chans(image)

# In R
image = image.at[0:32,32:64:,0].set(-1.1) # Below min
image = image.at[0:32,96:128,0].set(1.1) # Above max
# In G
image = image.at[0:32,64:96,1].set(float("nan"))
# In B
image = image.at[0:32,0:32,2].set(float("-inf"))
image = image.at[0:32,128:128+32,2].set(float("+inf"))

chans(image, cmap="viridis", cm_below="black", cm_above="white")
/ssd/xl0/work/projects/lovely-numpy/lovely_numpy/utils/colormap.py:64: RuntimeWarning: invalid value encountered in cast
  lut_idxs = (vals * cmax).astype(np.int64)

# 4 images, stacked 2x2
chans(jnp.stack([image]*4).reshape(2,2,196,196,3))

# # |hide
# if torch.cuda.is_available():
#     cudamem = torch.cuda.memory_allocated()
#     print(f"before allocation: {torch.cuda.memory_allocated()=}")
#     numbers = torch.ones((2,64, 512), device="cuda")
#     torch.cuda.synchronize()
#     print(f"after allocation: {torch.cuda.memory_allocated()=}")
#     display(chans(numbers))
#     print(f"after rgb: {torch.cuda.memory_allocated()=}")
   
#     del numbers
#     gc.collect()
#     # torch.cuda.memory.empty_cache()
#     # torch.cuda.synchronize()

#     print(f"after cleanup: {torch.cuda.memory_allocated()=}")
#     test_eq(cudamem >= torch.cuda.memory_allocated(), True)
  • Report an issue