JAX on TPU Guide

How to run JAX on Kubernetes with Kubeflow Trainer on Cloud TPU

This guide describes how to use TrainJob to train or fine-tune AI models with JAX on Cloud TPU on Google Kubernetes Engine (GKE).


Prerequisites

Before exploring this guide, make sure to follow:

apiVersion: cloud.google.com/v1
kind: ComputeClass
metadata:
  name: tpu-multihost-v5-8
spec:
  priorities:
  - tpu:
      type: tpu-v5-lite-podslice
      count: 4
      topology: 2x4
  nodePoolAutoCreation:
    enabled: true

JAX on TPU Overview

JAX on TPU requires a different runtime environment than GPU. Specifically:

  • Image: You must use a JAX image compatible with TPUs (e.g., us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu).
  • Resources: You must request google.com/tpu resources.
  • Node Selectors: You must specify GKE-specific node selectors and topology for TPU nodes.
  • Environment Variables: Some TPU-specific JAX environment variables might be required depending on your JAX version.

JAX Distributed Environment on TPU

Your training script must explicitly initialize the JAX distributed runtime.

from kubeflow.trainer import TrainerClient, CustomTrainer
from kubeflow.trainer.options import kubernetes as k8s_options

def get_jax_tpu_dist():
    import os
    import jax
    import jax.distributed as dist

    # Initialize distributed JAX.
    dist.initialize(
        coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"],
        num_processes=int(os.environ["JAX_NUM_PROCESSES"]),
        process_id=int(os.environ["JAX_PROCESS_ID"]),
    )

    print("JAX Distributed Environment on TPU")
    print(f"Local devices: {jax.local_devices()}")
    print(f"Global device count: {jax.device_count()}")
    print(f"Process index: {jax.process_index()}")

    import jax.numpy as jnp

    # Use local_device_count for the leading axis of the local input to pmap
    x = jnp.ones((jax.local_device_count(),))
    
    # Pass process_index as a sharded argument to ensure SPMD consistency across all processes in the distributed job.
    p_idx = jnp.array([jax.process_index()] * jax.local_device_count())
    
    y = jax.pmap(lambda v, p: v * p)(x, p_idx)

    print("PMAP result:", y)

client = TrainerClient()

# Define TPU Node Selectors and Tolerations
# Replace with your GKE TPU configuration
node_selector = {
    "cloud.google.com/compute-class": "tpu-multihost-v5-8",
    "cloud.google.com/gke-tpu-accelerator": "tpu-v5-lite-podslice",
    "cloud.google.com/gke-tpu-topology": "2x4",
}

job_patch = k8s_options.RuntimePatch(
    training_runtime_spec=k8s_options.TrainingRuntimeSpecPatch(
        template=k8s_options.JobSetTemplatePatch(
            spec=k8s_options.JobSetSpecPatch(
                replicated_jobs=[
                    k8s_options.ReplicatedJobPatch(
                        name="node",
                        template=k8s_options.JobTemplatePatch(
                            spec=k8s_options.JobSpecPatch(
                                template=k8s_options.PodTemplatePatch(
                                    spec=k8s_options.PodSpecPatch(
                                        node_selector=node_selector,
                                        tolerations=[
                                            {
                                                "key": "google.com/tpu",
                                                "operator": "Exists",
                                                "effect": "NoSchedule",
                                            },
                                            {
                                                "key": "cloud.google.com/compute-class",
                                                "operator": "Exists",
                                                "effect": "NoSchedule",
                                            },
                                        ],
                                    )
                                )
                            )
                        )
                    )
                ]
            )
        )
    )
)

# Create TrainJob
job_id = client.train(
    runtime="jax-distributed",
    trainer=CustomTrainer(
        func=get_jax_tpu_dist,
        image="us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest",
        num_nodes=2,
        resources_per_node={
            "google.com/tpu": 4,
        },
        env={
            "JAX_PLATFORMS": "tpu,cpu",
            "ENABLE_PJRT_COMPATIBILITY": "true",
        }
    ),
    options=[job_patch],
)

# Wait until completion
client.wait_for_job_status(job_id)

# Logs are aggregated from node-0
print("\n".join(client.get_job_logs(name=job_id)))

End-to-end Training Example

The following example demonstrates how to train a simple CNN on the MNIST dataset using JAX on multihost TPUs.

from kubeflow.trainer import TrainerClient, CustomTrainer
from kubeflow.trainer.options import kubernetes as k8s_options

def train_mnist_jax():
    import os
    import jax
    import jax.numpy as jnp
    import jax.distributed as dist
    import optax
    from flax import linen as nn
    from flax.training import train_state
    import tensorflow_datasets as tfds
    import tensorflow as tf

    # Initialize distributed JAX
    dist.initialize(
        coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"],
        num_processes=int(os.environ["JAX_NUM_PROCESSES"]),
        process_id=int(os.environ["JAX_PROCESS_ID"]),
    )

    # Prevent TF from grabbing TPU
    tf.config.set_visible_devices([], 'TPU')

    process_index = jax.process_index()
    local_device_count = jax.local_device_count()

    print(f"Process: {process_index}")
    print(f"Global devices: {jax.device_count()}")
    print(f"Local devices: {jax.local_devices()}")

    # Model definition
    class CNN(nn.Module):
        @nn.compact
        def __call__(self, x):
            x = nn.Conv(features=32, kernel_size=(3, 3))(x)
            x = nn.relu(x)
            x = nn.avg_pool(x, (2, 2), (2, 2))
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(128)(x)
            x = nn.relu(x)
            x = nn.Dense(10)(x)
            return x

    # Dataset sharding:
    # In JAX's SPMD model, each process runs the same code but should handle
    # different data to increase throughput. Without sharding, every node
    # would process the same images, wasting compute.
    ds = tfds.load("mnist", split="train", as_supervised=True)
    ds = ds.shard(num_shards=jax.process_count(), index=process_index)

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        return image, label

    ds = ds.map(preprocess).batch(128).prefetch(1)
    ds = tfds.as_numpy(ds)

    # Training setup
    model = CNN()
    rng = jax.random.PRNGKey(0)

    params = model.init(rng, jnp.ones([1, 28, 28, 1]))["params"]

    tx = optax.adam(1e-3)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx,
    )

    # replicate state across local devices
    state = jax.device_put_replicated(state, jax.local_devices())

    # Training step
    def loss_fn(params, batch):
        images, labels = batch
        logits = model.apply({"params": params}, images)
        onehot = jax.nn.one_hot(labels, 10)
        loss = optax.softmax_cross_entropy(logits, onehot).mean()
        return loss

    grad_fn = jax.value_and_grad(loss_fn)

    def train_step(state, batch):
        loss, grads = grad_fn(state.params, batch)
        # Average gradients across all devices.
        # We must bind "batch" axis in jax.pmap for pmean to work.
        grads = jax.lax.pmean(grads, axis_name="batch")
        state = state.apply_gradients(grads=grads)
        return state, loss

    train_step = jax.pmap(train_step, axis_name="batch")

    # Training loop
    for epoch in range(5):
        for images, labels in ds:
            # Convert to jnp and shard batch per local device
            images = jnp.array(images).reshape(
                (local_device_count, -1, 28, 28, 1)
            )
            labels = jnp.array(labels).reshape(
                (local_device_count, -1)
            )

            state, loss = train_step(state, (images, labels))

        if process_index == 0:
            print(f"Epoch {epoch}, Loss: {loss.mean()}")

client = TrainerClient()

# Define TPU Node Selectors and Tolerations
node_selector = {
    "cloud.google.com/compute-class": "tpu-multihost-v5-8",
    "cloud.google.com/gke-tpu-accelerator": "tpu-v5-lite-podslice",
    "cloud.google.com/gke-tpu-topology": "2x4",
}

job_patch = k8s_options.RuntimePatch(
    training_runtime_spec=k8s_options.TrainingRuntimeSpecPatch(
        template=k8s_options.JobSetTemplatePatch(
            spec=k8s_options.JobSetSpecPatch(
                replicated_jobs=[
                    k8s_options.ReplicatedJobPatch(
                        name="node",
                        template=k8s_options.JobTemplatePatch(
                            spec=k8s_options.JobSpecPatch(
                                template=k8s_options.PodTemplatePatch(
                                    spec=k8s_options.PodSpecPatch(
                                        node_selector=node_selector,
                                        tolerations=[
                                            {
                                                "key": "google.com/tpu",
                                                "operator": "Exists",
                                                "effect": "NoSchedule",
                                            },
                                            {
                                                "key": "cloud.google.com/compute-class",
                                                "operator": "Exists",
                                                "effect": "NoSchedule",
                                            },
                                        ],
                                    )
                                )
                            )
                        )
                    )
                ]
            )
        )
    )
)

job_id = client.train(
    runtime="jax-distributed",
    trainer=CustomTrainer(
        func=train_mnist_jax,
        image="us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest",
        num_nodes=2,
        resources_per_node={
            "google.com/tpu": 4,
        },
        env={
            "JAX_PLATFORMS": "tpu,cpu",
            "ENABLE_PJRT_COMPATIBILITY": "true",
        },
        packages_to_install=[
            "tensorflow-datasets",
            "flax",
            "optax",
            "tensorflow",
        ],
    ),
    options=[job_patch],
)

client.wait_for_job_status(job_id)
print("\n".join(client.get_job_logs(name=job_id)))

TPU Specific Configurations

Node Selectors and Topology

When running on GKE, TPUs are managed via specific node pools and you must match the proper node_selector and tolerations to your TPU node pool labels. If you are using custom ComputeClasses, add the following node_selector and tolerations to your TPU node pool labels:

LabelExample Value
cloud.google.com/compute-classtpu-multihost-v5-8
cloud.google.com/gke-tpu-acceleratortpu-v5-lite-podslice
cloud.google.com/gke-tpu-topology2x4
Toleration KeyToleration OperatorToleration Effect
google.com/tpuExistsNoSchedule
cloud.google.com/compute-classExistsNoSchedule

Environment Variables

VariableDescription
JAX_PLATFORMSSet to tpu,cpu to ensure JAX uses the TPU backend.
ENABLE_PJRT_COMPATIBILITYSet to true for compatibility with newer JAX/LibTPU versions.

Next Steps

Feedback

Was this page helpful?