{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "g_nWetWWd_ns" }, "source": [ "##### Copyright 2023 The TensorFlow Datasets Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2024-12-14T12:46:27.548697Z", "iopub.status.busy": "2024-12-14T12:46:27.548265Z", "iopub.status.idle": "2024-12-14T12:46:27.552716Z", "shell.execute_reply": "2024-12-14T12:46:27.552058Z" }, "id": "2pHVBk_seED1" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "M7vSdG6sAIQn" }, "source": [ "# TFDS for Jax and PyTorch" ] }, { "cell_type": "markdown", "metadata": { "id": "fwc5GKHBASdc" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "9ee074e4" }, "source": [ "TFDS has always been framework-agnostic. For instance, you can easily load\n", "datasets in\n", "[NumPy format](https://www.tensorflow.org/datasets/api_docs/python/tfds/as_numpy)\n", "for usage in Jax and PyTorch.\n", "\n", "TensorFlow and its data loading solution\n", "([`tf.data`](https://www.tensorflow.org/guide/data)) are first-class citizens in\n", "our API by design.\n", "\n", "We extended TFDS to support TensorFlow-less NumPy-only data loading. This can\n", "be convenient for usage in ML frameworks such as Jax and PyTorch. Indeed,\n", "for the latter users, TensorFlow can:\n", "\n", "- reserve GPU/TPU memory;\n", "- increase build time in CI/CD;\n", "- take time to import at runtime.\n", "\n", "TensorFlow is no longer a dependency to read datasets.\n", "\n", "ML pipelines need a data loader to load examples, decode them, and present\n", "them to the model. Data loaders use the\n", "\"source/sampler/loader\" paradigm:\n", "\n", "```\n", " TFDS dataset ┌────────────────┐\n", " on disk │ │\n", " ┌──────────►│ Data │\n", "|..|... │ | │ source ├─┐\n", "├──┼────┴─────┤ │ │ │\n", "│12│image12 │ └────────────────┘ │ ┌────────────────┐\n", "├──┼──────────┤ │ │ │\n", "│13│image13 │ ├───►│ Data ├───► ML pipeline\n", "├──┼──────────┤ │ │ loader │\n", "│14│image14 │ ┌────────────────┐ │ │ │\n", "├──┼──────────┤ │ │ │ └────────────────┘\n", "|..|... | │ Index ├─┘\n", " │ sampler │\n", " │ │\n", " └────────────────┘\n", "```\n", "\n", "- The data source is responsible for accessing and decoding examples from a TFDS\n", "dataset on the fly.\n", "- The index sampler is responsible for determining the order in which records\n", "are processed. This is important to implement global transformations (e.g.,\n", "global shuffling, sharding, repeating for multiple epochs) before reading any\n", "records.\n", "- The data loader orchestrates the loading by leveraging the data source and the\n", "index sampler. It allows performance optimization (e.g., pre-fetching,\n", "multiprocessing or multithreading).\n" ] }, { "cell_type": "markdown", "metadata": { "id": "UaWdLA3fQDK2" }, "source": [ "## TL;DR\n", "\n", "`tfds.data_source` is an API to create data sources:\n", "\n", "1. for fast prototyping in pure-Python pipelines;\n", "2. to manage data-intensive ML pipelines at scale." ] }, { "cell_type": "markdown", "metadata": { "id": "aLho3l_Vd0a5" }, "source": [ "## Setup\n", "\n", "Let's install and import the needed dependencies:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:46:27.556129Z", "iopub.status.busy": "2024-12-14T12:46:27.555620Z", "iopub.status.idle": "2024-12-14T12:46:40.604206Z", "shell.execute_reply": "2024-12-14T12:46:40.603374Z" }, "id": "c4COEsqIdvYH" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: array_record in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (0.5.1)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from array_record) (2.1.0)\r\n", "Requirement already satisfied: etils[epath] in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from array_record) (1.5.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: fsspec in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath]->array_record) (2024.10.0)\r\n", "Requirement already satisfied: importlib_resources in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath]->array_record) (6.4.5)\r\n", "Requirement already satisfied: typing_extensions in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath]->array_record) (4.12.2)\r\n", "Requirement already satisfied: zipp in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath]->array_record) (3.21.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting grain-nightly\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading grain_nightly-0.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grain-nightly) (2.1.0)\r\n", "Requirement already satisfied: array-record in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grain-nightly) (0.5.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting cloudpickle (from grain-nightly)\r\n", " Downloading cloudpickle-3.1.0-py3-none-any.whl.metadata (7.0 kB)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grain-nightly) (0.1.8)\r\n", "Requirement already satisfied: etils[epath,epy] in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grain-nightly) (1.5.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting jaxtyping (from grain-nightly)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading jaxtyping-0.2.36-py3-none-any.whl.metadata (6.5 kB)\r\n", "Collecting more-itertools>=9.1.0 (from grain-nightly)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading more_itertools-10.5.0-py3-none-any.whl.metadata (36 kB)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grain-nightly) (2.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: fsspec in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath,epy]->grain-nightly) (2024.10.0)\r\n", "Requirement already satisfied: importlib_resources in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath,epy]->grain-nightly) (6.4.5)\r\n", "Requirement already satisfied: typing_extensions in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath,epy]->grain-nightly) (4.12.2)\r\n", "Requirement already satisfied: zipp in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[epath,epy]->grain-nightly) (3.21.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading grain_nightly-0.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (419 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading more_itertools-10.5.0-py3-none-any.whl (60 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading cloudpickle-3.1.0-py3-none-any.whl (22 kB)\r\n", "Downloading jaxtyping-0.2.36-py3-none-any.whl (55 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: more-itertools, jaxtyping, cloudpickle, grain-nightly\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed cloudpickle-3.1.0 grain-nightly-0.0.9 jaxtyping-0.2.36 more-itertools-10.5.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting jax\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading jax-0.4.30-py3-none-any.whl.metadata (22 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting jaxlib\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl.metadata (1.0 kB)\r\n", "Requirement already satisfied: ml-dtypes>=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (0.4.1)\r\n", "Requirement already satisfied: numpy>=1.22 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (2.0.2)\r\n", "Requirement already satisfied: opt-einsum in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (3.4.0)\r\n", "Requirement already satisfied: scipy>=1.9 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (1.13.1)\r\n", "Requirement already satisfied: importlib-metadata>=4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax) (8.5.0)\r\n", "Requirement already satisfied: zipp>=3.20 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.6->jax) (3.21.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl (79.6 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: jaxlib, jax\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed jax-0.4.30 jaxlib-0.4.30\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting tfds-nightly\r\n", " Using cached tfds_nightly-4.9.3.dev202311230044-py3-none-any.whl.metadata (9.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (2.1.0)\r\n", "Requirement already satisfied: click in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (8.1.7)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (0.1.8)\r\n", "Requirement already satisfied: etils>=0.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tfds-nightly) (1.5.2)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (2.0.2)\r\n", "Requirement already satisfied: promise in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (2.3)\r\n", "Requirement already satisfied: protobuf>=3.20 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (3.20.3)\r\n", "Requirement already satisfied: psutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (6.1.0)\r\n", "Requirement already satisfied: requests>=2.19.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (2.32.3)\r\n", "Requirement already satisfied: tensorflow-metadata in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (1.16.1)\r\n", "Requirement already satisfied: termcolor in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (2.5.0)\r\n", "Requirement already satisfied: toml in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (0.10.2)\r\n", "Requirement already satisfied: tqdm in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (4.67.1)\r\n", "Requirement already satisfied: wrapt in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (1.17.0)\r\n", "Requirement already satisfied: array-record>=0.5.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tfds-nightly) (0.5.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: fsspec in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tfds-nightly) (2024.10.0)\r\n", "Requirement already satisfied: importlib_resources in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tfds-nightly) (6.4.5)\r\n", "Requirement already satisfied: typing_extensions in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tfds-nightly) (4.12.2)\r\n", "Requirement already satisfied: zipp in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tfds-nightly) (3.21.0)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests>=2.19.0->tfds-nightly) (3.4.0)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests>=2.19.0->tfds-nightly) (3.10)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests>=2.19.0->tfds-nightly) (2.2.3)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests>=2.19.0->tfds-nightly) (2024.8.30)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from promise->tfds-nightly) (1.17.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cached tfds_nightly-4.9.3.dev202311230044-py3-none-any.whl (5.0 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: tfds-nightly\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tfds-nightly-4.9.3.dev202311230044\r\n" ] } ], "source": [ "!pip install array_record\n", "!pip install grain-nightly\n", "!pip install jax jaxlib\n", "!pip install tfds-nightly\n", "\n", "import os\n", "os.environ.pop('TFDS_DATA_DIR', None)\n", "\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { "id": "CjEJeF1Id_JM" }, "source": [ "## Data sources\n", "\n", "Data sources are basically Python sequences. So they need to implement the\n", "following protocol:\n", "\n", "```python\n", "from typing import SupportsIndex\n", "\n", "class RandomAccessDataSource(Protocol):\n", " \"\"\"Interface for datasources where storage supports efficient random access.\"\"\"\n", "\n", " def __len__(self) -> int:\n", " \"\"\"Number of records in the dataset.\"\"\"\n", "\n", " def __getitem__(self, key: SupportsIndex) -> Any:\n", " \"\"\"Retrieves the record for the given key.\"\"\"\n", "```\n", "\n", "The underlying file format needs to support efficient random access. At the\n", "moment, TFDS relies on [`array_record`](https://github.com/google/array_record).\n", "\n", "[`array_record`](https://github.com/google/array_record) is a new file format\n", "derived from [Riegeli](https://github.com/google/riegeli), achieving a new\n", "frontier of IO efficiency. In particular, ArrayRecord supports parallel read,\n", "write, and random access by record index. ArrayRecord builds on top of Riegeli\n", "and supports the same compression algorithms.\n", "\n", "[`fashion_mnist`](https://www.tensorflow.org/datasets/catalog/fashion_mnist) is\n", "a common dataset for computer vision. To retrieve an ArrayRecord-based data\n", "source with TFDS, simply use:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:46:40.608004Z", "iopub.status.busy": "2024-12-14T12:46:40.607298Z", "iopub.status.idle": "2024-12-14T12:47:13.022237Z", "shell.execute_reply": "2024-12-14T12:47:13.021314Z" }, "id": "9Tslzx0_eEWx" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-12-14 12:46:40.889814: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1734180400.913496 788159 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1734180400.920664 788159 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mDownloading and preparing dataset 29.45 MiB (download: 29.45 MiB, generated: 36.42 MiB, total: 65.87 MiB) to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1...\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-12-14 12:46:47.047041: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mDataset fashion_mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.\u001b[0m\n" ] } ], "source": [ "ds = tfds.data_source('fashion_mnist')" ] }, { "cell_type": "markdown", "metadata": { "id": "AlaRrD_SeHLY" }, "source": [ "`tfds.data_source` is a convenient wrapper. It is equivalent to:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:13.025717Z", "iopub.status.busy": "2024-12-14T12:47:13.025042Z", "iopub.status.idle": "2024-12-14T12:47:13.037443Z", "shell.execute_reply": "2024-12-14T12:47:13.036786Z" }, "id": "duHDKzXReIKB" }, "outputs": [], "source": [ "builder = tfds.builder('fashion_mnist', file_format='array_record')\n", "builder.download_and_prepare()\n", "ds = builder.as_data_source()" ] }, { "cell_type": "markdown", "metadata": { "id": "rlyIsd0ueKjQ" }, "source": [ "This outputs a dictionary of data sources:\n", "\n", "```\n", "{\n", " 'train': DataSource(name=fashion_mnist, split='train', decoders=None),\n", " 'test': DataSource(name=fashion_mnist, split='test', decoders=None),\n", "}\n", "```\n", "\n", "Once `download_and_prepare` has run, and you generated the record files, we\n", "don't need TensorFlow anymore. Everything will happen in Python/NumPy!\n", "\n", "Let's check this by uninstalling TensorFlow and re-loading the data source\n", "in another subprocess:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:13.040229Z", "iopub.status.busy": "2024-12-14T12:47:13.039715Z", "iopub.status.idle": "2024-12-14T12:47:14.891889Z", "shell.execute_reply": "2024-12-14T12:47:14.890912Z" }, "id": "mTfSzvaQkSd9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: tensorflow 2.18.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Uninstalling tensorflow-2.18.0:\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Successfully uninstalled tensorflow-2.18.0\r\n" ] } ], "source": [ "!pip uninstall -y tensorflow" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:14.895546Z", "iopub.status.busy": "2024-12-14T12:47:14.894849Z", "iopub.status.idle": "2024-12-14T12:47:14.901068Z", "shell.execute_reply": "2024-12-14T12:47:14.900410Z" }, "id": "3sT5AN7neNT9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Writing no_tensorflow.py\n" ] } ], "source": [ "%%writefile no_tensorflow.py\n", "import os\n", "os.environ.pop('TFDS_DATA_DIR', None)\n", "\n", "import tensorflow_datasets as tfds\n", "\n", "try:\n", " import tensorflow as tf\n", "except ImportError:\n", " print('No TensorFlow found...')\n", "\n", "ds = tfds.data_source('fashion_mnist')\n", "print('...but the data source could still be loaded...')\n", "ds['train'][0]\n", "print('...and the records can be decoded.')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:14.903623Z", "iopub.status.busy": "2024-12-14T12:47:14.903188Z", "iopub.status.idle": "2024-12-14T12:47:15.787722Z", "shell.execute_reply": "2024-12-14T12:47:15.786727Z" }, "id": "FxohFdb3kSxh" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No TensorFlow found...\r\n", "...but the data source could still be loaded...\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...\r\n", "...and the records can be decoded.\r\n" ] } ], "source": [ "!python no_tensorflow.py" ] }, { "cell_type": "markdown", "metadata": { "id": "1o8n-BhhePYY" }, "source": [ "In future versions, we are also going to make the dataset preparation\n", "TensorFlow-free.\n", "\n", "A data source has a length:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:15.791640Z", "iopub.status.busy": "2024-12-14T12:47:15.790994Z", "iopub.status.idle": "2024-12-14T12:47:15.799499Z", "shell.execute_reply": "2024-12-14T12:47:15.798837Z" }, "id": "qtfl17SQeQ7F" }, "outputs": [ { "data": { "text/plain": [ "60000" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(ds['train'])" ] }, { "cell_type": "markdown", "metadata": { "id": "a-UFBu8leSMp" }, "source": [ "Accessing the first element of the dataset:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:15.802381Z", "iopub.status.busy": "2024-12-14T12:47:15.801846Z", "iopub.status.idle": "2024-12-14T12:47:20.571873Z", "shell.execute_reply": "2024-12-14T12:47:20.571063Z" }, "id": "tFvT2Sx2eToh" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "583 µs ± 4.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "ds['train'][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "VTgZskyZeU_D" }, "source": [ "...is just as cheap as accessing any other element. This is the definition of\n", "[random access](https://en.wikipedia.org/wiki/Random_access):" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:20.575165Z", "iopub.status.busy": "2024-12-14T12:47:20.574452Z", "iopub.status.idle": "2024-12-14T12:47:25.339440Z", "shell.execute_reply": "2024-12-14T12:47:25.338639Z" }, "id": "cPJFa6aIeWcY" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "587 µs ± 3.88 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" ] } ], "source": [ "%%timeit\n", "ds['train'][1000]" ] }, { "cell_type": "markdown", "metadata": { "id": "fs3kafYheX6N" }, "source": [ "Features now use NumPy DTypes (rather than TensorFlow DTypes). You can inspect\n", "the features with:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:25.342677Z", "iopub.status.busy": "2024-12-14T12:47:25.342006Z", "iopub.status.idle": "2024-12-14T12:47:25.350244Z", "shell.execute_reply": "2024-12-14T12:47:25.349580Z" }, "id": "q7x5AEEaeZja" }, "outputs": [], "source": [ "features = tfds.builder('fashion_mnist').info.features" ] }, { "cell_type": "markdown", "metadata": { "id": "VOnLqAZOeiBi" }, "source": [ "You'll find more information about\n", "[the features in our documentation](https://www.tensorflow.org/datasets/api_docs/python/tfds/features).\n", "Here we can notably retrieve the shape of the images, and the number of classes:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:25.353124Z", "iopub.status.busy": "2024-12-14T12:47:25.352532Z", "iopub.status.idle": "2024-12-14T12:47:25.356229Z", "shell.execute_reply": "2024-12-14T12:47:25.355568Z" }, "id": "Xk8Vc-y0edlb" }, "outputs": [], "source": [ "shape = features['image'].shape\n", "num_classes = features['label'].num_classes" ] }, { "cell_type": "markdown", "metadata": { "id": "eFh8pytVemsu" }, "source": [ "## Use in pure Python\n", "\n", "You can consume data sources in Python by iterating over them:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:25.358789Z", "iopub.status.busy": "2024-12-14T12:47:25.358366Z", "iopub.status.idle": "2024-12-14T12:47:25.366289Z", "shell.execute_reply": "2024-12-14T12:47:25.365631Z" }, "id": "ULjO-JDVefNf" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'image': array([[[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 18],\n", " [ 77],\n", " [227],\n", " [227],\n", " [208],\n", " [210],\n", " [225],\n", " [216],\n", " [ 85],\n", " [ 32],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 61],\n", " [100],\n", " [ 97],\n", " [ 80],\n", " [ 57],\n", " [117],\n", " [227],\n", " [238],\n", " [115],\n", " [ 49],\n", " [ 78],\n", " [106],\n", " [108],\n", " [ 71],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 81],\n", " [105],\n", " [ 80],\n", " [ 69],\n", " [ 72],\n", " [ 64],\n", " [ 44],\n", " [ 21],\n", " [ 13],\n", " [ 44],\n", " [ 69],\n", " [ 75],\n", " [ 75],\n", " [ 80],\n", " [114],\n", " [ 80],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 26],\n", " [ 92],\n", " [ 69],\n", " [ 68],\n", " [ 75],\n", " [ 75],\n", " [ 71],\n", " [ 74],\n", " [ 83],\n", " [ 75],\n", " [ 77],\n", " [ 78],\n", " [ 74],\n", " [ 74],\n", " [ 83],\n", " [ 77],\n", " [108],\n", " [ 34],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 55],\n", " [ 92],\n", " [ 69],\n", " [ 74],\n", " [ 74],\n", " [ 71],\n", " [ 71],\n", " [ 77],\n", " [ 69],\n", " [ 66],\n", " [ 75],\n", " [ 74],\n", " [ 77],\n", " [ 80],\n", " [ 80],\n", " [ 78],\n", " [ 94],\n", " [ 63],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 63],\n", " [ 95],\n", " [ 66],\n", " [ 68],\n", " [ 72],\n", " [ 72],\n", " [ 69],\n", " [ 72],\n", " [ 74],\n", " [ 74],\n", " [ 74],\n", " [ 75],\n", " [ 75],\n", " [ 77],\n", " [ 80],\n", " [ 77],\n", " [106],\n", " [ 61],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 80],\n", " [108],\n", " [ 71],\n", " [ 69],\n", " [ 72],\n", " [ 71],\n", " [ 69],\n", " [ 72],\n", " [ 75],\n", " [ 75],\n", " [ 72],\n", " [ 72],\n", " [ 75],\n", " [ 78],\n", " [ 72],\n", " [ 85],\n", " [128],\n", " [ 64],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 88],\n", " [120],\n", " [ 75],\n", " [ 74],\n", " [ 77],\n", " [ 75],\n", " [ 72],\n", " [ 77],\n", " [ 74],\n", " [ 74],\n", " [ 77],\n", " [ 78],\n", " [ 83],\n", " [ 83],\n", " [ 66],\n", " [111],\n", " [123],\n", " [ 78],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 85],\n", " [134],\n", " [ 74],\n", " [ 85],\n", " [ 69],\n", " [ 75],\n", " [ 75],\n", " [ 74],\n", " [ 75],\n", " [ 74],\n", " [ 75],\n", " [ 75],\n", " [ 81],\n", " [ 75],\n", " [ 61],\n", " [151],\n", " [115],\n", " [ 91],\n", " [ 12],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 10],\n", " [ 85],\n", " [153],\n", " [ 83],\n", " [ 80],\n", " [ 68],\n", " [ 77],\n", " [ 75],\n", " [ 74],\n", " [ 75],\n", " [ 74],\n", " [ 75],\n", " [ 77],\n", " [ 80],\n", " [ 68],\n", " [ 61],\n", " [162],\n", " [122],\n", " [ 78],\n", " [ 6],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 30],\n", " [ 75],\n", " [154],\n", " [ 85],\n", " [ 80],\n", " [ 71],\n", " [ 80],\n", " [ 72],\n", " [ 77],\n", " [ 75],\n", " [ 75],\n", " [ 77],\n", " [ 78],\n", " [ 77],\n", " [ 75],\n", " [ 49],\n", " [191],\n", " [132],\n", " [ 72],\n", " [ 15],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 58],\n", " [ 66],\n", " [174],\n", " [115],\n", " [ 66],\n", " [ 77],\n", " [ 80],\n", " [ 72],\n", " [ 78],\n", " [ 75],\n", " [ 77],\n", " [ 78],\n", " [ 78],\n", " [ 77],\n", " [ 66],\n", " [ 49],\n", " [222],\n", " [131],\n", " [ 77],\n", " [ 37],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 69],\n", " [ 55],\n", " [179],\n", " [139],\n", " [ 55],\n", " [ 92],\n", " [ 74],\n", " [ 74],\n", " [ 78],\n", " [ 74],\n", " [ 78],\n", " [ 77],\n", " [ 75],\n", " [ 80],\n", " [ 64],\n", " [ 55],\n", " [242],\n", " [111],\n", " [ 95],\n", " [ 44],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 74],\n", " [ 57],\n", " [159],\n", " [180],\n", " [ 55],\n", " [ 92],\n", " [ 64],\n", " [ 72],\n", " [ 74],\n", " [ 74],\n", " [ 77],\n", " [ 75],\n", " [ 77],\n", " [ 78],\n", " [ 55],\n", " [ 66],\n", " [255],\n", " [ 97],\n", " [108],\n", " [ 49],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 74],\n", " [ 66],\n", " [145],\n", " [153],\n", " [ 72],\n", " [ 83],\n", " [ 58],\n", " [ 78],\n", " [ 77],\n", " [ 75],\n", " [ 75],\n", " [ 75],\n", " [ 72],\n", " [ 80],\n", " [ 30],\n", " [132],\n", " [255],\n", " [ 37],\n", " [122],\n", " [ 60],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 80],\n", " [ 69],\n", " [142],\n", " [180],\n", " [142],\n", " [ 57],\n", " [ 64],\n", " [ 78],\n", " [ 74],\n", " [ 75],\n", " [ 75],\n", " [ 75],\n", " [ 72],\n", " [ 85],\n", " [ 21],\n", " [185],\n", " [227],\n", " [ 37],\n", " [143],\n", " [ 63],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 83],\n", " [ 71],\n", " [136],\n", " [194],\n", " [126],\n", " [ 46],\n", " [ 69],\n", " [ 75],\n", " [ 72],\n", " [ 75],\n", " [ 75],\n", " [ 75],\n", " [ 74],\n", " [ 78],\n", " [ 38],\n", " [139],\n", " [185],\n", " [ 60],\n", " [151],\n", " [ 58],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 4],\n", " [ 81],\n", " [ 74],\n", " [145],\n", " [177],\n", " [ 78],\n", " [ 49],\n", " [ 74],\n", " [ 77],\n", " [ 75],\n", " [ 75],\n", " [ 75],\n", " [ 75],\n", " [ 74],\n", " [ 72],\n", " [ 63],\n", " [ 80],\n", " [156],\n", " [117],\n", " [153],\n", " [ 55],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 10],\n", " [ 80],\n", " [ 72],\n", " [157],\n", " [163],\n", " [ 61],\n", " [ 55],\n", " [ 75],\n", " [ 77],\n", " [ 75],\n", " [ 77],\n", " [ 75],\n", " [ 75],\n", " [ 75],\n", " [ 77],\n", " [ 71],\n", " [ 60],\n", " [ 98],\n", " [156],\n", " [132],\n", " [ 58],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 13],\n", " [ 77],\n", " [ 74],\n", " [157],\n", " [143],\n", " [ 43],\n", " [ 61],\n", " [ 72],\n", " [ 75],\n", " [ 77],\n", " [ 75],\n", " [ 74],\n", " [ 77],\n", " [ 77],\n", " [ 75],\n", " [ 71],\n", " [ 58],\n", " [ 80],\n", " [157],\n", " [120],\n", " [ 66],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 18],\n", " [ 81],\n", " [ 74],\n", " [156],\n", " [114],\n", " [ 35],\n", " [ 72],\n", " [ 71],\n", " [ 75],\n", " [ 78],\n", " [ 72],\n", " [ 66],\n", " [ 80],\n", " [ 78],\n", " [ 77],\n", " [ 75],\n", " [ 64],\n", " [ 63],\n", " [165],\n", " [119],\n", " [ 68],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 23],\n", " [ 85],\n", " [ 81],\n", " [177],\n", " [ 57],\n", " [ 52],\n", " [ 77],\n", " [ 71],\n", " [ 78],\n", " [ 80],\n", " [ 72],\n", " [ 75],\n", " [ 74],\n", " [ 77],\n", " [ 77],\n", " [ 75],\n", " [ 64],\n", " [ 37],\n", " [173],\n", " [ 95],\n", " [ 72],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 26],\n", " [ 81],\n", " [ 86],\n", " [160],\n", " [ 20],\n", " [ 75],\n", " [ 77],\n", " [ 77],\n", " [ 80],\n", " [ 78],\n", " [ 80],\n", " [ 89],\n", " [ 78],\n", " [ 81],\n", " [ 83],\n", " [ 80],\n", " [ 74],\n", " [ 20],\n", " [177],\n", " [ 77],\n", " [ 74],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 49],\n", " [ 77],\n", " [ 91],\n", " [200],\n", " [ 0],\n", " [ 83],\n", " [ 95],\n", " [ 86],\n", " [ 88],\n", " [ 88],\n", " [ 89],\n", " [ 88],\n", " [ 89],\n", " [ 88],\n", " [ 83],\n", " [ 89],\n", " [ 86],\n", " [ 0],\n", " [191],\n", " [ 78],\n", " [ 80],\n", " [ 24],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 54],\n", " [ 71],\n", " [108],\n", " [165],\n", " [ 0],\n", " [ 24],\n", " [ 57],\n", " [ 52],\n", " [ 57],\n", " [ 60],\n", " [ 60],\n", " [ 60],\n", " [ 63],\n", " [ 63],\n", " [ 77],\n", " [ 89],\n", " [ 52],\n", " [ 0],\n", " [211],\n", " [ 97],\n", " [ 77],\n", " [ 61],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 68],\n", " [ 91],\n", " [117],\n", " [137],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 18],\n", " [216],\n", " [ 94],\n", " [ 97],\n", " [ 57],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 54],\n", " [115],\n", " [105],\n", " [185],\n", " [ 0],\n", " [ 0],\n", " [ 1],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [153],\n", " [ 78],\n", " [106],\n", " [ 37],\n", " [ 0],\n", " [ 0],\n", " [ 0]],\n", "\n", " [[ 0],\n", " [ 0],\n", " [ 0],\n", " [ 18],\n", " [ 61],\n", " [ 41],\n", " [103],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [ 0],\n", " [106],\n", " [ 47],\n", " [ 69],\n", " [ 23],\n", " [ 0],\n", " [ 0],\n", " [ 0]]], dtype=uint8), 'label': 2}\n" ] } ], "source": [ "for example in ds['train']:\n", " print(example)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "id": "gZRHZNOkenPb" }, "source": [ "If you inspect elements, you will also notice that all features are already\n", "decoded using NumPy. Behind the scenes, we use [OpenCV](https://opencv.org)\n", "by default because it is fast. If you don't have OpenCV installed, we default\n", "to [Pillow](python-pillow.org) to provide lightweight and fast image\n", "decoding.\n", "\n", "```\n", "{\n", " 'image': array([[[0], [0], ..., [0]],\n", " [[0], [0], ..., [0]]], dtype=uint8),\n", " 'label': 2,\n", "}\n", "```\n", "\n", "**Note**: Currently, the feature is only available for `Tensor`, `Image` and\n", "`Scalar` features. The `Audio` and `Video` features will come soon. Stay tuned!" ] }, { "cell_type": "markdown", "metadata": { "id": "8kLyK5j1enhc" }, "source": [ "## Use with PyTorch\n", "\n", "PyTorch uses the source/sampler/loader paradigm. In Torch, \"data sources\" are\n", "called \"datasets\".\n", "[`torch.utils.data`](https://pytorch.org/docs/stable/data.html) contains all the\n", "details you need to know to build efficient input pipelines in Torch.\n", "\n", "TFDS data sources can be used as regular\n", "[map-style datasets](https://pytorch.org/docs/stable/data.html#map-style-datasets).\n", "\n", "First we install and import Torch:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:47:25.369088Z", "iopub.status.busy": "2024-12-14T12:47:25.368639Z", "iopub.status.idle": "2024-12-14T12:48:53.730125Z", "shell.execute_reply": "2024-12-14T12:48:53.729263Z" }, "id": "3aKol1fDeyoK" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting torch\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl.metadata (28 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting filelock (from torch)\r\n", " Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)\r\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from torch) (4.12.2)\r\n", "Requirement already satisfied: networkx in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from torch) (3.2.1)\r\n", "Requirement already satisfied: jinja2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from torch) (3.1.4)\r\n", "Requirement already satisfied: fsspec in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from torch) (2024.10.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\r\n", " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\r\n", " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n", "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\r\n", " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\r\n", " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\r\n", " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\r\n", " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\r\n", " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\r\n", "Collecting nvidia-nccl-cu12==2.21.5 (from torch)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\r\n", "Collecting nvidia-nvtx-cu12==12.4.127 (from torch)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.7 kB)\r\n", "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\r\n", " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting triton==3.1.0 (from torch)\r\n", " Downloading triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.3 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting sympy==1.13.1 (from torch)\r\n", " Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: MarkupSafe>=2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jinja2->torch) (3.0.2)\r\n", "Downloading torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl (906.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl (188.7 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (99 kB)\r\n", "Downloading sympy-1.13.1-py3-none-any.whl (6.2 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (209.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading filelock-3.16.1-py3-none-any.whl (16 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: mpmath, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, filelock, triton, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed filelock-3.16.1 mpmath-1.3.0 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nccl-cu12-2.21.5 nvidia-nvjitlink-cu12-12.4.127 nvidia-nvtx-cu12-12.4.127 sympy-1.13.1 torch-2.5.1 triton-3.1.0\r\n" ] } ], "source": [ "!pip install torch\n", "\n", "from tqdm import tqdm\n", "import torch" ] }, { "cell_type": "markdown", "metadata": { "id": "HKdJvYywe0YC" }, "source": [ "We already defined data sources for training and testing (respectively,\n", "`ds['train']` and `ds['test']`). We can now define the sampler and the loaders:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:48:53.734116Z", "iopub.status.busy": "2024-12-14T12:48:53.733379Z", "iopub.status.idle": "2024-12-14T12:48:53.738651Z", "shell.execute_reply": "2024-12-14T12:48:53.737898Z" }, "id": "_4P2JIrie23f" }, "outputs": [], "source": [ "batch_size = 128\n", "train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)\n", "train_loader = torch.utils.data.DataLoader(\n", " ds['train'],\n", " sampler=train_sampler,\n", " batch_size=batch_size,\n", ")\n", "test_loader = torch.utils.data.DataLoader(\n", " ds['test'],\n", " sampler=None,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "EVhofOm4e53O" }, "source": [ "Using PyTorch, we train and evaluate a simple logistic regression on the first\n", "examples:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:48:53.741662Z", "iopub.status.busy": "2024-12-14T12:48:53.741059Z", "iopub.status.idle": "2024-12-14T12:48:58.587176Z", "shell.execute_reply": "2024-12-14T12:48:58.586185Z" }, "id": "HcAmvMa-e42p" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/40 [00:00 dict[str, np.ndarray]:\n", " label = element[\"label\"]\n", " text = self.LABEL_TO_TEXT[label]\n", " element[\"text\"] = text\n", " return element\n", "\n", "# You can chain transformations in a list:\n", "operations = [ImageToText()]" ] }, { "cell_type": "markdown", "metadata": { "id": "53U1d8Yj4IM9" }, "source": [ "Finally, the data loader takes care of orchestrating the loading. You can scale\n", "up with multiprocessing to enjoy both the flexibility of Python and the\n", "performance of a data loader:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2024-12-14T12:48:58.665163Z", "iopub.status.busy": "2024-12-14T12:48:58.664612Z", "iopub.status.idle": "2024-12-14T12:48:58.750328Z", "shell.execute_reply": "2024-12-14T12:48:58.749635Z" }, "id": "SQEP8Bhp4IM9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "two\n", "one\n", "one\n", "height\n", "four\n" ] } ], "source": [ "loader = pygrain.DataLoader(\n", " data_source=data_source,\n", " operations=operations,\n", " sampler=sampler,\n", " worker_count=0, # Scale to multiple workers in multiprocessing\n", ")\n", "\n", "for element in loader:\n", " print(element[\"text\"])" ] }, { "cell_type": "markdown", "metadata": { "id": "JvLEtCWRvvy8" }, "source": [ "## Read more\n", "\n", "For more information, please refer to [`tfds.data_source`](https://www.tensorflow.org/datasets/api_docs/python/tfds/data_source) API doc." ] } ], "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }