Run a batch workload with Pathways

For the purpose of this document, batch workloads are defined as JAX workloads that execute to completion and are deployed within the same GKE cluster as the Pathways cluster, specifically alongside the Pathways controller components (IFRT proxy server and Pathways resource manager). Completion of the JAX workload also terminates the Pathways cluster components. This guide uses a JAX training workload to demonstrate this.

Before you begin

Make sure you have:

Build a training image using Maxtext

MaxText is an open-source large language model (LLM) project developed by Google. It's written in JAX and designed to be highly performant and scalable, running efficiently on Google Cloud TPUs and GPUs.

To build a MaxText Docker image by using the latest version of stable JAX from the OSS GitHub repository, run the following command:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner

This command pushes the MaxText Kubernetes image to gcr.io/$PROJECT/${USER}_runner. You can use this Docker image to run training on TPUs using Pathways backend.

Run a batch workload using the PathwaysJob API

The following manifest deploys the Pathways components and runs a MaxText workload using the PathwaysJob API. The workload is encapsulated in the main container and exercises train.py.

Copy the following YAML into a file named pathways-job-batch-training.yaml and update the editable values.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: "gs://BUCKET_NAME"
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Replace the following:

  • USER : your Google Cloud user ID
  • MAX_RESTARTS : the maximum number of times the Job can be restarted
  • TPU_MACHINE_TYPE : the TPU machine type
  • TOPOLOGY : the TPU v4 or later topology. For more information about TPU versions and supported topologies, see TPU versions
  • WORKLOAD_NODEPOOL_COUNT : the number of node pools used by a Pathways workload
  • BUCKET_NAME : a Cloud Storage bucket for storing temporary files
  • PROJECT : your Google Cloud project ID
  • RUN_NAME : a user-assigned name to identify the workflow run

You can deploy the PathwaysJob YAML as follows:

kubectl apply -f pathways-job-batch-training.yaml

To view the PathwaysJob instance is created by the previous command use:

kubectl get pathwaysjob

The output should look like this:

NAME             AGE
pathways-trial   9s

To modify an attribute of the PathwaysJob instance, delete the PathwaysJob instance, modify the YAML and apply it to create a new PathwaysJob instance.

You can follow the progress of your workload by navigating to the Logs Explorer for your JAX container by choosing main under the Container Name filter.

You should see logs like the following which indicates training is progressing. The workload will complete after 30 steps.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

To delete the PathwaysJob instance, you can use the following command:

kubectl delete -f pathways-job-batch-training.yaml

Run a batch workload using XPK

Now you can submit the prebuilt Maxtext docker image using XPK with the same command you used previously.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Replace the following:

  • WORKLOAD: a unique name to identify your workload
  • CLUSTER: the name of your GKE cluster
  • WORKLOAD_NODEPOOL_COUNT : the maximum number of times the job can be restarted
  • TPU_TYPE: the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versions
  • PROJECT : you Google Cloud project ID
  • ZONE: the zone where you plan to run your workload
  • USER : your Google Cloud user ID
  • RUN_NAME : a user-assigned name to identify the workflow run

You should see output like the following:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Use the link in the output from the previous XPK command to follow the progress of your workload. You can filter the logs for your JAX container by choosing jax-tpu under the Container Name filter.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

The workload will complete after the specified number of steps, however, if you want to terminate it prematurely, use the following command:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

What's next