print(lovely(randoms[0]))print(lovely(randoms[:2]))print(lovely(randoms[:6].reshape((2, 3)))) # More than 2 elements -> show statisticsprint(lovely(randoms[:11])) # More than 10 -> suppress data output
assert jax.__version_info__[0] ==0from jax.sharding import NamedSharding, Mesh, PartitionSpec as Pfrom jax.experimental import mesh_utilsprint("=== Test 1: NamedSharding with 2D mesh (4,2) and P('y', 'x') ===")devices = mesh_utils.create_device_mesh((4, 2))mesh = Mesh(devices, axis_names=('y', 'x')) # x has 4 devices, y has 2sharding = NamedSharding(mesh, P('y', 'x')) # Shard array dim 0 across y, dim 1 across xx = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))y = jax.device_put(x, sharding)jax.debug.visualize_array_sharding(y)print(lovely(y))print("\n=== Test 2: NamedSharding with P('y', None) - replicate first dim ===")sharding2 = NamedSharding(mesh, P('y', None))y2 = jax.device_put(x, sharding2)jax.debug.visualize_array_sharding(y2)print(lovely(y2))print("\n=== Test 3: NamedSharding with P(None, 'x') - replicate second dim ===")sharding3 = NamedSharding(mesh, P(None, 'x'))y3 = jax.device_put(x, sharding3)jax.debug.visualize_array_sharding(y3)print(lovely(y3))print("\n=== Test 4: 1D mesh with 8 devices ===")devices_1d = mesh_utils.create_device_mesh((8,))mesh_1d = Mesh(devices_1d, axis_names=('x',))sharding_1d = NamedSharding(mesh_1d, P('x', None))y4 = jax.device_put(x, sharding_1d)jax.debug.visualize_array_sharding(y4)print(lovely(y4))
=== Test 1: NamedSharding with 2D mesh (4,2) and P('y', 'x') ===
CPU 0 CPU 1 CPU 2 CPU 3 CPU 4 CPU 5 CPU 6 CPU 7
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[y,x] 4×2 cpu:0-7
=== Test 2: NamedSharding with P('y', None) - replicate first dim ===
CPU 0,1 CPU 2,3 CPU 4,5 CPU 6,7
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[y,·] 4×2 cpu:0-7
=== Test 3: NamedSharding with P(None, 'x') - replicate second dim ===
CPU 0,2,4,6 CPU 1,3,5,7
Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[·,x] 4×2 cpu:0-7
=== Test 4: 1D mesh with 8 devices ===