{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "wJcYs_ERTnnI" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2023-06-09T12:22:01.677780Z", "iopub.status.busy": "2023-06-09T12:22:01.677107Z", "iopub.status.idle": "2023-06-09T12:22:01.681299Z", "shell.execute_reply": "2023-06-09T12:22:01.680633Z" }, "id": "HMUDt0CiUJk9" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "77z2OchJTk0l" }, "source": [ "# Migrate from TPUEstimator to TPUStrategy\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " View on TensorFlow.org\n", " \n", " \n", " \n", " Run in Google Colab\n", " \n", " \n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "meUTrR4I6m1C" }, "source": [ "This guide demonstrates how to migrate your workflows running on [TPUs](../../guide/tpu.ipynb) from TensorFlow 1's `TPUEstimator` API to TensorFlow 2's `TPUStrategy` API.\n", "\n", "- In TensorFlow 1, the `tf.compat.v1.estimator.tpu.TPUEstimator` API lets you train and evaluate a model, as well as perform inference and save your model (for serving) on (Cloud) TPUs.\n", "- In TensorFlow 2, to perform synchronous training on TPUs and TPU Pods (a collection of TPU devices connected by dedicated high-speed network interfaces), you need to use a TPU distribution strategy—`tf.distribute.TPUStrategy`. The strategy can work with the Keras APIs—including for model building (`tf.keras.Model`), optimizers (`tf.keras.optimizers.Optimizer`), and training (`Model.fit`)—as well as a custom training loop (with `tf.function` and `tf.GradientTape`).\n", "\n", "For end-to-end TensorFlow 2 examples, check out the [Use TPUs](../../guide/tpu.ipynb) guide—namely, the *Classification on TPUs* section—and the [Solve GLUE tasks using BERT on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. You may also find the [Distributed training](../../guide/distributed_training.ipynb) guide useful, which covers all TensorFlow distribution strategies, including `TPUStrategy`." ] }, { "cell_type": "markdown", "metadata": { "id": "YdZSoIXEbhg-" }, "source": [ "## Setup\n", "\n", "Start with imports and a simple dataset for demonstration purposes:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:01.684701Z", "iopub.status.busy": "2023-06-09T12:22:01.684240Z", "iopub.status.idle": "2023-06-09T12:22:04.060471Z", "shell.execute_reply": "2023-06-09T12:22:04.059686Z" }, "id": "iE0vSfMXumKI" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-06-09 12:22:02.963028: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import tensorflow as tf\n", "import tensorflow.compat.v1 as tf1" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:04.064697Z", "iopub.status.busy": "2023-06-09T12:22:04.063970Z", "iopub.status.idle": "2023-06-09T12:22:04.068049Z", "shell.execute_reply": "2023-06-09T12:22:04.067380Z" }, "id": "m7rnGxsXtDkV" }, "outputs": [], "source": [ "features = [[1., 1.5]]\n", "labels = [[0.3]]\n", "eval_features = [[4., 4.5]]\n", "eval_labels = [[0.8]]" ] }, { "cell_type": "markdown", "metadata": { "id": "4uXff1BEssdE" }, "source": [ "## TensorFlow 1: Drive a model on TPUs with TPUEstimator" ] }, { "cell_type": "markdown", "metadata": { "id": "BVWHEQj5a7rN" }, "source": [ "This section of the guide demonstrates how to perform training and evaluation with `tf.compat.v1.estimator.tpu.TPUEstimator` in TensorFlow 1.\n", "\n", "To use a `TPUEstimator`, first define a few functions: an input function for the training data, an evaluation input function for the evaluation data, and a model function that tells the `TPUEstimator` how the training op is defined with the features and labels:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:04.071368Z", "iopub.status.busy": "2023-06-09T12:22:04.070941Z", "iopub.status.idle": "2023-06-09T12:22:04.077016Z", "shell.execute_reply": "2023-06-09T12:22:04.076349Z" }, "id": "lqe9obf7suIj" }, "outputs": [], "source": [ "def _input_fn(params):\n", " dataset = tf1.data.Dataset.from_tensor_slices((features, labels))\n", " dataset = dataset.repeat()\n", " return dataset.batch(params['batch_size'], drop_remainder=True)\n", "\n", "def _eval_input_fn(params):\n", " dataset = tf1.data.Dataset.from_tensor_slices((eval_features, eval_labels))\n", " dataset = dataset.repeat()\n", " return dataset.batch(params['batch_size'], drop_remainder=True)\n", "\n", "def _model_fn(features, labels, mode, params):\n", " logits = tf1.layers.Dense(1)(features)\n", " loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)\n", " optimizer = tf1.train.AdagradOptimizer(0.05)\n", " train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())\n", " return tf1.estimator.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)" ] }, { "cell_type": "markdown", "metadata": { "id": "QYnP3Dszc-2R" }, "source": [ "With those functions defined, create a `tf.distribute.cluster_resolver.TPUClusterResolver` that provides the cluster information, and a `tf.compat.v1.estimator.tpu.RunConfig` object. Along with the model function you have defined, you can now create a `TPUEstimator`. Here, you will simplify the flow by skipping checkpoint savings. Then, you will specify the batch size for both training and evaluation for the `TPUEstimator`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:04.080161Z", "iopub.status.busy": "2023-06-09T12:22:04.079697Z", "iopub.status.idle": "2023-06-09T12:22:04.209482Z", "shell.execute_reply": "2023-06-09T12:22:04.208484Z" }, "id": "WAqyqawemlcl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All devices: []\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-06-09 12:22:04.200684: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] } ], "source": [ "cluster_resolver = tf1.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n", "print(\"All devices: \", tf1.config.list_logical_devices('TPU'))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:04.213456Z", "iopub.status.busy": "2023-06-09T12:22:04.212883Z", "iopub.status.idle": "2023-06-09T12:22:04.579109Z", "shell.execute_reply": "2023-06-09T12:22:04.578459Z" }, "id": "HsOpjW5plH9Q" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15327/4277674569.py:1: The name tf.estimator.tpu.TPUConfig is deprecated. Please use tf.compat.v1.estimator.tpu.TPUConfig instead.\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15327/4277674569.py:1: TPUConfig.__new__ (from tensorflow_estimator.python.estimator.tpu.tpu_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15327/4277674569.py:2: RunConfig.__init__ (from tensorflow_estimator.python.estimator.tpu.tpu_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_config.py:268: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15327/4277674569.py:6: TPUEstimator.__init__ (from tensorflow_estimator.python.estimator.tpu.tpu_estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Estimator's model_fn () includes params argument, but params are not passed to Estimator.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:2811: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpfi5eso9h\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpfi5eso9h', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n", "cluster_def {\n", " job {\n", " name: \"worker\"\n", " tasks {\n", " key: 0\n", " value: \"10.25.167.66:8470\"\n", " }\n", " }\n", "}\n", "isolate_session_state: true\n", ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.25.167.66:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.25.167.66:8470', '_evaluation_master': 'grpc://10.25.167.66:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=10, num_shards=None, num_cores_per_replica=None, per_host_input_for_training=2, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1, experimental_allow_per_host_v2_parallel_get_next=False, experimental_feed_hook=None), '_cluster': }\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:_TPUContext: eval_on_tpu True\n" ] } ], "source": [ "tpu_config = tf1.estimator.tpu.TPUConfig(iterations_per_loop=10)\n", "config = tf1.estimator.tpu.RunConfig(\n", " cluster=cluster_resolver,\n", " save_checkpoints_steps=None,\n", " tpu_config=tpu_config)\n", "estimator = tf1.estimator.tpu.TPUEstimator(\n", " model_fn=_model_fn,\n", " config=config,\n", " train_batch_size=8,\n", " eval_batch_size=8)" ] }, { "cell_type": "markdown", "metadata": { "id": "Uxw7tWrcepaZ" }, "source": [ "Call `TPUEstimator.train` to begin training the model:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:04.582283Z", "iopub.status.busy": "2023-06-09T12:22:04.582021Z", "iopub.status.idle": "2023-06-09T12:22:09.618438Z", "shell.execute_reply": "2023-06-09T12:22:09.617751Z" }, "id": "WZPKFOMAcyrP" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Querying Tensorflow master (grpc://10.25.167.66:8470) for TPU system metadata.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Found TPU system:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Workers: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 1291425829812295795)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, -3325997117977499465)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 8477538334583946656)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 7772414489121114884)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 3690613627338882953)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 7506155478907565168)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 2236267778124756088)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, -3746359292134531998)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, -7944284987887513631)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, -5726915818709471474)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -2156150057367007347)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:2371: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:686: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Call initializer instance with the dtype argument instead of passing it to the constructor\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15327/3404938034.py:16: TPUEstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.tpu.tpu_estimator) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-06-09 12:22:04.585870: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:374] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:3328: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Bypassing TPUEstimator hook\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:3369: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1414: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:TPU job name worker\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:760: Variable.load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Prefer Variable.assign which has equivalent behavior in 2.X.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initialized dataset iterators in 0 seconds\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Installing graceful shutdown hook.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0']\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Init TPU system\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-06-09 12:22:04.975589: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:374] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initialized TPU in 4 seconds\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting infeed thread controller.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting outfeed thread controller.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Outfeed finished for iteration (0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:loss = 2.3149996, step = 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Stop infeed thread controller\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Shutting down InfeedController thread.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:InfeedController received shutdown signal, stopping.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Infeed thread finished, shutting down.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:infeed marked as finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Stop output thread controller\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Shutting down OutfeedController thread.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:OutfeedController received shutdown signal, stopping.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Outfeed thread finished, shutting down.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:outfeed marked as finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Shutdown TPU system.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Loss for final step: 2.3149996.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:training_loop marked as finished\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimator.train(_input_fn, steps=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "ev1vjIz9euIw" }, "source": [ "Then, call `TPUEstimator.evaluate` to evaluate the model using the evaluation data:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:09.621909Z", "iopub.status.busy": "2023-06-09T12:22:09.621388Z", "iopub.status.idle": "2023-06-09T12:22:14.915630Z", "shell.execute_reply": "2023-06-09T12:22:14.914961Z" }, "id": "bqiKRiwWc0cz" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Could not find trained model in model_dir: /tmpfs/tmp/tmpfi5eso9h, running initialization to evaluate.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1176: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Deprecated in favor of operator or tf.math.divide.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done calling model_fn.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting evaluation at 2023-06-09T12:22:09\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.keras instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:TPU job name worker\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Graph was finalized.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Done running local_init_op.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Init TPU system\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initialized TPU in 4 seconds\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting infeed thread controller.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Starting outfeed thread controller.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initialized dataset iterators in 0 seconds\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Outfeed finished for iteration (0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Evaluation [1/1]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Stop infeed thread controller\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Shutting down InfeedController thread.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:InfeedController received shutdown signal, stopping.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Infeed thread finished, shutting down.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:infeed marked as finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Stop output thread controller\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Shutting down OutfeedController thread.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:OutfeedController received shutdown signal, stopping.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Outfeed thread finished, shutting down.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:outfeed marked as finished\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Shutdown TPU system.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Inference Time : 5.10077s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished evaluation at 2023-06-09-12:22:14\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Saving dict for global step 1: global_step = 1, loss = 5.8631864\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:evaluation_loop marked as finished\n" ] }, { "data": { "text/plain": [ "{'loss': 5.8631864, 'global_step': 1}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimator.evaluate(_eval_input_fn, steps=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "KEmzBjfnsxwT" }, "source": [ "## TensorFlow 2: Drive a model on TPUs with Keras Model.fit and TPUStrategy" ] }, { "cell_type": "markdown", "metadata": { "id": "UesuXNbShrbi" }, "source": [ "In TensorFlow 2, to train on the TPU workers, use `tf.distribute.TPUStrategy` together with the Keras APIs for model definition and training/evaluation. (Refer to the [Use TPUs](../../guide/tpu.ipynb) guide for more examples of training with Keras `Model.fit` and a custom training loop (with `tf.function` and `tf.GradientTape`).)\n", "\n", "Since you need to perform some initialization work to connect to the remote cluster and initialize the TPU workers, start by creating a `TPUClusterResolver` to provide the cluster information and connect to the cluster. (Learn more in the *TPU initialization* section of the [Use TPUs](../../guide/tpu.ipynb) guide.)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:14.919217Z", "iopub.status.busy": "2023-06-09T12:22:14.918702Z", "iopub.status.idle": "2023-06-09T12:22:20.056149Z", "shell.execute_reply": "2023-06-09T12:22:20.055419Z" }, "id": "_TgdPNgXoS63" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Initializing the TPU system: grpc://10.25.167.66:8470\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Initializing the TPU system: grpc://10.25.167.66:8470\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Finished initializing TPU system.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Finished initializing TPU system.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "All devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]\n" ] } ], "source": [ "cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n", "tf.config.experimental_connect_to_cluster(cluster_resolver)\n", "tf.tpu.experimental.initialize_tpu_system(cluster_resolver)\n", "print(\"All devices: \", tf.config.list_logical_devices('TPU'))" ] }, { "cell_type": "markdown", "metadata": { "id": "R4EHXhN3CVmo" }, "source": [ "Next, once your data is prepared, you will create a `TPUStrategy`, define a model, metrics, and an optimizer under the scope of this strategy.\n", "\n", "To achieve comparable training speed with `TPUStrategy`, you should make sure to pick a number for `steps_per_execution` in `Model.compile` because it specifies the number of batches to run during each `tf.function` call, and is critical for performance. This argument is similar to `iterations_per_loop` used in a `TPUEstimator`. If you are using custom training loops, you should make sure multiple steps are run within the `tf.function`-ed training function. Go to the *Improving performance with multiple steps inside tf.function* section of the [Use TPUs](../../guide/tpu.ipynb) guide for more information.\n", "\n", "`tf.distribute.TPUStrategy` can support bounded dynamic shapes, which is the case that the upper bound of the dynamic shape computation can be inferred. But dynamic shapes may introduce some performance overhead compared to static shapes. So, it is generally recommended to make your input shapes static if possible, especially in training. One common op that returns a dynamic shape is `tf.data.Dataset.batch(batch_size)`, since the number of samples remaining in a stream might be less than the batch size. Therefore, when training on the TPU, you should use `tf.data.Dataset.batch(..., drop_remainder=True)` for best training performance." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:20.059781Z", "iopub.status.busy": "2023-06-09T12:22:20.059239Z", "iopub.status.idle": "2023-06-09T12:22:20.143047Z", "shell.execute_reply": "2023-06-09T12:22:20.142286Z" }, "id": "atVciNgPs0fw" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Found TPU system:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Found TPU system:\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores: 8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Workers: 1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Workers: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" ] } ], "source": [ "dataset = tf.data.Dataset.from_tensor_slices(\n", " (features, labels)).shuffle(10).repeat().batch(\n", " 8, drop_remainder=True).prefetch(2)\n", "eval_dataset = tf.data.Dataset.from_tensor_slices(\n", " (eval_features, eval_labels)).batch(1, drop_remainder=True)\n", "\n", "strategy = tf.distribute.TPUStrategy(cluster_resolver)\n", "with strategy.scope():\n", " model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n", " optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)\n", " model.compile(optimizer, \"mse\", steps_per_execution=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "FkM2VZyni98F" }, "source": [ "With that, you are ready to train the model with the training dataset:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:20.146519Z", "iopub.status.busy": "2023-06-09T12:22:20.146257Z", "iopub.status.idle": "2023-06-09T12:22:22.493547Z", "shell.execute_reply": "2023-06-09T12:22:22.492828Z" }, "id": "Kip65sYBlKiu" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "10/10 [==============================] - ETA: 0s - loss: 1.6286" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 1s 117ms/step - loss: 1.6286\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "10/10 [==============================] - ETA: 0s - loss: 0.6109" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 0s 4ms/step - loss: 0.6109\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "10/10 [==============================] - ETA: 0s - loss: 0.2844" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 0s 4ms/step - loss: 0.2844\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "10/10 [==============================] - ETA: 0s - loss: 0.1395" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 0s 4ms/step - loss: 0.1395\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "10/10 [==============================] - ETA: 0s - loss: 0.0699" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "10/10 [==============================] - 0s 4ms/step - loss: 0.0699\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(dataset, epochs=5, steps_per_epoch=10)" ] }, { "cell_type": "markdown", "metadata": { "id": "r0AEK8sNjLOj" }, "source": [ "Finally, evaluate the model using the evaluation dataset:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2023-06-09T12:22:22.497099Z", "iopub.status.busy": "2023-06-09T12:22:22.496543Z", "iopub.status.idle": "2023-06-09T12:22:24.308366Z", "shell.execute_reply": "2023-06-09T12:22:24.307651Z" }, "id": "6tMRkyfKhqSL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/1 [==============================] - ETA: 0s - loss: 1.2904" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "1/1 [==============================] - 2s 2s/step - loss: 1.2904\n" ] }, { "data": { "text/plain": [ "{'loss': 1.2903766632080078}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(eval_dataset, return_dict=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "67ec4d3f35d6" }, "source": [ "## Next steps" ] }, { "cell_type": "markdown", "metadata": { "id": "gHx_RUL8xcJ3" }, "source": [ "To learn more about `TPUStrategy` in TensorFlow 2, consider the following resources:\n", "\n", "- Guide: [Use TPUs](../../guide/tpu.ipynb) (covering training with Keras `Model.fit`/a custom training loop with `tf.distribute.TPUStrategy`, as well as tips on improving the performance with `tf.function`)\n", "- Guide: [Distributed training with TensorFlow](../../guide/distributed_training.ipynb)\n", "\n", "To learn more about customizing your training, refer to:\n", "- Guide: [Customize what happens in Model.fit](../..guide/keras/customizing_what_happens_in_fit.ipynb)\n", "- Guide: [Writing a training loop from scratch](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)\n", "\n", "TPUs—Google's specialized ASICs for machine learning—are available through [Google Colab](https://colab.research.google.com/), the [TPU Research Cloud](https://sites.research.google/trc/), and [Cloud TPU](https://cloud.google.com/tpu)." ] } ], "metadata": { "accelerator": "TPU", "colab": { "collapsed_sections": [], "name": "tpu_estimator.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" } }, "nbformat": 4, "nbformat_minor": 0 }