fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(6, 2))
plt.close(fig)
ax1.set_xticks([]); ax1.set_yticks([]); ax2.set_xticks([]);
ax2.set_yticks([]); ax3.set_xticks([]); ax3.set_yticks([])
np.random.seed(1337)
r = np.random.rand(10, 10, 3)
x = (r*256).astype(np.uint8)
fig_rgb(x, scale=10, ax=ax1)
x = r.astype(np.float16)
fig_rgb(x, scale=10, ax=ax2)
fig_rgb((r > 0.5), scale=10, ax=ax3)


