$\textbf{Problem}$: What is JAX?
$\textbf{Solution}$: JAX = Autograd + XLA, where Autograd refers to automatic differentiation, and XLA refers to accelerated linear algebra (compiler developed by Google that optimizes code to run fast on GPUs/TPUs).
At a high level, JAX is just NumPy on steroids (and indeed, much of the syntax is identical). At a lower level, JAX is a framework for composable function transformations.
$\textbf{Problem}$: What are the $4$ most important JAX transformations?
$\textbf{Solution}$: jit (just-in-time compilation), grad (gradient), vmap (vectorization), and pmap (parallelization).
The philosophy here is that one can write standard Python functions, and JAX will transform them into GPU/TPU-optimized versions.
$\textbf{Problem}$: Compare JAX vs. PyTorch vs. TensorFlow.
$\textbf{Solution}$: JAX uses a functional programming paradigm whereas PyTorch is object-oriented/imperative and TensorFlow is mixed (Keras is OOP but TF Core is graph). JAX therefore has the steepest learning curve (as it requires unlearning OOP habits). However, it’s XLA performance is arguably the fastest, making it an essential tool for research, etc.
$\textbf{Problem}$: What higher-level libraries built on top of JAX are typically used to build neural networks?
$\textbf{Solution}$: Flax, Haiku, Equinox, Optax, etc.
# JAX's syntax is (for the most part) same as NumPy.
# There is also SciPy API support (jax.scipy)
import jax.numpy as jnp
import numpy as np
# 4 key transform functions
from jax import jit, grad, vmap, pmap
# JAX's low-level API
from jax import lax # anagram of XLA
from jax import make_jaxpr
from jax import random
from jax import device_put
import matplotlib.pyplot as plt
import jax
jax.devices()
[CudaDevice(id=0)]
# Fact #1: JAX syntax is very similar to NumPy!
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np)
[<matplotlib.lines.Line2D at 0x7f394ad1d180>]
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp)
[<matplotlib.lines.Line2D at 0x7f392ef5d1b0>]
# Fact 2: JAX arrays are immutable! (embrace the functional programming paradigm!)
size = 10
index = 0
value = 23
# NumPy: mutable arrays
x = np.arange(size)
print(x)
x[index] = value
print(x)
[0 1 2 3 4 5 6 7 8 9] [23 1 2 3 4 5 6 7 8 9]
# JAX: immutable arrays
x = jnp.arange(size)
print(x)
x[index] = value
print(x)
[0 1 2 3 4 5 6 7 8 9]
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[9], line 4 2 x = jnp.arange(size) 3 print(x) ----> 4 x[index] = value 5 print(x) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:596, in _unimplemented_setitem(self, i, x) 592 def _unimplemented_setitem(self, i, x): 593 msg = ("JAX arrays are immutable and do not support in-place item assignment." 594 " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:" 595 " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html") --> 596 raise TypeError(msg.format(type(self))) TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html
# Solution:
y = x.at[index].set(value)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9] [23 1 2 3 4 5 6 7 8 9]
# Fact 3: JAX handles random numbers differently (cf. NumPy)
seed = 0
key = random.PRNGKey(seed)
print(key)
x = random.normal(key,(10,)) # have to pass key here explicitly!
print(type(x),x)
[0 0] <class 'jaxlib._jax.ArrayImpl'> [ 1.6226422 2.0252647 -0.43359444 -0.07861735 0.1760909 -0.97208923 -0.49529874 0.4943786 0.6643493 -0.9501635 ]
# Fact #4: JAX is AI accelerator agnostic (same code runs everywhere!)
size = 3000
# Data is automatically pushed to the AI accelerator.
# No more need for ".to(device)" (PyTorch syntax)
x_jnp = random.normal(key, (size, size), dtype=jnp.float32)
x_np = np.random.normal(size=(size, size)).astype(np.float32) # some diff in API
# block_until_ready() --> ignore time for asynchronous dispatch
%timeit jnp.dot(x_jnp, x_jnp.T).block_until_ready() # on AI accelerator (e.g. GPU) - fast
%timeit np.dot(x_np, x_np.T) # on CPU - slow (NumPy only works with CPUs)
%timeit jnp.dot(x_np, x_np.T).block_until_ready() # on AI accelerator (e.g. GPU) with overhead of transferring np to jnp
x_np_device = device_put(x_np)
%timeit jnp.dot(x_np_device, x_np_device.T).block_until_ready()
7.9 ms ± 90.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 101 ms ± 5.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 22.5 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 7.96 ms ± 39.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
$\textbf{Problem}$: What does jit do?
$\textbf{Solution}$: jit is how JAX compiles functions into super-optimized kernels using XLA. It performs a tracing process which makes subsequent calls of the function very fast.
# Simpler visualizer
def visualize_fn(fn, l=-10, r=10, n=1000):
x = np.linspace(l, r, num=n)
y = fn(x)
plt.plot(x, y); plt.show()
def selu(x, alpha=1.67, lmbda=1.05): #a type of activation function
return lmbda * jnp.where(x>0, x, alpha*jnp.exp(x)-alpha)
visualize_fn(selu)
selu_jit = jit(selu)
data = random.normal(key, (1000000,))
print("Non-jit version")
%timeit selu(data).block_until_ready()
print("Jit version:")
%timeit selu_jit(data).block_until_ready()
Non-jit version 1.66 ms ± 85.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) Jit version: 545 μs ± 117 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
$\textbf{Problem}$: What does grad() do?
$\textbf{Solution}$: Differentiation can be manual/symbolic or numerical. But grad() does it automatically!
def L(x):
return jnp.sum(x**3) # this function eats in a vector
# and spits out the sum of cubes of the vector components
x = jnp.arange(0, 6.0, 1)
print(x)
print(L(x))
print(grad(L)(x)) #manually, the grad is (3x_1^2,3x_2^2,3x_3^2)
[0. 1. 2. 3. 4. 5.] 225.0 [ 0. 3. 12. 27. 48. 75.]
# Numeric diff (to check that autodiff works)
def finite_diff(f, x):
eps = 1e-3
return jnp.array([(f(x+eps*v)-f(x-eps*v))/(2*eps)
for v in jnp.eye(len(x))])
finite_diff(L, x)
Array([ 0. , 3.0517578, 11.901855 , 26.931763 , 47.98889 ,
75.07324 ], dtype=float32)
# Example of autodiff
x = 1.
f = lambda x: x**2 + x + 4
visualize_fn(f, l=-1, r=2, n=100)
dfdx = grad(f) #2x+1
d2fdx2 = grad(dfdx) #2
d3fdx3 = grad(d2fdx2) #0
print(f(x), dfdx(x), d2fdx2(x), d3fdx3(x))
# More powerful cf. backward() (PyTorch syntax)
6.0 3.0 2.0 0.0
# Modifying above example with 2 inputs (by default
# grad always takes partial derivative w.r.t. 1st input)
x = 1.5329423
y = -2.3
f = lambda x, y: jnp.exp(-jnp.pi*x**2) + 3*x**3 - jnp.cos(jnp.pi*y**2)
dfdx = grad(f) #same as grad(f, argnums=0), etc.
d2fdx2 = grad(dfdx)
d3fdx3 = grad(d2fdx2)
dfdy = grad(f, argnums=1)
d2fdy2 = grad(dfdy, argnums=1)
print(f(x, y), dfdx(x, y), d2fdx2(x, y), d3fdx3(x, y), dfdy(x, y), d2fdy2(x, y))
# Can also get df/dx and df/dy at the same time
grads = grad(f, argnums=(0, 1))
print(grads(x, y))
11.420368 21.143219 27.64676 17.557095 11.4187975 -132.96454 (Array(21.143219, dtype=float32, weak_type=True), Array(11.4187975, dtype=float32, weak_type=True))
# JAX autodiff works not only for scalar fields, but also
# vector-valued functions (e.g. Jacobians)
from jax import jacfwd, jacrev
x = 1.
y = 1.
f = lambda x, y: x**2 + y**2 #paraboloid
#df/dx = 2x
#df/dy = 2y
#grad = [df/dx, df/dy]
#d2f/dx2 = 2
#d2f/dy2 = 2
#d2f/dxdy = d2df/dydx = 0
#H = [[d2f/dx2, d2fdxdy],[d2fdydx, d2fdy2]]
def hessian(f):
return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums=(0, 1)))
print(jacrev(f, argnums=(0, 1))(x, y))
print(hessian(f)(x, y))
(Array(2., dtype=float32, weak_type=True), Array(2., dtype=float32, weak_type=True)) ((Array(2., dtype=float32, weak_type=True), Array(0., dtype=float32, weak_type=True)), (Array(0., dtype=float32, weak_type=True), Array(2., dtype=float32, weak_type=True)))
# Edge case of non-differentiable function
f = lambda x: abs(x)
visualize_fn(f)
dfdx = grad(f)
print(dfdx(0.)) # technically non-differentiable at x=0
# but JAX defaults to the derivative as x-->0+ from above
1.0
$\textbf{Problem}$: What does vmap() do in JAX?
$\textbf{Solution}$: Standing for vectorized map, it handles batch dimensions. The idea is one write functions as if they were to operate on a single data point, and vmap automatically transforms it to work on a batch of data points. This saves a lot of headache with either trying to write for loops or figure out complex broadcasting dimensions in NumPy.
W = random.normal(key, (150, 100)) #e.g. treat as weights of linear NN layer
batched_x = random.normal(key, (10, 100)) # e.g. a batch of 10 flattened 4x25 images
print(W)
print(batched_x)
def apply_matrix(x):
return jnp.dot(W, x)
[[ 1.6226422 2.0252647 -0.43359444 ... -0.91352165 1.370097 -0.7800775 ] [ 0.36481506 0.9761402 -0.0071727 ... -0.07060029 0.33603913 2.354045 ] [-0.2431693 1.1728051 0.84588975 ... -0.7116935 -0.05395174 -2.0926828 ] ... [ 0.6094885 0.8456557 -0.35652205 ... -1.2560906 0.42053804 0.1999395 ] [-1.0437807 0.5227283 0.2781648 ... 0.83346254 -0.6578746 0.67506284] [-0.3662531 -0.77806586 0.59587073 ... 1.1635658 0.475208 0.6065461 ]] [[ 1.62264216e+00 2.02526474e+00 -4.33594435e-01 -7.86173493e-02 1.76090896e-01 -9.72089231e-01 -4.95298743e-01 4.94378597e-01 6.64349318e-01 -9.50163484e-01 2.17953038e+00 -1.95515060e+00 3.58570725e-01 1.57795131e-01 1.27708471e+00 1.51046479e+00 9.70655978e-01 5.99608064e-01 2.47007050e-02 -1.91647720e+00 -1.85934913e+00 1.72814405e+00 4.71903495e-02 8.14127982e-01 1.31327674e-01 2.82847047e-01 1.24359429e+00 6.90280080e-01 -8.00737441e-01 -7.40989983e-01 -1.53882873e+00 3.02691847e-01 -2.07160451e-02 1.13287210e-01 -2.20654696e-01 7.05225617e-02 8.53295803e-01 -8.21773827e-01 -1.46142114e-02 -1.50462165e-01 -9.00135219e-01 -7.59072721e-01 3.33095133e-01 8.09249043e-01 4.26925533e-02 -5.77671230e-01 -4.14398938e-01 -1.94125330e+00 1.31611836e+00 7.54272819e-01 1.61709309e-01 -3.48330699e-02 -1.33064091e+00 3.93620282e-01 4.82595831e-01 8.03829551e-01 -6.33716822e-01 1.03875601e+00 -7.41591334e-01 -4.29958791e-01 -2.25100428e-01 -5.19667149e-01 -1.66921651e+00 6.75354362e-01 2.27387220e-01 -1.18004262e+00 -9.76733565e-01 1.19696045e+00 -8.41275632e-01 6.59807801e-01 1.06801593e+00 3.15421283e-01 4.37664032e-01 1.17185640e+00 9.07709897e-01 1.22262418e+00 -5.46395242e-01 8.56304348e-01 -7.96577521e-03 4.73439127e-01 -1.10903490e+00 2.64235139e+00 8.89576316e-01 9.95201647e-01 2.55197197e-01 1.24961376e-01 1.16417289e+00 1.92963660e-01 -1.90995440e-01 -4.36594725e-01 -1.14619887e+00 1.97602510e-01 1.16866553e+00 -8.73398483e-01 8.81808579e-01 -3.44105691e-01 -1.46149725e-01 -9.13521647e-01 1.37009704e+00 -7.80077517e-01] [ 3.64815056e-01 9.76140201e-01 -7.17270281e-03 2.10522056e-01 1.90358415e-01 3.82912666e-01 -1.26563323e+00 -1.48435450e+00 -1.14543624e-01 1.10371351e+00 1.98467016e-01 2.13889346e-01 -6.60534799e-01 -7.27220058e-01 4.04439718e-01 1.89657375e-01 -6.03179395e-01 9.45058823e-01 1.08387780e+00 -2.05607367e+00 -7.13821530e-01 5.92868268e-01 1.05077624e+00 -1.46462381e+00 6.60011351e-01 -3.01721781e-01 1.33131772e-01 -3.32813233e-01 1.57000983e+00 5.74512124e-01 7.23415494e-01 6.96684480e-01 -6.64234340e-01 -1.96695662e+00 -2.41625428e+00 2.73301542e-01 1.16031730e+00 2.65512705e-01 6.90909326e-01 -2.56064266e-01 -2.02274013e+00 -6.23128891e-01 2.79531687e-01 -1.35031724e+00 1.01288453e-01 5.12681365e-01 2.64019489e-01 -1.82912755e+00 1.43377745e+00 1.31885552e+00 -1.49532259e+00 9.33276117e-01 1.40926480e+00 -1.67883754e-01 -1.18622862e-01 -2.42824897e-01 -9.61759269e-01 -7.56359994e-01 2.57282567e+00 -1.06017923e+00 3.12329054e-01 3.27511787e-01 8.28322321e-02 -1.08268857e+00 -7.72234499e-01 -6.34604633e-01 1.22641027e+00 -1.48701501e+00 -7.92869031e-01 5.53118527e-01 -1.18553972e+00 9.76909399e-01 -4.38450336e-01 -3.29755992e-01 3.32547158e-01 -6.52719617e-01 -1.20521224e+00 -8.86308253e-01 -2.10883713e+00 -1.55035362e-01 -6.57932043e-01 -6.63254023e-01 -3.33620496e-02 -8.95929098e-01 7.71168023e-02 -9.09823000e-01 1.27605200e+00 -4.01676625e-01 -9.99925256e-01 1.73419788e-02 4.04541880e-01 -1.07132435e+00 1.03666258e+00 -6.68480515e-01 -7.79318735e-02 1.20802212e+00 2.00314546e+00 -7.06002936e-02 3.36039126e-01 2.35404491e+00] [-2.43169293e-01 1.17280507e+00 8.45889747e-01 7.28510857e-01 5.97945750e-01 -7.66705036e-01 1.74257264e-01 -6.22359514e-01 8.30672979e-01 3.31583172e-01 -8.89908016e-01 2.11538583e-01 -9.24918503e-02 -1.61701906e+00 -4.12348479e-01 1.40828717e+00 1.03313172e+00 1.08109426e+00 -1.14955103e+00 -3.21219079e-02 -7.46680051e-02 5.98946154e-01 -1.19016111e+00 -6.29099309e-01 5.55720888e-02 3.65632802e-01 5.46573997e-01 7.64500856e-01 -3.39724004e-01 3.62569809e-01 -2.96992004e-01 1.03010583e+00 9.65574086e-01 7.55760491e-01 1.21214367e-01 8.22795153e-01 5.98026872e-01 -8.37458730e-01 5.43704927e-01 -3.22042048e-01 -4.44730967e-01 9.14655566e-01 1.21259427e+00 1.79783368e+00 -1.60528362e+00 7.97706470e-02 -7.69092619e-01 -1.94633031e+00 1.59343791e+00 -9.19291854e-01 -1.71621680e+00 1.11879945e+00 2.03491643e-01 -1.80223882e-01 1.74956903e-01 -7.34732687e-01 4.10009056e-01 -1.53308785e+00 -1.70366681e+00 2.04102325e+00 -2.04918474e-01 -8.45943689e-02 6.67714596e-01 -1.38450539e+00 3.79239351e-01 -7.40188420e-01 1.30270433e+00 -1.02559209e+00 4.96420115e-01 -7.00343609e-01 -3.66578884e-02 -4.15897846e-01 -2.48955950e-01 9.09369051e-01 -1.88014939e-01 -7.17376173e-01 3.89687389e-01 -2.08640948e-01 -7.79499948e-01 1.29690838e+00 -9.13590670e-01 1.33042842e-01 8.17202270e-01 -4.76543754e-01 4.17642653e-01 -9.76111740e-03 -9.28037405e-01 6.50380313e-01 -3.19879699e+00 4.65263397e-01 3.35585117e-01 2.14436688e-02 -2.76631832e-01 4.24526855e-02 -3.00587714e-01 -6.31640971e-01 -1.22646141e+00 -7.11693525e-01 -5.39517365e-02 -2.09268284e+00] [ 4.73754674e-01 -1.59625351e+00 -1.53968468e-01 -2.44954348e+00 6.49740636e-01 -7.07279503e-01 1.60122812e-01 -1.56832671e+00 2.35200658e-01 1.16291511e+00 4.25386399e-01 -3.92973781e-01 7.57110953e-01 -1.14584528e-01 -1.07421279e+00 -6.47926211e-01 -9.07325804e-01 2.22105756e-01 3.95380974e-01 1.78448841e-01 -5.79642057e-02 1.04901743e+00 1.34441698e+00 -9.69599932e-02 9.23476666e-02 8.63941193e-01 1.11614284e-03 -1.05547857e+00 1.09360147e+00 -1.33676386e+00 8.93752217e-01 3.83038670e-02 -2.31079721e+00 2.61638403e-01 5.76163709e-01 7.12172329e-01 1.13079399e-01 -1.81492853e+00 1.04789495e+00 -1.23274231e+00 5.07376343e-02 -1.33968818e+00 2.26974034e+00 2.03760728e-01 5.24858892e-01 -1.25730753e+00 -4.17512745e-01 6.30214334e-01 -5.04626036e-01 2.55590463e+00 5.74609153e-02 4.38864678e-01 -2.43704557e-01 1.28732228e+00 1.14161777e+00 5.20770967e-01 2.16976404e-01 1.14142621e+00 8.94764543e-01 1.80909836e+00 -1.90802503e+00 -9.99371231e-01 7.12768495e-01 -6.34517670e-01 -1.24687783e-01 6.85439765e-01 9.78687882e-01 -2.72447288e-01 -1.24981749e+00 -8.90857697e-01 1.01692431e-01 9.47433591e-01 9.78131071e-02 1.25865614e+00 1.04573154e+00 -1.85578093e-01 5.13851285e-01 -1.96585572e+00 -6.18929446e-01 -1.59196496e+00 -7.94556320e-01 1.35239506e+00 1.30747497e-01 1.08340538e+00 7.94674933e-01 -1.70419037e-01 -1.07815914e-01 -8.67752433e-01 2.74028946e-02 1.42733014e+00 -3.35451245e-01 -1.18081085e-02 -7.14428008e-01 1.29715097e+00 -1.53426766e+00 -7.59699047e-01 1.63791358e-01 -1.65026736e+00 1.48562670e+00 -1.72110963e+00] [-3.48294787e-02 2.26507974e+00 4.86799538e-01 8.02250922e-01 -5.03105819e-01 -1.51174557e+00 -1.23069298e+00 -4.37387377e-01 -1.19922660e-01 -1.32741228e-01 9.81089592e-01 -2.00786293e-01 8.46012473e-01 7.58226991e-01 1.37318218e+00 4.99171913e-01 7.41392896e-02 1.14247692e+00 1.19223285e+00 1.54212606e+00 -2.56369328e+00 -5.90210617e-01 -1.69458699e+00 3.01133752e-01 9.47523534e-01 -1.28467157e-01 -6.07066035e-01 -6.00612879e-01 -3.29539704e+00 3.31878364e-01 3.49914461e-01 -8.02474558e-01 9.00005996e-02 8.55811656e-01 -6.35393858e-01 -1.88428596e-01 8.55382442e-01 -7.71138191e-01 1.48878348e+00 -5.60926378e-01 6.07925117e-01 1.13047779e+00 -1.20606339e+00 8.76124084e-01 3.93461764e-01 5.42526424e-01 9.87375915e-01 1.54771483e+00 -1.11704183e+00 9.66534257e-01 -1.50813270e+00 -4.18101907e-01 -8.16014290e-01 -6.96839750e-01 -2.56920576e-01 4.37291175e-01 -2.53834738e-03 2.33468676e+00 -1.50986416e-02 -6.59652352e-01 -4.65559602e-01 -1.34911144e+00 4.10299897e-02 4.49487656e-01 -3.28320116e-01 5.82158938e-03 -9.17235196e-01 5.47479868e-01 1.00075707e-01 5.53778887e-01 7.95560360e-01 5.82074165e-01 4.40597296e-01 1.41883627e-01 -1.60682452e+00 8.55871812e-02 4.72201049e-01 6.77530110e-01 -1.82441995e-01 1.01498105e-01 -8.04406777e-03 1.96414077e+00 -1.11661828e+00 -1.56128228e+00 -4.74351019e-01 1.91607726e+00 -1.19328606e+00 -2.86463916e-01 1.25017092e-01 1.13316298e+00 -2.74157941e-01 -1.64727375e-01 -1.72950840e+00 1.12391174e+00 -6.66688561e-01 7.28739873e-02 1.04647315e+00 -1.55617642e+00 -2.54764885e-01 -8.18689287e-01] [ 8.91156673e-01 3.18052262e-01 2.09722710e+00 1.73778522e+00 -1.41319370e+00 -7.47485280e-01 6.66629076e-01 7.39560783e-01 8.95429015e-01 1.51626244e-01 1.38383520e+00 -1.36733592e-01 8.11541319e-01 -4.33953226e-01 7.94477880e-01 -6.55612409e-01 -1.38271320e+00 -3.39266628e-01 -2.93511152e-01 1.46516097e+00 1.71347570e+00 -1.40051532e+00 1.12666571e+00 -4.97829884e-01 -5.93189120e-01 -1.45751333e+00 -2.71108794e+00 -2.00019574e+00 -9.48191524e-01 -4.44346368e-01 8.13890696e-01 -7.42613897e-02 9.32008401e-02 -1.23251045e+00 -8.50963406e-03 2.05505773e-01 -1.48191705e-01 -9.23814103e-02 -2.20937908e-01 2.62501180e-01 1.08574092e-01 7.75018990e-01 1.11870992e+00 7.86271334e-01 9.40064609e-01 1.32092059e+00 -5.05305588e-01 1.13251865e+00 1.40149787e-01 -1.42711639e+00 -2.81242430e-01 -1.66696763e+00 1.10434997e+00 1.75778019e+00 -1.71847045e+00 -2.79099345e-01 1.51310873e+00 -8.37235510e-01 -7.63156950e-01 -1.06069550e-01 2.34096125e-01 -2.71091008e+00 2.44496793e-01 1.89915252e+00 8.01037431e-01 1.23803353e+00 -9.71419290e-02 -1.92206979e+00 8.28902960e-01 -4.83347476e-01 5.25316775e-01 2.30684206e-02 -5.12029767e-01 -7.48020053e-01 4.11946565e-01 5.49255192e-01 -2.23095298e-01 -5.12054920e-01 2.52339333e-01 -9.49054807e-02 1.07594228e+00 -9.24595118e-01 -6.69982493e-01 -1.06690586e+00 6.30436301e-01 2.94996917e-01 1.32940316e+00 -1.39383161e+00 -5.40680647e-01 -7.85788894e-01 3.04403752e-01 4.66728628e-01 2.67849475e-01 1.44306827e+00 -2.27237716e-01 1.02745861e-01 -5.52183032e-01 -5.09359360e-01 1.01352119e+00 6.09791458e-01] [ 3.99991989e-01 -6.25152647e-01 7.01872110e-01 -1.16544163e+00 8.80496144e-01 7.28263557e-01 3.84819478e-01 -1.36815324e-01 -4.19357091e-01 -2.99014002e-01 6.20382190e-01 -1.30659953e-01 1.17253888e+00 -5.22038400e-01 1.64784598e+00 8.08855772e-01 1.11875927e+00 1.02307034e+00 1.02533817e+00 -8.38138282e-01 -6.32627964e-01 8.43035996e-01 -4.72003847e-01 -1.91650724e+00 1.87835085e+00 6.62476659e-01 -1.64310064e-03 1.54904974e+00 -4.31367248e-01 -9.94612157e-01 -2.51327772e-02 -8.24098051e-01 5.06091237e-01 -6.20458759e-02 -3.90430093e-01 -9.47100639e-01 -1.86841384e-01 -8.28972280e-01 -1.91550240e-01 -2.12299919e+00 8.34999502e-01 -2.38999426e-01 -2.41668820e-01 -9.00301933e-01 -2.25573361e-01 1.54047862e-01 9.26408947e-01 -2.11789295e-01 4.12847757e-01 1.14952362e+00 -7.36300349e-01 2.04192728e-01 1.28266513e+00 -6.70982182e-01 -1.83460546e+00 2.15686157e-01 7.07491457e-01 -5.47649086e-01 1.27539206e+00 5.64734936e-01 2.26897240e-01 -1.17738798e-01 -4.27806437e-01 2.72378884e-02 1.63076770e+00 -4.90351528e-01 5.36693037e-01 8.68616879e-01 8.11795175e-01 3.30178648e-01 2.17748213e+00 -7.27532327e-01 -1.45456925e-01 -1.59931540e+00 -2.77693480e-01 -4.46349621e-01 -9.30218279e-01 -2.19141650e+00 8.11784804e-01 7.18652382e-02 6.57420337e-01 1.03664204e-01 -1.04592490e+00 6.97659552e-01 1.92154855e-01 6.56601965e-01 8.16967726e-01 -1.21444792e-01 7.82743096e-01 1.41578543e+00 3.41761321e-01 1.23176627e-01 6.58716798e-01 -9.68537092e-01 8.28790963e-01 -6.14581816e-02 4.16769356e-01 7.93242037e-01 1.27035511e+00 -2.98217845e+00] [ 1.41787744e+00 4.30863053e-01 -1.35842001e+00 -2.31687352e-01 4.81799468e-02 -1.13010384e-01 -2.81411290e-01 1.23678505e+00 -4.91829105e-02 3.90700817e-01 -1.18719149e+00 5.87880611e-01 1.05693567e+00 1.75502792e-01 -8.20879996e-01 -6.93775490e-02 1.40034509e+00 -3.20545942e-01 -8.02284122e-01 -1.90450406e+00 1.86188713e-01 -6.00135505e-01 -3.11949283e-01 -1.58620250e+00 4.76210535e-01 -1.73004854e+00 -6.93736970e-01 -5.08554578e-01 3.16312289e+00 1.95174009e-01 -2.76474714e-01 1.33630812e+00 2.34854460e-01 -2.47333512e-01 7.94547141e-01 8.25466216e-01 5.37205935e-01 9.24634635e-01 -1.38021624e+00 -2.07926393e-01 -4.06866372e-01 -1.34310174e+00 1.07187748e+00 4.51775134e-01 -9.03886437e-01 7.52393663e-01 -5.68250000e-01 6.33994997e-01 1.08135533e+00 -1.00676358e+00 -9.19247627e-01 -2.00193954e+00 -1.03497982e+00 -3.54129493e-01 -1.17053652e+00 2.89141273e+00 2.00288579e-01 2.43504629e-01 5.16683877e-01 -1.94491204e-02 -1.23140061e+00 5.12821823e-02 -4.50057536e-01 -1.03455268e-01 2.95981437e-01 9.00486037e-02 -1.65372467e+00 -7.55375445e-01 4.64709401e-01 -6.81560993e-01 -1.21171749e+00 -1.10982382e+00 -8.82262439e-02 -4.89559084e-01 3.21174353e-01 4.49467182e-01 -1.44262266e+00 -6.77745342e-02 -4.52509262e-02 -4.69268233e-01 1.16332404e-01 1.82556927e+00 6.26976192e-01 -9.81169581e-01 4.06791836e-01 -1.66363668e+00 7.90482312e-02 -1.15621150e+00 1.21551013e+00 -1.50499329e-01 -8.60824715e-03 1.26619697e+00 -6.30170465e-01 -9.19624329e-01 4.25071388e-01 1.68217826e+00 -1.06247319e-02 -2.80735463e-01 -3.46607894e-01 6.16329730e-01] [ 1.23951900e+00 -1.29802644e+00 -1.11810875e+00 -1.61887482e-01 1.66484877e-01 4.02306058e-02 2.34848931e-01 -6.17022276e-01 -1.01548862e+00 -1.59668422e+00 -4.61611271e-01 1.00245929e+00 2.30857298e-01 2.06788397e+00 -6.77438915e-01 -1.82206213e-01 2.47246489e-01 -1.72794685e-01 -1.61421788e+00 3.54998648e-01 1.22171605e+00 1.04115450e+00 6.81458652e-01 -6.77761137e-01 1.90401602e+00 1.37473094e+00 1.93106627e+00 -8.90373051e-01 2.69743651e-01 1.10350811e+00 9.85456467e-01 -4.17405128e-01 1.45321739e+00 -1.62904516e-01 2.61666441e+00 -1.13660252e+00 -5.23015976e-01 1.41417384e+00 -1.45638537e+00 1.39924064e-01 -1.81835264e-01 -1.85092092e+00 6.44125581e-01 -3.70008737e-01 3.54113102e-01 1.47451103e+00 -1.08801985e+00 2.43944931e+00 -1.32531929e+00 -1.80648351e+00 -4.23250198e-01 7.63600230e-01 -5.85210204e-01 2.35295343e+00 1.04250467e+00 1.02737755e-01 -2.87180662e-01 -8.28941539e-02 6.05848372e-01 2.33335465e-01 1.96693033e-01 -7.93169662e-02 1.89767814e+00 -2.49023829e-03 9.39399183e-01 1.27057350e+00 -1.26574382e-01 1.88864917e-01 2.41839439e-01 -5.35686433e-01 -1.35840702e+00 -1.48494470e+00 1.21456063e+00 3.16799998e-01 7.81244099e-01 1.40346968e+00 -8.55279490e-02 -4.50297594e-01 -1.24199247e+00 -5.65499604e-01 1.21995819e+00 -5.08152306e-01 6.65360019e-02 1.64739639e-01 1.54252136e+00 1.00633180e+00 -1.25837934e+00 5.13076067e-01 4.77847643e-02 4.05105203e-02 -6.82877123e-01 -1.07491016e-01 5.98556817e-01 -2.11435556e+00 1.08828628e+00 1.69309035e-01 -1.96819067e+00 1.22727036e+00 -1.31954658e+00 -3.26137334e-01] [ 7.08347976e-01 8.58157128e-02 8.75563264e-01 2.74913341e-01 4.60563987e-01 3.63558292e-01 -3.28417599e-01 1.39410818e+00 2.12482005e-01 3.25346977e-01 -5.32687724e-01 8.93540680e-02 5.19629002e-01 -1.36936378e+00 -6.61399513e-02 1.50276494e+00 1.28715611e+00 -2.44274974e-01 5.93465090e-01 5.05267680e-01 6.96502507e-01 -1.56812537e+00 1.53453970e+00 8.01590204e-01 -1.70177698e+00 -5.72526097e-01 -7.42326021e-01 -4.66878384e-01 -2.30425686e-01 -3.67261142e-01 1.15874124e+00 -2.89703548e-01 -2.18252748e-01 3.10105979e-01 -1.29287863e+00 2.05113482e+00 2.57617265e-01 -1.06610370e+00 -5.81136525e-01 1.32945165e-01 6.66835070e-01 7.49220848e-01 5.12089491e-01 4.30191994e-01 -1.51903883e-01 7.75620282e-01 1.16832733e+00 4.68485653e-01 -4.53565627e-01 -9.55212057e-01 -2.65528321e-01 -3.27313989e-01 9.56872046e-01 5.94931424e-01 -3.91686225e+00 8.41580987e-01 -4.13113117e-01 -1.14385709e-01 -1.87210631e+00 5.46258092e-02 -5.05993187e-01 -1.67221451e+00 2.15138698e+00 -2.31907427e-01 -3.46957266e-01 -3.51842254e-01 4.06620145e-01 7.34579444e-01 1.59595609e+00 -1.41492343e+00 -9.70596373e-01 8.89268756e-01 -7.70626605e-01 -1.15664482e+00 -1.54152620e+00 2.41009980e-01 -2.38504753e-01 -2.03631926e+00 -1.29285181e+00 3.12330008e-01 -2.23036790e+00 -1.19731987e+00 -1.44340980e+00 -1.25551546e+00 -1.41780686e+00 -6.98669374e-01 -1.49969471e+00 -4.50949132e-01 -1.64870260e-04 8.19425404e-01 -6.81385696e-01 -7.35609353e-01 -6.61660969e-01 -1.59738159e+00 -2.30075568e-01 1.22990265e-01 1.80285692e-01 -5.41782714e-02 -1.50115645e+00 -2.67956465e-01]]
def naively_batched_apply_matrix(batched_x):
return jnp.stack([apply_matrix(x) for x in batched_x])
print("Naively batched time:")
%timeit naively_batched_apply_matrix(batched_x).block_until_ready( )
Naively batched time: 1.77 ms ± 93.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
@jit
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, W.T) #had to completely rewrite function to vectorize it
print("Manually batched time:")
%timeit batched_apply_matrix(batched_x).block_until_ready()
Manually batched time: 223 μs ± 49.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
@jit # note: one can arbitrarily compose JAX transformations
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x) #much simpler! and very efficient!
print("Automatically vectorized with vmap time:")
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Automatically vectorized with vmap time: 157 μs ± 18.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
def apply_matrix(W, x): #accept both W and x args, in line with func prog paradigm
return jnp.dot(W, x)
@jit # note: we can arbitrarily compose JAX transformations
def vmap_batched_apply_matrix(W, batched_x):
return vmap(apply_matrix)(W, batched_x) #much simpler! and very efficient!
print("Automatically vectorized with vmap time:")
%timeit vmap_batched_apply_matrix(W, batched_x).block_until_ready()
# crashes because W does not have a batch dimension.
Automatically vectorized with vmap time:
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[167], line 9 6 return vmap(apply_matrix)(W, batched_x) #much simpler! and very efficient! 8 print("Automatically vectorized with vmap time:") ----> 9 get_ipython().run_line_magic('timeit', 'vmap_batched_apply_matrix(W, batched_x).block_until_ready()') File ~/jax_linus/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2482, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth) 2480 kwargs['local_ns'] = self.get_local_scope(stack_depth) 2481 with self.builtin_trap: -> 2482 result = fn(*args, **kwargs) 2484 # The code below prevents the output from being displayed 2485 # when using magics with decorator @output_can_be_silenced 2486 # when the last Python token in the expression is a ';'. 2487 if getattr(fn, magic.MAGIC_OUTPUT_CAN_BE_SILENCED, False): File ~/jax_linus/lib/python3.10/site-packages/IPython/core/magics/execution.py:1209, in ExecutionMagics.timeit(self, line, cell, local_ns) 1207 for index in range(0, 10): 1208 number = 10 ** index -> 1209 time_number = timer.timeit(number) 1210 if time_number >= 0.2: 1211 break File ~/jax_linus/lib/python3.10/site-packages/IPython/core/magics/execution.py:174, in Timer.timeit(self, number) 172 gc.disable() 173 try: --> 174 timing = self.inner(it, self.timer) 175 finally: 176 if gcold: File <magic-timeit>:1, in inner(_it, _timer) [... skipping hidden 14 frame] Cell In[167], line 6, in vmap_batched_apply_matrix(W, batched_x) 4 @jit # note: we can arbitrarily compose JAX transformations 5 def vmap_batched_apply_matrix(W, batched_x): ----> 6 return vmap(apply_matrix)(W, batched_x) [... skipping hidden 2 frame] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/api.py:1248, in _mapped_axis_size(fn, tree, vals, dims, name) 1246 else: 1247 msg.append(f" * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n") -> 1248 raise ValueError(''.join(msg)[:-2]) ValueError: vmap got inconsistent sizes for array axes to be mapped: * one axis had size 150: axis 0 of argument W of type float32[150,100]; * one axis had size 10: axis 0 of argument x of type float32[10,100]
# Solution: use the in_axes arg for vmap
def apply_matrix(W, x): #accept both W and x args, in line with func prog paradigm
return jnp.dot(W, x)
@jit # note: we can arbitrarily compose JAX transformations
def vmap_batched_apply_matrix(W, batched_x):
return vmap(apply_matrix, in_axes=(None, 0))(W, batched_x) #much simpler! and very efficient!
print("Automatically vectorized with vmap time:")
%timeit vmap_batched_apply_matrix(W, batched_x).block_until_ready()
Automatically vectorized with vmap time: 197 μs ± 29.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
$\textbf{Problem}$: Why is JAX’s API layering similar to an onion?
$\textbf{Solution}$: The highest layer of the API is NumPy, followed by Lax, followed by the XLA compiler. The Lax API is stricter than NumPy but also more powerful; it’s a Python wrapper around XLA.
# Example 1: Lax is stricter than NumPy
print(jnp.add(1, 1.0))
print(lax.add(1, 1.0))
2.0
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[171], line 4 1 # Example 1: Lax is stricter than NumPy 3 print(jnp.add(1, 1.0)) ----> 4 print(lax.add(1, 1.0)) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/lax/lax.py:1196, in add(x, y) 1176 r"""Elementwise addition: :math:`x + y`. 1177 1178 This function lowers directly to the `stablehlo.add`_ operation. (...) 1193 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add 1194 """ 1195 x, y = core.standard_insert_pvary(x, y) -> 1196 return add_p.bind(x, y) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/core.py:536, in Primitive.bind(self, *args, **params) 534 def bind(self, *args, **params): 535 args = args if self.skip_canonicalization else map(canonicalize_value, args) --> 536 return self._true_bind(*args, **params) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/core.py:552, in Primitive._true_bind(self, *args, **params) 550 trace_ctx.set_trace(eval_trace) 551 try: --> 552 return self.bind_with_trace(prev_trace, args, params) 553 finally: 554 trace_ctx.set_trace(prev_trace) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/core.py:562, in Primitive.bind_with_trace(self, trace, args, params) 559 with set_current_trace(trace): 560 return self.to_lojax(*args, **params) # type: ignore --> 562 return trace.process_primitive(self, args, params) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/core.py:1066, in EvalTrace.process_primitive(self, primitive, args, params) 1064 args = map(full_lower, args) 1065 check_eval_args(args) -> 1066 return primitive.impl(*args, **params) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/dispatch.py:91, in apply_primitive(prim, *args, **params) 89 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) 90 try: ---> 91 outs = fun(*args) 92 finally: 93 lib.jax_jit.swap_thread_local_state_disable_jit(prev) [... skipping hidden 24 frame] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/lax/lax.py:8799, in check_same_dtypes(name, *avals) 8797 equiv = _JNP_FUNCTION_EQUIVALENTS[name] 8798 msg += f" (Tip: jnp.{equiv} is a similar function that does automatic type promotion on inputs)." -> 8799 raise TypeError(msg.format(name, ", ".join(str(a.dtype) for a in avals))) TypeError: lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).
# Example 2: Lax is more powerful (tradeoff: less user-friendly)
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
# NumPy API
convolution_jnp = jnp.convolve(x, y)
# Lax API
convolution_lax = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float),
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y)-1, len(y)-1)]
)
print(convolution_jnp)
print(convolution_lax)
print(convolution_lax[0][0]) # returns batched result, hence need for this indexing
# see:
[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.] [[[1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]]] [1. 3. 4. 4. 4. 4. 4. 4. 4. 4. 3. 1.]
# another JIT example
def norm(X):
X = X - X.mean(0)
return X/X.std(0)
norm_compiled = jit(norm)
X = random.normal(key, (10000, 100), dtype=jnp.float32)
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
1.9 ms ± 153 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) 1.91 ms ± 52.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# Example of a failure: array shapes must be static.
def get_negatives(x):
return x[x < 0]
x = random.normal(key, (10,), dtype=jnp.float32)
print(get_negatives(x))
[-0.43359444 -0.07861735 -0.97208923 -0.49529874 -0.9501635 ]
# but this fails:
print(jit(get_negatives)(x))
--------------------------------------------------------------------------- NonConcreteBooleanIndexError Traceback (most recent call last) Cell In[199], line 2 1 # but this fails: ----> 2 print(jit(get_negatives)(x)) [... skipping hidden 14 frame] Cell In[198], line 3, in get_negatives(x) 2 def get_negatives(x): ----> 3 return x[x < 0] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:1083, in _forward_operator_to_aval.<locals>.op(self, *args) 1082 def op(self, *args): -> 1083 return getattr(self.aval, f"_{name}")(self, *args) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:657, in _getitem(self, item) 656 def _getitem(self, item): --> 657 return indexing.rewriting_take(self, item) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/indexing.py:645, in rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding) 639 if (isinstance(aval, core.DShapedArray) and aval.shape == () and 640 dtypes.issubdtype(aval.dtype, np.integer) and 641 not dtypes.issubdtype(aval.dtype, dtypes.bool_) and 642 isinstance(arr.shape[0], int)): 643 return lax.dynamic_index_in_dim(arr, idx, keepdims=False) --> 645 treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape) 646 internal_gather = partial( 647 _gather, treedef=treedef, static_idx=static_idx, 648 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, 649 mode=mode, fill_value=fill_value) 650 if out_sharding is not None: File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/indexing.py:738, in split_index_for_jit(idx, shape) 734 raise TypeError(f"JAX does not support string indexing; got {idx=}") 736 # Expand any (concrete) boolean indices. We can then use advanced integer 737 # indexing logic to handle them. --> 738 idx = _expand_bool_indices(idx, shape) 740 leaves, treedef = tree_flatten(idx) 741 dynamic = [None] * len(leaves) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/indexing.py:1059, in _expand_bool_indices(idx, shape) 1055 abstract_i = core.get_aval(i) 1057 if not core.is_concrete(i): 1058 # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete -> 1059 raise errors.NonConcreteBooleanIndexError(abstract_i) 1060 elif np.ndim(i) == 0: 1061 out.append(bool(i)) NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10] See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
# So why does this happen? --> tracing on different levels of abstraction
@jit
def f(x, y):
print("Running f().")
print(f"x={x}")
print(f"y={y}")
result = jnp.dot(x+1, y+1)
print(f"result = {result}")
return result
x = np.random.randn(3, 4)
y = np.random.randn(4)
print(f(x, y))
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
print("Second Call (notice no more print statement side effects):")
print(f(x2, y2))
# Note: any time one puts the same shapes + types into a jitted function, fast b/c tracer object already XLA-compiled
# S + T! shapes + types!
x3 = np.random.randn(3, 5)
y3 = np.random.randn(5)
print(f(x3, y3)) # notice how jit has to retrace it
Running f(). x=Traced<float32[3,4]>with<DynamicJaxprTrace> y=Traced<float32[4]>with<DynamicJaxprTrace> result = Traced<float32[3]>with<DynamicJaxprTrace> [1.0699788 4.141973 2.184647 ] Second Call (notice no more print statement side effects): [9.221527 5.545151 2.8010619] Running f(). x=Traced<float32[3,5]>with<DynamicJaxprTrace> y=Traced<float32[5]>with<DynamicJaxprTrace> result = Traced<float32[3]>with<DynamicJaxprTrace> [5.1654205 4.1818676 6.113046 ]
def f(x, y):
# same function as above but w/o the print statements, as would be written in practice
return jnp.dot(x + 1, y + 1)
print(make_jaxpr(f)(x, y))
# produces a flow model showing what jit creates in the background when tracing
# can go to JAX docs to get better understanding of this grammar
{ lambda ; a:f32[3,4] b:f32[4]. let
c:f32[3,4] = add a 1.0:f32[]
d:f32[4] = add b 1.0:f32[]
e:f32[3] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }
@jit
def f(x, neg):
return -x if neg else X
f(1, True)
--------------------------------------------------------------------------- TracerBoolConversionError Traceback (most recent call last) Cell In[14], line 5 1 @jit 2 def f(x, neg): 3 return -x if neg else X ----> 5 f(1, True) [... skipping hidden 14 frame] Cell In[14], line 3, in f(x, neg) 1 @jit 2 def f(x, neg): ----> 3 return -x if neg else X [... skipping hidden 1 frame] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/core.py:1661, in concretization_function_error.<locals>.error(self, arg) 1660 def error(self, arg): -> 1661 raise TracerBoolConversionError(arg) TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function f at /tmp/ipykernel_424021/236938057.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg. See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
$\textbf{Problem}$: Why does running the above code cell lead to the error message shown? What can be done to debug this error?
$\textbf{Solution}$: When you jit a function, there is always a $2$-step procedure:
- Tracing (Recipe Phase): JAX runs the function once using placeholder objects called tracers instead of the actual data. This is meant purely to record the computational graph of the function.
- Execution (Cooking Phase): JAX compiles the graph into XLA code to run on GPU/TPU.
So the conflict here is that standard Python control flow (such as the if/else conditional statement in the function above) happened during the $1^{\text{st}}$ step of the jit compilation. The Python interpreter would have checked whether neg is True or False, but at the time it was still a tracer object, which basically told the interpreter “I don’t know what value I hold yet; I represent a value that will exist later when we run the XLA code”.
To debug this error, one can instead make neg a static argument (see below).
from functools import partial
@partial(jit, static_argnums=(1, ))
def f(x, neg):
print(x)
return -x if neg else x
print(f(1, True))
print(f(2, True)) # no more tracing here b/c still True which was already traced
print(f(2, False))
print(f(23, False)) # same comment here
Traced<~int32[]>with<DynamicJaxprTrace> -1 -2 Traced<~int32[]>with<DynamicJaxprTrace> 2 23
@jit
def f(x):
print(x) #tracer object
print(x.shape) #concrete value
print(jnp.array(x.shape).prod()) #tracer object again
return x.reshape(jnp.array(x.shape).prod()) #fail b/c try to apply reshape to a tracer
x = jnp.ones((2, 3))
print(f(x))
Traced<float32[2,3]>with<DynamicJaxprTrace> (2, 3) Traced<int32[]>with<DynamicJaxprTrace>
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[39], line 9 6 return x.reshape(jnp.array(x.shape).prod()) #fail b/c try to apply reshape to a tracer 8 x = jnp.ones((2, 3)) ----> 9 print(f(x)) [... skipping hidden 14 frame] Cell In[39], line 6, in f(x) 4 print(x.shape) #concrete value 5 print(jnp.array(x.shape).prod()) #tracer object again ----> 6 return x.reshape(jnp.array(x.shape).prod()) [... skipping hidden 2 frame] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:454, in _compute_newshape(arr, newshape) 452 except: 453 newshape = [newshape] --> 454 newshape = core.canonicalize_shape(newshape) # type: ignore[arg-type] 455 neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1] 456 if len(neg1s) > 1: File ~/jax_linus/lib/python3.10/site-packages/jax/_src/core.py:1864, in canonicalize_shape(shape, context) 1862 except TypeError: 1863 pass -> 1864 raise _invalid_shape_error(shape, context) TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<int32[]>with<DynamicJaxprTrace>]. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. The error occurred while tracing the function f at /tmp/ipykernel_424021/3415433522.py:1 for jit. This value became a tracer due to JAX operations on these lines: operation a:i32[] = reduce_prod[axes=(0,)] b from line /tmp/ipykernel_424021/3415433522.py:6 (f)
@jit
def f(x):
return x.reshape(np.prod(x.shape)) #np.prod works b/c it returns concrete array rather than tracer object
print(f(x))
[1. 1. 1. 1. 1. 1.]
$\textbf{Problem}$: What does it mean that JAX is only designed to work with “pure functions”?
$\textbf{Solution}$: Informally, a function is pure iff:
- All inputs are passed through the function arguments, all outputs are passed through the return statement of the function.
- When passed with the same input arguments, the same output result is always obtained.
# Example: identity function
def impure_print_side_effect(x):
print("Executing function:") #violates #1
return x
print("First call (tracing)", jit(impure_print_side_effect)(4))
print("Second call, no more side effects b/c traced and XLA-compiled already", jit(impure_print_side_effect)(5))
print("Third call, need to trace computation graph again, type change:", jit(impure_print_side_effect)(jnp.array([5,])))
Executing function: First call (tracing) 4 Second call, no more side effects b/c traced and XLA-compiled already 5 Executing function: Third call, need to trace computation graph again, type change: [5]
# Example: don't intefere with global variable values!
g = 0
def impure_uses_globals(x):
return x + g
print("First call (caches value of g=0):", jit(impure_uses_globals)(4))
# Now do the crime of updating a global variable:
g = 10
print("Second call looks wrong:", jit(impure_uses_globals)(5))
# If change the type of the func arg, JAX will retrace and now it will read the
# most recent value of the global variable
print("Third call:", jit(impure_uses_globals)(jnp.array([4])))
First call (caches value of g=0): 4 Second call looks wrong: 5 Third call: [14]
# Example: valid pure function (Haiku/Flax built on this idea)
def pure_uses_internal_state(x):
state = dict(even=0, odd=0)
for i in range(10):
state["even" if i%2==0 else "odd"] += x
return state["even"], state["odd"]
print(jit(pure_uses_internal_state)(7))
(Array(35, dtype=int32, weak_type=True), Array(35, dtype=int32, weak_type=True))
# Example: iterators are forbidden!
# lax.fori_loop: similarly for lax.scan, lax.cond, etc.
array = jnp.arange(10)
print("Correct Answer:", lax.fori_loop(0, 10, lambda i, x: x+array[i], 0))
# however this one breaks b/c an iterator is used.
# Iterators are stateful objects, so they violate the purity constraint of JAX functions
iterator = iter(range(10))
print("Incorrect Answer:", lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0))
Correct Answer: 45 Incorrect Answer: 0
$\text{Problem}$: Given a JAX NumPy array, how should one update the value of an element in the array in-place?
$\text{Solution}$: Using the .at[].set() syntax, for example:
jax_arr = jnp.zeros(shape=(3, 3), dtype=jnp.float32)
updated_jax_arr = jax_arr.at[1,:].set(jnp.pi)
print(jax_arr)
print(updated_jax_arr)
# This may seem wasteful, but XLA is smart enough to figure out that
# if one isn't using the input array jax_arr, then it won't allocate
# a special memory object for the output array, it will simply reuse
# the input array and modify it in-place despite not appearing to do so
# at this high-level API.
[[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] [[0. 0. 0. ] [3.1415927 3.1415927 3.1415927] [0. 0. 0. ]]
# However, one still has access to the expressiveness of NumPy!!!
# Example:
print(jax_arr)
another_updated_jax_arr = jax_arr.at[::2, 1:].add(7)
print(another_updated_jax_arr)
[[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] [[0. 7. 7.] [0. 0. 0.] [0. 7. 7.]]
$\text{Problem}$: Since JAX wants to be accelerator agnostic, how does JAX handle out-of-bounds indexing?
$\text{Solution}$: Rather than throwing an error message or exception, it simply does a certain “clamping”.
# NumPy behavior
try:
np.arange(10)[11]
except Exception as e:
print(f"Exception {e}")
Exception index 11 is out of bounds for axis 0 with size 10
# JAX behavior is disturbing and likely to be the cause of many bugs!!!
# similar to NaN behavior when doing invalid floating point arithmetic
print(jnp.arange(10).at[11].add(823942))
print(jnp.arange(10)[11])
[0 1 2 3 4 5 6 7 8 9] 9
# Another "gotcha" of JAX:
# NumPy
print(np.sum([1, 2, 3]))
# JAX
try:
jnp.sum([1, 2, 3])
except Exception as e:
print(f"TypeError: {e}")
6 TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
# The reason for this behavior can, as with any JAX behavior, be dissected via jaxpr:
def permissive_sum(x):
return jnp.sum(jnp.array(x))
x = list(range(10)) #[0,..., 9]
print(make_jaxpr(permissive_sum)(x))
# Thus, JAX is good for researchers looking for highly-optimized and flexible programs
# for a beginnner, PyTorch is friendlier
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
j:i32[]. let
k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
l:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] k
m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
n:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] m
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
p:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] o
q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
r:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] q
s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
t:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] s
u:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
v:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] u
w:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
x:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] w
y:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
z:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] y
ba:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
bb:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] ba
bc:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
bd:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] bc
be:i32[10] = concatenate[dimension=0] l n p r t v x z bb bd
bf:i32[] = reduce_sum[axes=(0,)] be
in (bf,) }
# NumPy - PRNG is stateful!
print(np.random.random()) # PRNG state is advanced
print(np.random.random()) # thus consuming entropy of PRNG
# in both cases, PRNG state change hidden from us...let's expose!
np.random.seed(seed=0) #seed is INITIAL PRNG state!!!
rng_state = np.random.get_state() #get state of PRNG
print(rng_state[2:]) #print this metadata from state
_ = np.random.uniform()
rng_state = np.random.get_state()
print(rng_state[2:])
_ = np.random.uniform()
rng_state = np.random.get_state()
print(rng_state[2:])
# Mersenne Twister PRNG known to have problems (NumPy's implementation of PRNG)
0.6027633760716439 0.5448831829968969 (624, 0, 0.0) (2, 0, 0.0) (4, 0, 0.0)
# In functional programming paradigm, JAX random functions can't modify PRNG state
key = random.PRNGKey(seed=0)
print(key)
print(random.normal(key, shape=(1,)))
print(key)
print(random.normal(key, shape=(1,)))
print(key)
# clearly not random! How to deal with it?
[0 0] [1.6226422] [0 0] [1.6226422] [0 0]
# Solution: splitting! Not just a key, but key + subkey!
print("Old key:", key)
key, subkey = random.split(key) # can also split into > 1 subkeys
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("Key:", key)
print("Subkey:", subkey)
print("Random number:", normal_pseudorandom)
# functionally, key and subkey are indistinguishable, only a convention
# try running this cell multiple times!
Old key: [ 197075234 2075500836] Key: [2234728722 1518742019] Subkey: [3499959921 3652298783] Random number: [-0.2507795]
# Why this design?
# Answer: code is reproducible, parallelizable and vectorizable
np.random.seed(seed=0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
# b/c start from same seed state, calling foo() gives same answer each time
# However, if we were to jit foo(), then jit might decide to parallelize the
# program by calling bar() on one core and baz() on another one, hence one may
# obtain different results (e.g. 0.3 + 2*0.4 vs. 0.4 + 2*0.3 which is not reproducible)
1.9791922366721637
# NumPy
np.random.seed(seed=0)
print("Individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(seed=0)
print("Simultaneously:", np.random.uniform(size=3))
print("They're the same!")
# JAX
key = random.PRNGKey(seed=0)
subkeys = random.split(key, 3) # creating 3 subkeys
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("Individually:", sequence)
key = random.PRNGKey(seed=0)
print("Simultaneously:", random.normal(key, shape=(3,)))
print("They're different!")
Individually: [0.5488135 0.71518937 0.60276338] Simultaneously: [0.5488135 0.71518937 0.60276338] They're the same! Individually: [ 1.0040143 -2.4424558 1.2956359] Simultaneously: [ 1.6226422 2.0252647 -0.43359444] They're different!
# Python control flow + grad() --> no problems!
def f(x):
if x < 3:
return 3* x**2
else:
return -4*x
x = np.linspace(-10, 10, 1000)
y = [f(_) for _ in x]
plt.plot(x, y); plt.show()
print(grad(f)(2.)) #correct!
print(grad(f)(4.)) #correct!
12.0 -4.0
# but if you also want to jit f, then because we're conditioning on x
# in the function scope (if x < 3), it must be passed as a static argument
f_jit = jit(f, static_argnums=(0,))
x = 2
print(make_jaxpr(f_jit, static_argnums=(0,))(x))
print(f_jit(x))
{ lambda ; . let
a:i32[] = pjit[name=f jaxpr={ lambda ; . let in (12:i32[],) }]
in (a,) }
12
def f(x, n):
y = 0
for i in range(n):
y = y + x[i]
return y
f_jit = jit(f, static_argnums=(1,))
x = (jnp.array([2, 3, 4]), 15)
print(x)
print(*x)
print(make_jaxpr(f_jit, static_argnums=(1,))(*x))
print(f_jit(*x))
print(2+3+4*13)
(Array([2, 3, 4], dtype=int32), 15)
[2 3 4] 15
{ lambda ; a:i32[3]. let
b:i32[] = pjit[
name=f
jaxpr={ lambda ; a:i32[3]. let
c:i32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
d:i32[] = squeeze[dimensions=(0,)] c
e:i32[] = add 0:i32[] d
f:i32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
g:i32[] = squeeze[dimensions=(0,)] f
h:i32[] = add e g
i:i32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
j:i32[] = squeeze[dimensions=(0,)] i
k:i32[] = add h j
l:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 3:i32[]
m:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a l
n:i32[] = add k m
o:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 4:i32[]
p:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a o
q:i32[] = add n p
r:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 5:i32[]
s:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a r
t:i32[] = add q s
u:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 6:i32[]
v:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a u
w:i32[] = add t v
x:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 7:i32[]
y:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a x
z:i32[] = add w y
ba:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 8:i32[]
bb:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a ba
bc:i32[] = add z bb
bd:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 9:i32[]
be:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a bd
bf:i32[] = add bc be
bg:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 10:i32[]
bh:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a bg
bi:i32[] = add bf bh
bj:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 11:i32[]
bk:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a bj
bl:i32[] = add bi bk
bm:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 12:i32[]
bn:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a bm
bo:i32[] = add bl bn
bp:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 13:i32[]
bq:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a bp
br:i32[] = add bo bq
bs:i32[1] = broadcast_in_dim[
broadcast_dimensions=()
shape=(1,)
sharding=None
] 14:i32[]
bt:i32[] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), operand_batching_dims=(), start_indices_batching_dims=())
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(1,)
unique_indices=True
] a bs
b:i32[] = add br bt
in (b,) }
] a
in (b,) }
57
57
# Better way (though less readable) solution is to use low-level Lax API
def f_fori(x, n):
body_fun = lambda i,val: val+x[i]
return lax.fori_loop(0, n, body_fun, 0)
f_fori_jit = jit(f_fori)
print(make_jaxpr(f_fori_jit)(*x))
print(f_fori_jit(*x))
{ lambda ; a:i32[3] b:i32[]. let
c:i32[] = pjit[
name=f_fori
jaxpr={ lambda ; a:i32[3] b:i32[]. let
_:i32[] _:i32[] c:i32[] = while[
body_jaxpr={ lambda ; d:i32[3] e:i32[] f:i32[] g:i32[]. let
h:i32[] = add e 1:i32[]
i:bool[] = lt e 0:i32[]
j:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
k:i32[] = add j 3:i32[]
l:i32[] = select_n i e k
m:i32[1] = dynamic_slice[slice_sizes=(1,)] d l
n:i32[] = squeeze[dimensions=(0,)] m
o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
p:i32[] = add o n
in (h, f, p) }
body_nconsts=1
cond_jaxpr={ lambda ; q:i32[] r:i32[] s:i32[]. let
t:bool[] = lt q r
in (t,) }
cond_nconsts=0
] a 0:i32[] b 0:i32[]
in (c,) }
] a b
in (c,) }
57
# Conditioning on data dimensionality is permitted
def log2_if_rank_2(x):
if x.ndim == 2:
ln_x = jnp.log(x)
ln_2 = jnp.log(2)
return ln_x/ln_2
else:
return x
print(make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
# because array is 3D, just return input as output
{ lambda ; a:i32[3]. let in (a,) }
jnp.divide(0, 0)
Array(nan, dtype=float32, weak_type=True)
jnp.divide(0, 0)
from jax import config
config.update("jax_debug_nans", True)
Invalid nan value encountered in the output of a jax.jit function. Calling the de-optimized version.
--------------------------------------------------------------------------- FloatingPointError Traceback (most recent call last) Cell In[160], line 1 ----> 1 jnp.divide(0, 0) 2 from jax import config 3 config.update("jax_debug_nans", True) File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py:2481, in divide(x1, x2) 2478 @export 2479 def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: 2480 """Alias of :func:`jax.numpy.true_divide`.""" -> 2481 return true_divide(x1, x2) [... skipping hidden 4 frame] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py:2473, in true_divide(x1, x2) 2471 x1, x2 = promote_args_inexact("true_divide", x1, x2) 2472 jnp_error._set_error_if_divide_by_zero(x2) -> 2473 out = lax.div(x1, x2) 2474 jnp_error._set_error_if_nan(out) 2475 return out [... skipping hidden 8 frame] File ~/jax_linus/lib/python3.10/site-packages/jax/_src/pjit.py:181, in _python_pjit_helper(fun, jit_info, *args, **kwargs) 179 except api_util.InternalFloatingPointError as e: 180 if getattr(fun, '_apply_primitive', False): --> 181 raise FloatingPointError(f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None 182 api_util.maybe_recursive_nan_check(e, fun, args, kwargs) 184 if p.box_data: FloatingPointError: invalid value (nan) encountered in div
# JAX enforces single precision, though there are ways around this.
x = random.uniform(key, shape=(1000,), dtype=jnp.float64)
print(x.dtype)
print("__________________")
print(type(x))
float32 __________________ <class 'jaxlib._jax.ArrayImpl'>
/tmp/ipykernel_424021/1243311004.py:3: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more. x = random.uniform(key, shape=(1000,), dtype=jnp.float64)







































