Exploring PyTorch and JAX

In this post, we hope to concisely introduce PyTorch and JAX - two prominent frameworks for deep learning [colab] [repo]

We also hoped to cover Tensorflow. But for reasons unclear to us, we were unable to install Tensorflow both with Python 3.12 and Python 3.11. Given the simultaneous development of Tensorflow and JAX within Google, and the recent popularity of JAX, I wouldn’t be surprised if Tensorflow is heading towards deprecation in the future. For completeness, it is worth noting that Tensorflow is the oldest among the three, being first introduced in 2011.

PyTorch was first introduced in 2016 by Adam Paszke and Soumith Chintala along with others at FAIR [1]. Some features that made PyTorch stand out over Tensorflow when it was introduced are: dynamic computational graph, pythonic nature, and extensive ecosystem - notably torchvision, torchaudio, and torchtext.

JAX (“just after execution”) was first introduced in 2018 by Roy Frostig, Matthew James Johnson, and Chris Leary at Google Brain [2]. Some unique features of JAX include: jit compilation (“just in time compilation”), XLA (“accelerated linear algebra”), autovectorization & large data parallelism (via vmap and pmap respectively). JAX is known for its computational efficiency on hardware accelerators like GPUs and TPUs.

What both PyTorch and JAX have in common is automatic differentiation (autograd in pytorch, and just grad in JAX). The execution speed however is faster in JAX since it benefits from autovectorization and jit compilation abilities mentioned earlier. On the other hand, what makes PyTorch and JAX fundamentally different as frameworks is the programming paradigm they use: PyTorch is object-oriented, while JAX is functional.

Let us look at some examples.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt
from torchviz import make_dot
# Define neural net architecture class (784+128+10 neurons, where the second and third layer are fully connected)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
net = Net()
# Load dataset (FashionMNIST from torchvision)
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
trainloader = DataLoader(training_data, batch_size=64, shuffle=True)

# Choose loss function and optimizer (cross-entropy loss, stochastic gradient descent)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# Train the network
for epoch in range(2):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data               # get data stored as [inputs, labels] in trainloader
        optimizer.zero_grad()               # zero the parameter gradients
        outputs = net(inputs)               # forward pass
        loss = criterion(outputs, labels)   # compute loss function
        loss.backward()                     # backward pass
        optimizer.step()                    # update weights
# Load test dataset
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Evaluate the model
def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # No need to track gradients for evaluation
        for images, labels in data_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Calculate the accuracy on the test dataset
accuracy = evaluate_model(net, test_loader)
print(f'Accuracy of the model on the test images: {accuracy}%')
Accuracy of the model on the test images: 57.59%

So we’ve trained a simple neural net with two fully connected layers (784+128+10 neurons) on FashionMNIST (clothes classification dataset in torchvision), with a seemingly lacklustre accuracy of 57.59 percent. In any case, we would like some visualization of the trained weights. For this, we will try two methods: matplotlib and torchviz. Let’s first look at matplotlib.

# Extract weights that connect to first fully connected layer 'fc1' (128x784 matrix)
weights = net.fc1.weight.data.numpy()
# Present weights matrix as 128 images of 28x28 resolution
# num_x = int(np.ceil(np.sqrt(len(weights))))  
# num_y = num_x                                
fig, axes = plt.subplots(11, 12, figsize=(15, 15)) 
for i, ax in enumerate(axes.flat):  
    if i < len(weights):    
        ax.imshow(weights[i].reshape(28, 28), cmap='gray')  
    ax.axis('off')
# plt.tight_layout()
plt.show()
___
___

What we see above is a matplotlib visualization of the trained weight matrix for the first fully connected layer (with 128 neurons, each getting 784 = 28*28 weights, visualized as grayscale images). Let us now try torchviz, a network visualization tool for pytorch.

# Create a dummy input tensor that matches the input shape of the network
dummy_input = torch.randn(1, 784)
# Perform a forward pass to get the output
output = net(dummy_input)
# Visualize the computational graph
graph = make_dot(output.mean(), params=dict(net.named_parameters()))
graph.render('network_graph', format='png')  # This will save the graph as a PNG image
graph
___
___

Above is a torchviz visualization of the neural net flow : two fully connected layers fc1 and fc2, fc1 has 128 neurons each with 784 weights, and fc2 has 10 neurons each with 128 weights. each input is a 28*28 image.

Let us now try to implement the same example in JAX.

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

# Define neural net architecture class
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = jnp.reshape(x, (x.shape[0], -1))  # Flatten input
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x
# Define functions to embed the train & test data into numpy arrays
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
def prepare_dataloader(dataset, *args, **kwargs):
    return DataLoader(dataset, collate_fn=numpy_collate, *args, **kwargs)
# Load train & test data (FashionMNIST from torchvision again)
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
trainloader = prepare_dataloader(training_data, batch_size=64, shuffle=True)
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())
testloader = prepare_dataloader(test_data, batch_size=64, shuffle=False)

# Define cross-entropy loss
def cross_entropy_loss(*, logits, labels):
    labels_onehot = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(labels_onehot * jax.nn.log_softmax(logits), axis=-1))
# Define update step
@jax.jit
def train_step(state, batch):
    inputs, labels = batch
    inputs = jnp.array(inputs).reshape(inputs.shape[0], -1)
    labels = jnp.array(labels)
    def loss_fn(params):
        logits = SimpleNN().apply({'params': params}, inputs)
        loss = cross_entropy_loss(logits=logits, labels=labels)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, logits
    
# Train the network
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
model = SimpleNN()
params = model.init(init_rng, jnp.ones([1, 28 * 28]))['params']

tx = optax.sgd(learning_rate=0.001, momentum=0.9)   # Stochastic Gradient Descent optimizer provided by optax
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

for epoch in range(2):
    for batch in trainloader:
        state, loss, logits = train_step(state, batch)

# Evaluate the network's performance
def accuracy(logits, labels):
    return jnp.mean(jnp.argmax(logits, -1) == labels)
@jax.jit
def eval_step(params, batch):
    inputs, labels = batch
    inputs = jnp.array(inputs).reshape(inputs.shape[0], -1)
    labels = jnp.array(labels)
    logits = model.apply({'params': params}, inputs)
    return accuracy(logits, labels)
accuracies = []
for batch in testloader:
    accuracies.append(eval_step(state.params, batch))
print('Test set accuracy:', np.mean(accuracies))

Test set accuracy: 0.79677546

Interestingly, while PyTorch trained with a poor 57.6 percent test accuracy, JAX got 79.7 percent test accuracy. This is despite using the same optimizer (SGD with learning rate 0.001 and momentum 0.9) and the same loss function (cross-entropy). So either the test dataset is created somewhat differently (or) JAX is superior in accuracy to PyTorch. I don’t know if the latter is true in general, but I guess we’ll learn more as I continue experimenting in the future.

Let’s go ahead and visualize like before. This time we can do exactly what we did with matplotlib earlier, however we cannot use torchviz. In my exploration, I didn’t come across a simple equivialent of torchviz in JAX. Leaving that aside, let’s go ahead and visualize the weights.

weights = state.params['Dense_0']['kernel']
# Transpose the weights to match the input shape for visualization
weights = weights.T

# Reshape and plot the weights
fig, axes = plt.subplots(11, 12, figsize=(15, 15))
for i, ax in enumerate(axes.flat):
    if i < weights.shape[0]:  # Check to avoid index error
        weight = weights[i].reshape(28, 28)  # Reshape the weight to 28x28
        ax.imshow(weight, cmap='gray')
    ax.axis('off')
#plt.tight_layout()
plt.show()
___
___

Interesting. While there isn’t too much we learn at this low a resolution, if we carefully compare the structure of the PyTorch and JAX weights, we see that JAX has more fine-grained variation in the trained weights, which from a distance looks more random than the pytorch weights. But clearly they’ve learnt different weights, and JAX’s superior accuracy might be stemming from its ability to pick up more fine-grained spatial variations in the input images.


References

[1] NeurIPS: PyTorch paper by FAIR
[2] MLSys: JAX paper by Google Brain

Assisted by perplexity.ai

Written on March 19, 2024