Menu
Home Explore People Places Arts History Plants & Animals Science Life & Culture Technology
On this page
JAX (software)
Machine Learning framework designed for parallelization and autograd.

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by Google with contributions from Nvidia and other community contributors.

It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch. The primary features of JAX are:

  1. Providing a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
  2. Built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
  3. Efficient evaluation of gradients via its automatic differentiation transformations.
  4. Automatically vectorized to efficiently map them over arrays representing batches of inputs.
Related Image Collections Add Image
We don't have any YouTube videos related to JAX (software) yet.
We don't have any PDF documents related to JAX (software) yet.
We don't have any Books related to JAX (software) yet.

grad

Main article: Automatic differentiation

The below code demonstrates the grad function's automatic differentiation.

# imports from jax import grad import jax.numpy as jnp # define the logistic function def logistic(x): return jnp.exp(x) / (jnp.exp(x) + 1) # obtain the gradient function of the logistic function grad_logistic = grad(logistic) # evaluate the gradient of the logistic function at x = 1 grad_log_out = grad_logistic(1.0) print(grad_log_out)

The final line should outputː

0.19661194

jit

The below code demonstrates the jit function's optimization through fusion.

# imports from jax import jit import jax.numpy as jnp # define the cube function def cube(x): return x * x * x # generate data x = jnp.ones((10000, 10000)) # create the jit version of the cube function jit_cube = jit(cube) # apply the cube and jit_cube functions to the same data for speed comparison cube(x) jit_cube(x)

The computation time for jit_cube (line #17) should be noticeably shorter than that for cube (line #16). Increasing the values on line #7, will further exacerbate the difference.

vmap

Main article: Array programming

The below code demonstrates the vmap function's vectorization.

# imports from jax import vmap partial import jax.numpy as jnp # define function def grads(self, inputs): in_grad_partial = jax.partial(self._net_grads, self._net_params) grad_vmap = jax.vmap(in_grad_partial) rich_grads = grad_vmap(inputs) flat_grads = np.asarray(self._flatten_batch(rich_grads)) assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0] return flat_grads

The GIF on the right of this section illustrates the notion of vectorized addition.

pmap

The below code demonstrates the pmap function's parallelization for matrix multiplication.

# import pmap and random from JAX; import JAX NumPy from jax import pmap, random import jax.numpy as jnp # generate 2 random matrices of dimensions 5000 x 6000, one per device random_keys = random.split(random.PRNGKey(0), 2) matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys) # without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices) # without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately means = pmap(jnp.mean)(outputs) print(means)

The final line should print the valuesː

[1.1566595 1.1805978]

See also

References

  1. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18 https://web.archive.org/web/20220618205214/https://github.com/google/jax

  2. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.{{cite journal}}: CS1 maint: date and year (link) https://mlsys.org/Conferences/doc/2018/146.pdf

  3. "Using JAX to accelerate our research". www.deepmind.com. Archived from the original on 2022-06-18. Retrieved 2022-06-18. https://www.deepmind.com/blog/using-jax-to-accelerate-our-research

  4. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21. https://web.archive.org/web/20220621143905/https://www.businessinsider.com/facebook-pytorch-beat-google-tensorflow-jax-meta-ai-2022-6

  5. "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18. https://analyticsindiamag.com/why-is-googles-jax-so-popular/

  6. "Quickstart — JAX documentation". https://docs.jax.dev/en/latest/quickstart.html