{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ZAxHmnV88kcD" }, "source": [ "Project: /mediapipe/_project.yaml\n", "Book: /mediapipe/_book.yaml\n", "\n", "\n", "\n", "# Text classification model customization guide\n", "\n", "\u003ctable align=\"left\" class=\"buttons\"\u003e\n", " \u003ctd\u003e\n", " \u003ca href=\"https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/customization/text_classifier.ipynb\" target=\"_blank\"\u003e\n", " \u003cimg src=\"https://developers.google.com/static/mediapipe/solutions/customization/colab-logo-32px_1920.png\" alt=\"Colab logo\"\u003e Run in Colab\n", " \u003c/a\u003e\n", " \u003c/td\u003e\n", "\n", " \u003ctd\u003e\n", " \u003ca href=\"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/text_classifier.ipynb\" target=\"_blank\"\u003e\n", " \u003cimg src=\"https://developers.google.com/static/mediapipe/solutions/customization/github-logo-32px_1920.png\" alt=\"GitHub logo\"\u003e\n", " View on GitHub\n", " \u003c/a\u003e\n", " \u003c/td\u003e\n", "\u003c/table\u003e" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JO1GUwC1_T2x" }, "outputs": [], "source": [ "#@title License information\n", "# Copyright 2023 The MediaPipe Authors.\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "#\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "_7OpXsvDmBXz" }, "source": [ "The MediaPipe Model Maker package is a simple, low-code solution for customizing on-device machine learning (ML) Models. This notebook shows the end-to-end process of customizing a text classification model for the specific use case of performing sentiment analysis on movie reviews." ] }, { "cell_type": "markdown", "metadata": { "id": "mzG8EHCUnAO2" }, "source": [ "## Prerequisites\n", "\n", "Install the MediaPipe Model Maker package." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fnyOyxCIWV5j" }, "outputs": [], "source": [ "!pip install --upgrade pip\n", "!pip install mediapipe-model-maker" ] }, { "cell_type": "markdown", "metadata": { "id": "8eGD7ryenJ1b" }, "source": [ "Import the required libraries." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YVoFJsZ5nMhK" }, "outputs": [], "source": [ "import os\n", "import tensorflow as tf\n", "assert tf.__version__.startswith('2')\n", "\n", "from mediapipe_model_maker import text_classifier" ] }, { "cell_type": "markdown", "metadata": { "id": "Don3e8Oc9LGK" }, "source": [ "## Get the Dataset\n", "\n", "The following code block downloads the [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) dataset which contains 67,349 movie reviews for training and 872 movie reviews for testing. The dataset has two classes: positive and negative movie reviews. Positive reviews are labelled with `1` and negative reviews with `0`. We will use this dataset to train the two text classifiers featured in this tutorial.\n", "\n", "*Disclaimer: The dataset linked in this Colab is not owned or distributed by Google, and is made available by third parties. Please review the terms and conditions made available by the third parties before using the data.*" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XIiVqeUprx6v" }, "outputs": [], "source": [ "data_path = tf.keras.utils.get_file(\n", " fname='SST-2.zip',\n", " origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',\n", " extract=True)\n", "data_dir = os.path.join(os.path.dirname(data_path), 'SST-2') # folder name" ] }, { "cell_type": "markdown", "metadata": { "id": "NmvVIU_KlX_P" }, "source": [ "The SST-2 dataset is stored as a TSV file. The only difference between the TSV and CSV formats is that TSV uses a tab `\\t` character as its delimiter and CSV uses a comma `,`." ] }, { "cell_type": "markdown", "metadata": { "id": "bnlhd_Q26SWM" }, "source": [ "The following code block extracts the training and validation data from their TSV files using the `Dataset.from_csv` method.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z9bSysfY6R9A" }, "outputs": [], "source": [ "csv_params = text_classifier.CSVParams(\n", " text_column='sentence', label_column='label', delimiter='\\t')\n", "train_data = text_classifier.Dataset.from_csv(\n", " filename=os.path.join(os.path.join(data_dir, 'train.tsv')),\n", " csv_params=csv_params)\n", "validation_data = text_classifier.Dataset.from_csv(\n", " filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),\n", " csv_params=csv_params)" ] }, { "cell_type": "markdown", "metadata": { "id": "wFPwFUTJrfqn" }, "source": [ "## Average Word Embedding Model\n", "\n", "Model Maker's Text Classifier supports two classifiers with distinct model architectures: an average word embedding model and a BERT model. The first demo classifier wil use an average word embedding architecture." ] }, { "cell_type": "markdown", "metadata": { "id": "rxYAUbIf6x_e" }, "source": [ "To create and train a text classifier we need to set some `TextClassifierOptions`. These options require us to specify a `supported_model`, which can take the value `AVERAGE_WORD_EMBEDDING_CLASSIFIER` or `MOBILEBERT_CLASSIFIER`. We'll use `AVERAGE_WORD_EMBEDDING_CLASSIFIER` for now.\n", "\n", "For more information on the `TextClassifierOptions` class and its fields, see the TextClassifierOptions section below.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LBxLuS2Z60YF" }, "outputs": [], "source": [ "supported_model = text_classifier.SupportedModels.AVERAGE_WORD_EMBEDDING_CLASSIFIER\n", "hparams = text_classifier.AverageWordEmbeddingHParams(epochs=10, batch_size=32, learning_rate=0, export_dir=\"awe_exported_models\")\n", "options = text_classifier.TextClassifierOptions(supported_model=supported_model, hparams=hparams)" ] }, { "cell_type": "markdown", "metadata": { "id": "IUcA9fTr7WP2" }, "source": [ "Now we can use the training and validation data with the `TextClassifierOptions` we've defined to create and train a text classifier. To do so, we use the `TextClassifier.create` function." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Slivdb647VyL" }, "outputs": [], "source": [ "model = text_classifier.TextClassifier.create(train_data, validation_data, options)" ] }, { "cell_type": "markdown", "metadata": { "id": "hR9IgXGg8gOD" }, "source": [ "Evaluate the model on the validation data and print the loss and accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KjFcqOlo8hz5" }, "outputs": [], "source": [ "metrics = model.evaluate(validation_data)\n", "print(f'Test loss:{metrics[0]}, Test accuracy:{metrics[1]}')" ] }, { "cell_type": "markdown", "metadata": { "id": "wAB8E-CS9DXf" }, "source": [ "After training the model we can export it as a TFLite file for on-device applications. We can also export the labels used during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sRWMw1aqM1FB" }, "outputs": [], "source": [ "model.export_model()\n", "model.export_labels(export_dir=options.hparams.export_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "n0QKFRJu0kEy" }, "source": [ "You can download the TFLite model using the left sidebar of Colab." ] }, { "cell_type": "markdown", "metadata": { "id": "CatcJMktM3Ms" }, "source": [ "## BERT-classifier\n", "\n", "Now let's train a text classifier based on the [MobileBERT](https://arxiv.org/abs/2004.02984) model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8kBx9EggOiFB" }, "outputs": [], "source": [ "supported_model = text_classifier.SupportedModels.MOBILEBERT_CLASSIFIER\n", "hparams = text_classifier.BertHParams(epochs=2, batch_size=48, learning_rate=3e-5, export_dir=\"bert_exported_models\")\n", "options = text_classifier.TextClassifierOptions(supported_model=supported_model, hparams=hparams)" ] }, { "cell_type": "markdown", "metadata": { "id": "NpKkDmLHPwKz" }, "source": [ "Create and train the text classifier like we did with the average word embedding-based classifier.\n", "\n", "WARNING: This can take ~25 minutes on a GPU runtime and nearly 7 hours on CPU. We strongly recommend using a GPU runtime." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rhRswkG9Pw3f" }, "outputs": [], "source": [ "bert_model = text_classifier.TextClassifier.create(train_data, validation_data, options)" ] }, { "cell_type": "markdown", "metadata": { "id": "ovco6Y2mTjgL" }, "source": [ "Evaluate the model. Note the improved performance compared to the average word embedding-based classifier." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SvhFpBbwTjGS" }, "outputs": [], "source": [ "metrics = bert_model.evaluate(validation_data)\n", "print(f'Test loss:{metrics[0]}, Test accuracy:{metrics[1]}')" ] }, { "cell_type": "markdown", "metadata": { "id": "4-B_M19aT3ZZ" }, "source": [ "The MobileBERT model is over 100MB so when we export the BERT-based classifier as a TFLite model, it will help to use quantization which can bring the TFLite model size down to 28MB." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UFC_XVZAT7F4" }, "outputs": [], "source": [ "from mediapipe_model_maker import quantization\n", "quantization_config = quantization.QuantizationConfig.for_dynamic()\n", "bert_model.export_model(quantization_config=quantization_config)\n", "bert_model.export_labels(export_dir=options.hparams.export_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "tZH3N-_P5XEB" }, "source": [ "## TextClassifierOptions\n", "\n", "We can configure text classifier training with `TextClassifierOptions`, which takes one required parameter:\n", "\n", "* `supported_model` which describes the model architecture that the text classifier is based on. It can be either an `AVERAGE_WORD_EMBEDDING_CLASSIFIER` or a `MOBILEBERT_CLASSIFIER`.\n", "\n", "`TextClassifierOptions` can also take two optional parameters:\n", "\n", "* `hparams` which describes hyperparameters used during model training. This takes an `HParams` object.\n", "* `model_options` which describes configurable parameters related to the model architecture or data preprocessing. For an average word-embedding classifier, this field takes an `AverageWordEmbeddingModelOptions` object. For a BERT-based classifier, this field takes a `BertModelOptions` object.\n", "\n", "If these fields aren't set, model creation and training will be run with predefined default values.\n", "\n", "`HParams` has the following list of customizable parameters which affect model accuracy:\n", "* `learning_rate`: The learning rate to use for gradient descent-based optimizers. Defaults to 3e-5 for the BERT-based classifier and 0 for the average word-embedding classifier because it does not need such an optimizer.\n", "* `batch_size`: Batch size for training. Defaults to 32 for the average word-embedding classifier and 48 for the BERT-based classifier.\n", "* `epochs`: Number of training iterations over the dataset. Defaults to 10 for the average word-embedding classifier and 3 for the BERT-based classifier.\n", "* `steps_per_epoch`: An optional integer that indicates the number of training steps per epoch. If not set, the training pipeline calculates the default steps per epoch as the training dataset size divided by batch size.\n", "* `shuffle`: True if the dataset is shuffled before training. Defaults to False.\n", "\n", "Additional `HParams` parameters that do not affect model accuracy:\n", "* `export_dir`: The location of the model checkpoint files and exported model files.\n", "\n", "`AverageWordEmbeddingModelOptions` has the following list of customizable parameters related to the model architecture:\n", "* `seq_len`: the length of the input sequence for the model. Defaults to 256.\n", "* `wordvec_dim`: the dimension of the word embeddings. Defaults to 16.\n", "* `dropout_rate`: The rate used in the model's dropout layer. Defaults to 0.2.\n", "\n", "It also has the following customizable parameters related to data preprocessing:\n", "* `do_lower_case`: whether text input is converted to lower case before training or inference. Defaults to True.\n", "* `vocab_size`: the maximum size of the vocab generated from the set of text data.\n", "\n", "`BertModelOptions` has the following list of customizable parameters related to the model architecture:\n", "* `seq_len`: the length of the input sequence for the BERT-encoder. Defaults to 128\n", "* `dropout_rate`: the rate used in the classifier's dropout layer. Defaults to 0.1.\n", "* `do_fine_tuning`: whether the BERT-encoder is unfrozen and should be trainable along with the classifier layers. Defaults to True." ] }, { "cell_type": "markdown", "metadata": { "id": "1zS4thLtu7os" }, "source": [ "## Benchmarks\n", "\n", "Below is a summary of our benchmarking results for the average word-embedding and BERT-based classifiers featured in this tutorial. To optimize model performance for your use-case, it's worthwhile to experiment with different model and training parameters in order to obtain the highest test accuracy. Refer to the TextClassifierOptions section for more information on customizing these parameters.\n", "\n", "\u003ctable\u003e\n", "\u003cthead\u003e\n", "\u003ctr\u003e\n", "\u003cth\u003eModel architecture\u003c/th\u003e\n", "\u003cth\u003eTest Accuracy\u003c/th\u003e\n", "\u003cth\u003eModel Size\u003c/th\u003e\n", "\u003cth\u003eCPU 1 Thread Latency on Pixel 6 (ms)\u003c/th\u003e\n", "\u003cth\u003eGPU Latency on Pixel 6 (ms)\u003c/th\u003e\n", "\u003c/tr\u003e\n", "\u003c/thead\u003e\n", "\u003ctbody\u003e\n", "\u003ctr\u003e\n", "\u003ctd\u003eAverage Word-Embedding\u003c/td\u003e\n", "\u003ctd\u003e81.0%\u003c/td\u003e\n", "\u003ctd\u003e776K\u003c/td\u003e\n", "\u003ctd\u003e0.03\u003c/td\u003e\n", "\u003ctd\u003e0.43\u003c/td\u003e\n", "\u003c/tr\u003e\n", "\u003ctr\u003e\n", "\u003ctd\u003eMobileBERT\u003c/td\u003e\n", "\u003ctd\u003e90.36%\u003c/td\u003e\n", "\u003ctd\u003e28.4MB\u003c/td\u003e\n", "\u003ctd\u003e53.38\u003c/td\u003e\n", "\u003ctd\u003e74.22\u003c/td\u003e\n", "\u003c/tr\u003e\n", "\u003c/tbody\u003e\n", "\u003c/table\u003e" ] } ], "metadata": { "accelerator": "GPU", "colab": { "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1TuyWrMT8DKKN9t0kN9mSYpMvAaiX1nWh", "timestamp": 1668478912102 }, { "file_id": "1L7gx9o6hhBxtYyCamkG4O2UslZKbE8fH", "timestamp": 1668406581227 } ], "toc_visible": true }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }