在 Docker 容器中執行 TPU 工作負載

Docker 容器可將程式碼和所有必要的依附元件合併至單一可散發套件,讓您更輕鬆地設定應用程式。您可以在 TPU VM 中執行 Docker 容器,簡化 Cloud TPU 應用程式的設定和共用作業。本文件說明如何為 Cloud TPU 支援的每個 ML 架構設定 Docker 容器。

在 Docker 容器中訓練 PyTorch 模型

TPU 裝置

  1. 建立 Cloud TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. 使用 SSH 連線至 TPU VM

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. 請確認您的 Google Cloud 使用者已獲授「Artifact Registry Reader」角色。詳情請參閱「授予 Artifact Registry 角色」。

  4. 使用 PyTorch/XLA 夜間映像檔,在 TPU VM 中啟動容器

    sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. 設定 TPU 執行階段

    PyTorch/XLA 執行階段有兩種選項:PJRT 和 XRT。除非您有使用 XRT 的理由,否則建議您使用 PJRT。如要進一步瞭解不同的執行階段設定,請參閱 PJRT 執行階段說明文件

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  6. 複製 PyTorch XLA 存放區

    git clone --recursive https://github.com/pytorch/xla.git
  7. 訓練 ResNet50

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

訓練指令碼執行完畢後,請務必清理資源。

  1. 輸入 exit 即可退出 Docker 容器
  2. 輸入 exit 即可退出 TPU VM
  3. 刪除 TPU VM

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU 配量

在 TPU 配量上執行 PyTorch 程式碼時,您必須同時在所有 TPU worker 上執行程式碼。其中一種方法是使用 gcloud compute tpus tpu-vm ssh 指令,並加上 --worker=all--command 旗標。以下程序說明如何建立 Docker 映像檔,以便輕鬆設定每個 TPU worker。

  1. 建立 TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. 將目前使用者新增至 Docker 群組

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command='sudo usermod -a -G docker $USER'
  3. 複製 PyTorch XLA 存放區

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="git clone --recursive https://github.com/pytorch/xla.git"
  4. 在所有 TPU 工作站上執行容器中的訓練指令碼

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Docker 指令標記:

    • --rm 會在程序終止後移除容器。
    • --privileged 會將 TPU 裝置公開給容器。
    • --net=host 會將所有容器的通訊埠繫結至 TPU VM,以便 Pod 中的主機彼此通訊。
    • -e 會設定環境變數。

訓練指令碼執行完畢後,請務必清理資源。

使用下列指令刪除 TPU VM:

gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b

在 Docker 容器中訓練 JAX 模型

TPU 裝置

  1. 建立 TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. 使用 SSH 連線至 TPU VM

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. 在 TPU VM 中啟動 Docker 精靈

    sudo systemctl start docker
  4. 啟動 Docker 容器

    sudo docker run --net=host -ti --rm --name your-container-name \
    --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. 安裝 JAX

    pip install jax[tpu]
  6. 安裝 FLAX

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. 安裝 tensorflowtensorflow-dataset 套件

    pip install tensorflow
    pip install tensorflow-datasets
  8. 執行 FLAX MNIST 訓練指令碼

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

訓練指令碼執行完畢後,請務必清理資源。

  1. 輸入 exit 即可退出 Docker 容器
  2. 輸入 exit 即可退出 TPU VM
  3. 刪除 TPU VM

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU 配量

在 TPU 配量上執行 JAX 程式碼時,您必須同時在所有 TPU 工作站上執行 JAX 程式碼。其中一種方法是使用 gcloud compute tpus tpu-vm ssh 指令,並加上 --worker=all--command 旗標。以下程序說明如何建立 Docker 映像檔,以便輕鬆設定每個 TPU worker。

  1. 在目前目錄中建立名為 Dockerfile 的檔案,然後貼上以下文字

    FROM python:3.10
    RUN pip install jax[tpu]
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    RUN pip install tensorflow
    RUN pip install tensorflow-datasets
    WORKDIR ./flax/examples/mnist
  2. 準備 Artifact Registry

    gcloud artifacts repositories create your-repo \
    --repository-format=docker \
    --location=europe-west4 --description="Docker repository" \
    --project=your-project
    
    gcloud artifacts repositories list \
    --project=your-project
    
    gcloud auth configure-docker europe-west4-docker.pkg.dev
  3. 建構 Docker 映像檔

    docker build -t your-image-name .
  4. 請先為 Docker 映像檔新增標記,再推送至 Artifact Registry。如要進一步瞭解如何使用 Artifact Registry,請參閱「使用容器映像檔」。

    docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. 將 Docker 映像檔推送至 Artifact Registry

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  6. 建立 TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  7. 從所有 TPU worker 的 Artifact Registry 提取 Docker 映像檔

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command='sudo usermod -a -G docker ${USER}'
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  8. 在所有 TPU 工作站上執行容器

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
  9. 在所有 TPU 工作站上執行訓練指令碼

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

訓練指令碼執行完畢後,請務必清理資源。

  1. 關閉所有 worker 上的容器

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. 刪除 TPU VM

    gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

使用 JAX Stable Stack 在 Docker 容器中訓練 JAX 模型

您可以使用 JAX Stable Stack 基礎映像檔建構 MaxTextMaxDiffusion Docker 映像檔。

JAX Stable Stack 會將 JAX 與 orbaxflaxoptaxlibtpu.so 等核心套件捆綁,為 MaxText 和 MaxDiffusion 提供一致的環境。這些程式庫已通過測試,確保相容性,並提供穩定的基礎,可用於建構及執行 MaxText 和 MaxDiffusion。這樣一來,就能避免因不相容的套件版本而產生潛在衝突。

JAX Stable Stack 包含已發布且符合資格的 libtpu.so,這是驅動 TPU 程式編譯、執行和 ICI 網路設定的核心程式庫。libtpu 版本會取代先前 JAX 使用的每晚版本,並透過 HLO/StableHLO IR 中的 PJRT 級別資格測試,確保 TPU 上的 XLA 運算功能一致。

如要使用 JAX Stable Stack 建構 MaxText 和 MaxDiffusion Docker 映像檔,請在執行 docker_build_dependency_image.sh 指令碼時,將 MODE 變數設為 stable_stack,並將 BASEIMAGE 變數設為要使用的基礎映像檔。

docker_build_dependency_image.sh 位於 MaxDiffusion GitHub 存放區MaxText GitHub 存放區中。複製要使用的存放區,然後執行該存放區中的 docker_build_dependency_image.sh 指令碼,以建構 Docker 映像檔。

git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
git clone https://github.com/AI-Hypercomputer/maxtext.git

下列指令會使用 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 做為基本映像檔,產生可與 MaxText 和 MaxDiffusion 搭配使用的 Docker 映像檔。

sudo bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

如需可用的 JAX Stable Stack 基本映像檔清單,請參閱「Artifact Registry 中的 JAX Stable Stack 映像檔」。

後續步驟