For the purpose of this document, batch workloads are defined as JAX workloads that execute to completion and are deployed within the same GKE cluster as the Pathways cluster, specifically alongside the Pathways controller components (IFRT proxy server and Pathways resource manager). Completion of the JAX workload also terminates the Pathways cluster components. This guide uses a JAX training workload to demonstrate this.
Before you begin
Make sure you have:
- Created a GKE cluster.
- Installed XPK
- Installed Kubernetes tools
- Enabled the TPU API
- Enabled the Google Kubernetes Engine API
- Ensure your Google Cloud project is allowlisted for Pathways
Build a training image using Maxtext
MaxText is an open-source large language model (LLM) project developed by Google. It's written in JAX and designed to be highly performant and scalable, running efficiently on Google Cloud TPUs and GPUs.
To build a MaxText Docker image by using the latest version of stable JAX from the OSS GitHub repository, run the following command:
git clone https://github.com/AI-Hypercomputer/maxtext cd maxtext/ gcloud config set project PROJECT bash ./docker_build_dependency_image.sh MODE=stable gcloud auth configure-docker bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner
This command pushes the MaxText Kubernetes image to gcr.io/$PROJECT/${USER}_runner
.
You can use this Docker image to run training on TPUs using Pathways backend.
Run a batch workload using the PathwaysJob API
The following manifest deploys the Pathways components and runs a MaxText
workload using the PathwaysJob API.
The workload is encapsulated in the main
container and exercises train.py
.
Copy the following YAML into a file named pathways-job-batch-training.yaml
and
update the editable values.
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: maxRestarts: MAX_RESTARTS workers: - type: TPU_MACHINE_TYPE topology: TOPOLOGY numSlices: WORKLOAD_NODEPOOL_COUNT pathwaysDir: "gs://BUCKET_NAME" controller: deploymentMode: default template: spec: containers: - name: main image: gcr.io/PROJECT/USER_runner command: - bash - -c - | python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=gs://BUCKET_NAME \ run_name=RUN_NAME \ per_device_batch_size=1 \ enable_checkpointing=false \ remat_policy=full \ global_parameter_scale=1 \ steps=20 \ max_target_length=2048 \ use_iota_embed=true \ reuse_example_batch=1 \ dataset_type=synthetic \ attention=flash \ gcs_metrics=True \ enable_single_controller=True
Replace the following:
USER
: your Google Cloud user IDMAX_RESTARTS
: the maximum number of times the Job can be restartedTPU_MACHINE_TYPE
: the TPU machine typeTOPOLOGY
: the TPU v4 or later topology. For more information about TPU versions and supported topologies, see TPU versionsWORKLOAD_NODEPOOL_COUNT
: the number of node pools used by a Pathways workloadBUCKET_NAME
: a Cloud Storage bucket for storing temporary filesPROJECT
: your Google Cloud project IDRUN_NAME
: a user-assigned name to identify the workflow run
You can deploy the PathwaysJob
YAML as follows:
kubectl apply -f pathways-job-batch-training.yaml
To view the PathwaysJob
instance is created by the previous command use:
kubectl get pathwaysjob
The output should look like this:
NAME AGE pathways-trial 9s
To modify an attribute of the PathwaysJob
instance, delete the PathwaysJob
instance, modify the YAML and apply it to create a new PathwaysJob
instance.
You can follow the progress of your workload by navigating to the Logs Explorer for
your JAX container by choosing main
under the Container Name filter.
You should see logs like the following which indicates training is progressing. The workload will complete after 30 steps.
completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888 completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697 completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641 completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547 completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179 completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776
To delete the PathwaysJob
instance, you can use the following command:
kubectl delete -f pathways-job-batch-training.yaml
Run a batch workload using XPK
Now you can submit the prebuilt Maxtext docker image using XPK with the same command you used previously.
xpk workload create-pathways \ --workload=WORKLOAD \ --cluster=CLUSTER \ --num-slices=WORKLOAD_NODEPOOL_COUNT \ --tpu-type=TPU_TYPE \ --project=PROJECT \ --zone=ZONE \ --docker-image='gcr.io/PROJECT/USER_runner' \ --command="python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"
Replace the following:
WORKLOAD
: a unique name to identify your workloadCLUSTER
: the name of your GKE clusterWORKLOAD_NODEPOOL_COUNT
: the maximum number of times the job can be restartedTPU_TYPE
: the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versionsPROJECT
: you Google Cloud project IDZONE
: the zone where you plan to run your workloadUSER
: your Google Cloud user IDRUN_NAME
: a user-assigned name to identify the workflow run
You should see output like the following:
[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT
Use the link in the output from the previous XPK command to follow the progress
of your workload. You can filter the logs for your JAX container by choosing
jax-tpu
under the Container Name filter.
completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888 completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697 completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641 completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547 completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179 completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776
The workload will complete after the specified number of steps, however, if you want to terminate it prematurely, use the following command:
xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE