lovely-jax
  1. Data representations
  2. πŸ–ŒοΈ View as RGB images
  • πŸ’˜ Lovely JAX
  • Data representations
    • 🧾 View as a summary
    • πŸ–ŒοΈ View as RGB images
    • πŸ“Š View as a histogram
    • πŸ“Ί View channels
  • Misc
    • πŸ€” Config
    • πŸ™ˆ Monkey-patching
    • 🎭 Matplotlib integration
  1. Data representations
  2. πŸ–ŒοΈ View as RGB images

πŸ–ŒοΈ View as RGB images


source

rgb

 rgb (x:jax.Array, denorm:Any=None, cl:Any=True, gutter_px:int=3,
      frame_px:int=1, scale:int=1, view_width:int=966,
      ax:Optional[matplotlib.axes._axes.Axes]=None)
Type Default Details
x Array Tensor to display. [[…], C,H,W] or [[…], H,W,C]
denorm Any None Reverse per-channel normalizatoin
cl Any True Channel-last
gutter_px int 3 If more than one tensor -> tile with this gutter width
frame_px int 1 If more than one tensor -> tile with this frame width
scale int 1 Scale up. Can’t scale down.
view_width int 966 target width of the image
ax Optional None Use this Axes
Returns RGBProxy
rgb(image)

rgb(image, scale=2)

two_images = jnp.stack([image]*2)
two_images
Array[2, 196, 196, 3] n=230496 x∈[-2.118, 2.640] ΞΌ=-0.388 Οƒ=1.073 cpu:0
in_stats = (    (0.485, 0.456, 0.406),  # Mean
                (0.229, 0.224, 0.225) ) # std
rgb(two_images, denorm=in_stats)

# Make 8 images with progressively higher brightness and stack them 2x2x2.

eight_images = (jnp.stack([image]*8) + jnp.linspace(-2, 2, 8)[:,None,None,None])
eight_images = (eight_images
                     *jnp.array(in_stats[1])
                     +jnp.array(in_stats[0])
                ).clip(0,1).reshape(2,2,2,196,196,3)

eight_images
Array[2, 2, 2, 196, 196, 3] n=921984 x∈[0., 1.000] ΞΌ=0.382 Οƒ=0.319 cpu:0
rgb(eight_images)

# You can do channel-last too:
rgb(image.transpose(2, 0, 1), cl=False)

  • Report an issue