Introduction to JAX: PyTorch to JAX: A Mental Model Shift- Part 1

March 20, 2026

JAX is a more recent development in numerical computing and machine learning, developed by Google Research. JAX bridges NumPy with performance optimizations via Just-In-Time compilation and functional programming principles. There are three fundamental pillars within JAX. First the idea of immutability, once an array is created, we no longer have the ability to modify it in place, in a manner that we are able to do in NumPy. However, this idea of immutability enables optimizations and more predictability within parallel computing environments. Second, JAX provides automatic differentiation (or autodiff) which can compute gradients on any function. The advantage here is the speed gain and numerically stability. Finally, JAX provides advanced compilation capabilities that accelerate computations. JAX Fundamentals

Immutability

JAX arrays look and behave similar to NumPy arrays with key differences such as immutability designed for functional programming. In order to modify an array JAX provides a method for updates:

Jax_array = jnp.array([1,2,3])
Updated_jax_array = jax_array.at[0].set(10)

This approach creates a new array is created with the requested changes. The immutability requirement ensures that a program running on multiple threads, does not lead to a race condition. A race condition happens when multiple threads access and modify shared data at the same time, leading to unpredictable or incorrect results. Since JAX arrays are immutable, threads cannot change the original data, thus only able to create new arrays. This prevents accidental overwrites and makes parallel code much safer. The main downside is the use of memory in creating copies of the existing arrays, however, a common technique is copy-on-write or using persistent data structures whereby only the changed parts are copied, not the entire array. ##Pure functions In functional programming pure functions satisfy the following:

  1. Deterministic behavior: A pure function will always produce the same outputs, will not depend on external states and will not initiate variation between calls (example incrementing global state variable).
  2. No side effects: A pure function will not modify code outside of the function, meaning there are no changes to global variable, modification to input arguments, printing to console or writing to disk or in other words, the pure function does not interact with external systems. Modification of input argurments in place can also occur, take the following for example:
def inplace_modification(data_list,val):
for i in range(len(data_list)):
data_list[i] *= val
return data_list# Demonstrating the side effect
my_list = [1, 2, 3]
print(f"Original list: {my_list}")
modified_list = impure_scale_list_inplace(my_list, 2)
print(f"List after impure_scale_list_inplace: {my_list} (Input modified)")
print(f"Returned list is same object: {modified_list is my_list}")

This functionality occurs because Python passes a reference to the original list (my_list) not a copy of my_list ( in Python lists are mutable). Within the inplace_modification() function the loop then updates data_list which is however just another name for my_list,. This occurs because in memory both names (data_list and my_list) point to the same list object. In a purely functional language, data structures are immutable by default. When a data structure is passed to a function, the language creates a new list with the changes applied by the function, leaving the original list unchanged. Functions cannot alter existing data or external states, these functions return new values based on the input, guaranteeing no side effects and preserving functional purity. In memory, as integers are immutable passing a int into a function, creates a copy in memory, therefore pointer are at two separate objects. However, python lists (or NumPy arrays), both of which are mutable, the parameters inside the function and the variable outside the function point to the same object. Therefore, changes within a function holding lists will affect the original input list as well. Immutability is enforced by the data structure’s implementation in Python’s core, for instance .append() is a core method available to update a list in place. For immutable types like integers, tuples, strings, python does not provide any methods or operations that can alter the object, after the object is created. Any operation that might seems like it is modifying an immutable object, actually creates a new object in memory while preserving the original as is. Mutable types, however, expose methods that can change their contents in place, so their memory can be updated directly. For example, when a function is passed a list in its arguments, and snce lists have .append() method defined in their class, any function that receives a list can call list.append(). This is possible as Python allows dynamic typing and method lookup based on the objects type at runtime. In Python, method lookup at runtime uses the object’s type and its class hierarchy, which is called Method Resolution Order (MRO). When obj.method() is invoked, Python checks if method exists in obj’s class. If the method is not found, the function looks up the parent classes in order following MRO. If in the end the method id not found, Python raises an AttributeError. This flexibility allows for powerful object-origented programming by can lead to impure functions as a consequence of the flexibility. Within JAX framework, there is heavy reliance on pure functions as JAX compiles functions for optimization. In doing so, it is expected that a function will have deterministic behavior every time the function is called with the same inputs. Automatic differentiation operates by anaylzing a function with mathematical operations, if a function contains global variables for insta,ce the side effects themselves do not have mathematical derivatives and therefore differentiation would be meaningless. Extending pure functions to parallelization, means when JAX distributes computations across multiple processes or devices, it needs to know that the functions will not interfere with each other through shared state modifications, pure functions can guarantee this independence.