搭配使用 PyTorch 和 Cloud TPU 訓練 Resnet50


這個教學課程將說明如何使用 PyTorch 在 Cloud TPU 裝置中訓練 ResNet-50 模型。如果您有其他已針對 TPU 完成最佳化處理的圖片分類模型,而且這些模型使用的是 PyTorch 和 ImageNet 資料集,您也可以按照這個教學課程中的步驟對其進行訓練。

本教學課程中的模型是以圖像識別的深度殘差學習為基礎,該論文首度提出殘差網路 (ResNet) 架構的概念。這個教學課程使用了含有 50 層架構的變化版本「ResNet-50」,並說明如何使用 PyTorch/XLA 訓練模型。

目標

  • 準備資料集。
  • 執行訓練工作。
  • 驗證輸出結果。

費用

In this document, you use the following billable components of Google Cloud:

  • Compute Engine
  • Cloud TPU

To generate a cost estimate based on your projected usage, use the pricing calculator. New Google Cloud users might be eligible for a free trial.

事前準備

開始學習這個教學課程之前,請先檢查您的 Google Cloud 專案設定是否正確。

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. 本逐步操作說明使用 Google Cloud的計費元件。請參閱 Cloud TPU 定價頁面來估算費用。使用完畢後,請務必清除您建立的資源,以免產生不必要的費用。

建立 TPU VM

  1. 開啟 Cloud Shell 視窗。

    開啟 Cloud Shell

  2. 建立 TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. 使用 SSH 連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. 在 TPU VM 上安裝 PyTorch/XLA:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. 複製 PyTorch/XLA GitHub 存放區

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. 使用偽資料執行訓練指令碼

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

清除所用資源

如要避免系統向您的 Google Cloud 帳戶收取本教學課程中所用資源的相關費用,請刪除含有該項資源的專案,或者保留專案但刪除個別資源。

  1. 中斷與 TPU VM 的連線:

    (vm) $ exit

    畫面上的提示現在應為 username@projectname,表示您正在 Cloud Shell 中。

  2. 刪除 TPU VM。

    $ gcloud compute tpus tpu-vm delete your-tpu-name \
       --zone=us-central1-a

後續步驟