{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Tce3stUlHN0L" }, "source": [ "##### Copyright 2023 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2025-02-26T12:05:34.142530Z", "iopub.status.busy": "2025-02-26T12:05:34.142096Z", "iopub.status.idle": "2025-02-26T12:05:34.145903Z", "shell.execute_reply": "2025-02-26T12:05:34.145343Z" }, "id": "tuOe1ymfHZPu" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "36EdAGhThQov" }, "source": [ "# Uplifting with Decision Forests\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View on GitHub\n", " \n", " Download notebook\n", "
\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2j8GzKvfVvF8" }, "source": [ "Welcome to the *Uplifting* Tutorial for TensorFlow Decision Forests (TF-DF). In this tutorial, you will learn what uplifting is, why it is so important, and how to do it in TF-DF.\n", "\n", "This tutorial assumes you are familiar with the fundaments of TF-DF, in particular the installation procedure. The [beginner tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab) is a great place to start learning about TF-DF.\n", "\n", "In this colab, you will:\n", "\n", "- Learn what an uplift modeling is.\n", "- Train a Uplift Random Forest model on the **Hillstrom Email Marketing** dataset.\n", "- Evaluate the quality of this model.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "MQIPhTQVW19g" }, "source": [ "## Installing TensorFlow Decision Forests\n", "\n", "Install TF-DF by running the following cell.\n", "\n", "[Wurlitzer](https://pypi.org/project/wurlitzer/) is needed to display the detailed training logs in Colabs (when using `verbose=2` in the model constructor).\n", "\n", "Tensorflow Datasets is needed download the dataset used in this tutorial." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:34.148710Z", "iopub.status.busy": "2025-02-26T12:05:34.148172Z", "iopub.status.idle": "2025-02-26T12:05:36.720368Z", "shell.execute_reply": "2025-02-26T12:05:36.719589Z" }, "id": "oiz5HmMyWxgd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting tensorflow_decision_forests\r\n", " Using cached tensorflow_decision_forests-1.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting wurlitzer\r\n", " Using cached wurlitzer-3.1.1-py3-none-any.whl.metadata (2.5 kB)\r\n", "Requirement already satisfied: tensorflow-datasets in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (4.9.3)\r\n", "Requirement already satisfied: numpy in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.0.2)\r\n", "Requirement already satisfied: pandas in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.2.3)\r\n", "Requirement already satisfied: tensorflow==2.18.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.18.0)\r\n", "Requirement already satisfied: six in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (1.17.0)\r\n", "Requirement already satisfied: absl-py in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.1.0)\r\n", "Requirement already satisfied: wheel in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (0.45.1)\r\n", "Requirement already satisfied: tf-keras~=2.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow_decision_forests) (2.18.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting ydf (from tensorflow_decision_forests)\r\n", " Using cached ydf-0.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.5 kB)\r\n", "Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (1.6.3)\r\n", "Requirement already satisfied: flatbuffers>=24.3.25 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (25.2.10)\r\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (0.6.0)\r\n", "Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (0.2.0)\r\n", "Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (18.1.1)\r\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (3.4.0)\r\n", "Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (24.2)\r\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (3.20.3)\r\n", "Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (2.32.3)\r\n", "Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (75.8.1)\r\n", "Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (2.5.0)\r\n", "Requirement already satisfied: typing-extensions>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (4.12.2)\r\n", "Requirement already satisfied: wrapt>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (1.17.2)\r\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (1.71.0rc2)\r\n", "Requirement already satisfied: tensorboard<2.19,>=2.18 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (2.18.0)\r\n", "Requirement already satisfied: keras>=3.5.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (3.8.0)\r\n", "Requirement already satisfied: h5py>=3.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (3.13.0)\r\n", "Requirement already satisfied: ml-dtypes<0.5.0,>=0.4.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (0.4.1)\r\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow==2.18.0->tensorflow_decision_forests) (0.37.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: array-record in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (0.5.1)\r\n", "Requirement already satisfied: click in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (8.1.8)\r\n", "Requirement already satisfied: dm-tree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (0.1.8)\r\n", "Requirement already satisfied: etils>=0.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets) (1.5.2)\r\n", "Requirement already satisfied: promise in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (2.3)\r\n", "Requirement already satisfied: psutil in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (7.0.0)\r\n", "Requirement already satisfied: tensorflow-metadata in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (1.16.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: toml in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (0.10.2)\r\n", "Requirement already satisfied: tqdm in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-datasets) (4.67.1)\r\n", "Requirement already satisfied: fsspec in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets) (2025.2.0)\r\n", "Requirement already satisfied: importlib_resources in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets) (6.5.2)\r\n", "Requirement already satisfied: zipp in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets) (3.21.0)\r\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow==2.18.0->tensorflow_decision_forests) (3.4.1)\r\n", "Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow==2.18.0->tensorflow_decision_forests) (3.10)\r\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow==2.18.0->tensorflow_decision_forests) (2.3.0)\r\n", "Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorflow==2.18.0->tensorflow_decision_forests) (2025.1.31)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: python-dateutil>=2.8.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2.9.0.post0)\r\n", "Requirement already satisfied: pytz>=2020.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2025.1)\r\n", "Requirement already satisfied: tzdata>=2022.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pandas->tensorflow_decision_forests) (2025.1)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO: pip is looking at multiple versions of ydf to determine which version is compatible with other requirements. This could take a while.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " Downloading ydf-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)\r\n", "Requirement already satisfied: rich in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.5.0->tensorflow==2.18.0->tensorflow_decision_forests) (13.9.4)\r\n", "Requirement already satisfied: namex in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.5.0->tensorflow==2.18.0->tensorflow_decision_forests) (0.0.8)\r\n", "Requirement already satisfied: optree in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from keras>=3.5.0->tensorflow==2.18.0->tensorflow_decision_forests) (0.14.0)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0->tensorflow_decision_forests) (3.7)\r\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0->tensorflow_decision_forests) (0.7.2)\r\n", "Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0->tensorflow_decision_forests) (3.1.3)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: importlib-metadata>=4.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown>=2.6.8->tensorboard<2.19,>=2.18->tensorflow==2.18.0->tensorflow_decision_forests) (8.6.1)\r\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow==2.18.0->tensorflow_decision_forests) (3.0.2)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: markdown-it-py>=2.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.5.0->tensorflow==2.18.0->tensorflow_decision_forests) (3.0.0)\r\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from rich->keras>=3.5.0->tensorflow==2.18.0->tensorflow_decision_forests) (2.19.1)\r\n", "Requirement already satisfied: mdurl~=0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow==2.18.0->tensorflow_decision_forests) (0.1.2)\r\n", "Using cached tensorflow_decision_forests-1.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)\r\n", "Using cached wurlitzer-3.1.1-py3-none-any.whl (8.6 kB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading ydf-0.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.5 MB)\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Installing collected packages: ydf, wurlitzer, tensorflow_decision_forests\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed tensorflow_decision_forests-1.11.0 wurlitzer-3.1.1 ydf-0.9.0\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "pip install tensorflow_decision_forests wurlitzer tensorflow-datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "2LIE3UDMXeB4" }, "source": [ "## Importing libraries" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:36.723749Z", "iopub.status.busy": "2025-02-26T12:05:36.723235Z", "iopub.status.idle": "2025-02-26T12:05:39.940194Z", "shell.execute_reply": "2025-02-26T12:05:39.939424Z" }, "id": "ue7Q-ysiPOmG" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-02-26 12:05:36.978454: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1740571537.000097 9083 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1740571537.006625 9083 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "data": { "text/html": [ "\n", "

🌲 Try YDF, the successor of\n", " TensorFlow\n", " Decision Forests using the same algorithms but with more features and faster\n", " training!\n", "

\n", "
\n", "
\n", " \n", " Old code

\n", "
\n",
       "import tensorflow_decision_forests as tfdf\n",
       "\n",
       "tf_ds = tfdf.keras.pd_dataframe_to_tf_dataset(ds, label=\"l\")\n",
       "model = tfdf.keras.RandomForestModel(label=\"l\")\n",
       "model.fit(tf_ds)\n",
       "
\n", "
\n", "
\n", "
\n", " \n", " New code

\n", "
\n",
       "import ydf\n",
       "\n",
       "model = ydf.RandomForestLearner(label=\"l\").train(ds)\n",
       "
\n", "
\n", "
\n", "

(Learn more in the migration\n", " guide)

\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import tensorflow_decision_forests as tfdf\n", "\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "import math\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "id": "bN7quUfTXjaA" }, "source": [ "The hidden code cell limits the output height in colab.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "cellView": "form", "execution": { "iopub.execute_input": "2025-02-26T12:05:39.943454Z", "iopub.status.busy": "2025-02-26T12:05:39.942847Z", "iopub.status.idle": "2025-02-26T12:05:39.951028Z", "shell.execute_reply": "2025-02-26T12:05:39.950382Z" }, "id": "nFP4KJ79Xl3J" }, "outputs": [], "source": [ "#@title\n", "\n", "from IPython.core.magic import register_line_magic\n", "from IPython.display import Javascript\n", "from IPython.display import display as ipy_display\n", "\n", "# Some of the model training logs can cover the full\n", "# screen if not compressed to a smaller viewport.\n", "# This magic allows setting a max height for a cell.\n", "@register_line_magic\n", "def set_cell_height(size):\n", " ipy_display(\n", " Javascript(\"google.colab.output.setIframeHeight(0, true, {maxHeight: \" +\n", " str(size) + \"})\"))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:39.953701Z", "iopub.status.busy": "2025-02-26T12:05:39.953146Z", "iopub.status.idle": "2025-02-26T12:05:39.956821Z", "shell.execute_reply": "2025-02-26T12:05:39.956173Z" }, "id": "jnpiCdRKXvir" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found TensorFlow Decision Forests v1.11.0\n" ] } ], "source": [ "# Check the version of TensorFlow Decision Forests\n", "print(\"Found TensorFlow Decision Forests v\" + tfdf.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "9SqXMEGLX0ry" }, "source": [ "## What is uplift modeling?\n", "\n", "[Uplift modeling](https://en.wikipedia.org/wiki/Uplift_modelling) is a statistical modeling technique to predict the **incremental impact of an action** on a subject. The action is often referred to as a **treatment** that may or may not be applied.\n", "\n", "Uplift modeling is often used in targeted marketing campaigns to predict the increase in the likelihood of a person making a purchase (or any other desired action) based on the marketing exposition they receive.\n", "\n", "For example, uplift modeling can predict the **effect** of an email. The effect is defined as the **conditional probability**\n", "\\begin{align}\n", "\\text{effect}(\\text{email}) = &\\Pr(\\text{outcome}=\\text{purchase}\\ \\vert\\ \\text{treatment}=\\text{with email})\\\\ &- \\Pr(\\text{outcome}=\\text{purchase} \\ \\vert\\ \\text{treatment}=\\text{no email}),\n", "\\end{align}\n", "where $\\Pr(\\text{outcome}=\\text{purchase}\\ \\vert\\ ...)$\n", "is the probability of purchase depending on the receiving or not an email.\n", "\n", "Compare this to a classification model: With a classification model, one can predict the probability of a purchase. However, customers with a high probability are likely to spend money in the store regardless of whether or not they received an email.\n", "\n", "Similarly, one can use **numerical uplifting** to predict the numerical **increase in spend** when receiving an email. In comparison, a regression model can only increase the expected spend, which is a less useful metric in many cases.\n", "\n", "### Defining uplift models in TF-DF\n", "\n", "TF-DF expects uplifting datasets to be presented in a \"flat\" format.\n", "A dataset of customers might look like this\n", "\n", "treatment | outcome | feature_1 | feature_2\n", "--------- | ------- | --------- | ---------\n", "0 | 1 | 0.1 | blue \n", "0 | 0 | 0.2 | blue \n", "1 | 1 | 0.3 | blue \n", "1 | 1 | 0.4 | blue \n", "\n", "\n", "The **treatment** is a binary variable indicating whether or not the example has received treatment. In the above example, the treatment indicates if the customer has received an email or not. The **outcome** (label) indicates the status of the example after receiving the treatment (or not). TF-DF supports categorical outcomes for categorical uplifting and numerical outcomes for numerical uplifting.\n", "\n", "**Note**: Uplifting is also frequently used in medical contexts. Here the *treatment* can be a medical treatment (e.g. administering a vaccine), the label can be an indicator of quality of life (e.g. whether the patient got sick). This also explains the nomenclature of uplift modeling." ] }, { "cell_type": "markdown", "metadata": { "id": "kVaDog4ldPEY" }, "source": [ "## Training an uplifting model\n", "\n", "In this example, we will use the *Hillstrom Email Marketing dataset*.\n", "\n", "This dataset contains 64,000 customers who last purchased within twelve months. The customers were involved in an e-mail test:\n", "\n", "- 1/3 were randomly chosen to receive an e-mail campaign featuring Mens merchandise.\n", "- 1/3 were randomly chosen to receive an e-mail campaign featuring Womens merchandise.\n", "- 1/3 were randomly chosen to not receive an e-mail campaign.\n", "\n", "During a period of two weeks following the e-mail campaign, results were tracked. The task is to tell if the Mens or Womens e-mail campaign was successful.\n", "\n", "Read more about dataset [in its documentation]( https://blog.minethatdata.com/2008/03/minethatdata-e-mail-analytics-and-data.html). This tutorial uses the dataset as curated by [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/hillstrom)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:39.959559Z", "iopub.status.busy": "2025-02-26T12:05:39.959130Z", "iopub.status.idle": "2025-02-26T12:05:43.623758Z", "shell.execute_reply": "2025-02-26T12:05:43.623127Z" }, "id": "1veZ9nJZPGsv" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0000 00:00:1740571542.697544 9083 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13638 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5\n", "I0000 00:00:1740571542.699822 9083 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5\n", "I0000 00:00:1740571542.701960 9083 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5\n", "I0000 00:00:1740571542.703996 9083 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-26 12:05:43.599785: W tensorflow/core/kernels/data/cache_dataset_ops.cc:914] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
channelconversionhistoryhistory_segmentmensnewbierecencysegmentspendvisitwomenszip_code
0b'Web'029.990000b'1) $0 - $100'106b'Womens E-Mail'0.000b'Surburban'
1b'Web'0150.380005b'2) $100 - $200'019b'Womens E-Mail'0.001b'Surburban'
2b'Phone'0602.960022b'5) $500 - $750'114b'Womens E-Mail'0.000b'Surburban'
3b'Multichannel'0341.010010b'3) $200 - $350'009b'Womens E-Mail'0.011b'Urban'
4b'Phone'097.180000b'1) $0 - $100'013b'Womens E-Mail'0.011b'Surburban'
5b'Web'083.269997b'1) $0 - $100'105b'Mens E-Mail'0.000b'Urban'
6b'Web'0331.170013b'3) $200 - $350'108b'Womens E-Mail'0.000b'Surburban'
7b'Multichannel'0628.400024b'5) $500 - $750'119b'No E-Mail'0.010b'Surburban'
8b'Phone'0134.610001b'2) $100 - $200'106b'No E-Mail'0.010b'Rural'
9b'Web'0141.210007b'2) $100 - $200'019b'Mens E-Mail'0.011b'Surburban'
\n", "
" ], "text/plain": [ " channel conversion history history_segment mens newbie \\\n", "0 b'Web' 0 29.990000 b'1) $0 - $100' 1 0 \n", "1 b'Web' 0 150.380005 b'2) $100 - $200' 0 1 \n", "2 b'Phone' 0 602.960022 b'5) $500 - $750' 1 1 \n", "3 b'Multichannel' 0 341.010010 b'3) $200 - $350' 0 0 \n", "4 b'Phone' 0 97.180000 b'1) $0 - $100' 0 1 \n", "5 b'Web' 0 83.269997 b'1) $0 - $100' 1 0 \n", "6 b'Web' 0 331.170013 b'3) $200 - $350' 1 0 \n", "7 b'Multichannel' 0 628.400024 b'5) $500 - $750' 1 1 \n", "8 b'Phone' 0 134.610001 b'2) $100 - $200' 1 0 \n", "9 b'Web' 0 141.210007 b'2) $100 - $200' 0 1 \n", "\n", " recency segment spend visit womens zip_code \n", "0 6 b'Womens E-Mail' 0.0 0 0 b'Surburban' \n", "1 9 b'Womens E-Mail' 0.0 0 1 b'Surburban' \n", "2 4 b'Womens E-Mail' 0.0 0 0 b'Surburban' \n", "3 9 b'Womens E-Mail' 0.0 1 1 b'Urban' \n", "4 3 b'Womens E-Mail' 0.0 1 1 b'Surburban' \n", "5 5 b'Mens E-Mail' 0.0 0 0 b'Urban' \n", "6 8 b'Womens E-Mail' 0.0 0 0 b'Surburban' \n", "7 9 b'No E-Mail' 0.0 1 0 b'Surburban' \n", "8 6 b'No E-Mail' 0.0 1 0 b'Rural' \n", "9 9 b'Mens E-Mail' 0.0 1 1 b'Surburban' " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the dataset\n", "import tensorflow_datasets as tfds\n", "raw_train, raw_test = tfds.load('hillstrom', split=['train[:80%]', 'train[20%:]'])\n", "\n", "# Display the first 10 examples of the test fold.\n", "pd.DataFrame(list(raw_test.batch(10).take(1))[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "5stnFbyKaIgn" }, "source": [ "### Dataset preprocessing\n", "\n", "Since TF-DF currently only supports binary treatments, combine the \"Men's Email\" and the \"Women's Email\" campaign. This tutorial uses the binary variable `conversion` as outcome. This means that the problem is a **Categorical Uplifting** problem. If we were using the numerical variable `spend`, the problem would be a **Numerical Uplifting** problem." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:43.627514Z", "iopub.status.busy": "2025-02-26T12:05:43.626902Z", "iopub.status.idle": "2025-02-26T12:05:43.730947Z", "shell.execute_reply": "2025-02-26T12:05:43.730329Z" }, "id": "dLpAw7jibIrh" }, "outputs": [], "source": [ "def prepare_dataset(example):\n", " # Use a binary treatment class.\n", " example['treatment'] = 1 if example['segment'] == b'Mens E-Mail' or example['segment'] == b'Womens E-Mail' else 0\n", " outcome = example['conversion']\n", " # Restrict the dataset to the input features.\n", " input_features = ['channel', 'history', 'mens', 'womens', 'newbie', 'recency', 'zip_code', 'treatment']\n", " example = {feature: example[feature] for feature in input_features}\n", " return example, outcome\n", "\n", "train_ds = raw_train.map(prepare_dataset).batch(100)\n", "test_ds = raw_test.map(prepare_dataset).batch(100)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z-mtKmd-RoOu" }, "source": [ "### Model training\n", "\n", "Finally, train and evaluate the model as usual. Note that TF-DF only supports Random Forest models for uplifting." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:43.734085Z", "iopub.status.busy": "2025-02-26T12:05:43.733710Z", "iopub.status.idle": "2025-02-26T12:05:52.260169Z", "shell.execute_reply": "2025-02-26T12:05:52.259536Z" }, "id": "-OZN8t8LRn38" }, "outputs": [ { "data": { "application/javascript": [ "google.colab.output.setIframeHeight(0, true, {maxHeight: 300})" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Use /tmpfs/tmp/tmphsli6pat as temporary training directory\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Reading training dataset...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training tensor examples:\n", "Features: {'channel': , 'history': , 'mens': , 'womens': , 'newbie': , 'recency': , 'zip_code': , 'treatment': }\n", "Label: Tensor(\"data_8:0\", shape=(None,), dtype=int64)\n", "Weights: None\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Normalized tensor features:\n", " {'channel': SemanticTensor(semantic=, tensor=), 'history': SemanticTensor(semantic=, tensor=), 'mens': SemanticTensor(semantic=, tensor=), 'womens': SemanticTensor(semantic=, tensor=), 'newbie': SemanticTensor(semantic=, tensor=), 'recency': SemanticTensor(semantic=, tensor=), 'zip_code': SemanticTensor(semantic=, tensor=)}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training dataset read in 0:00:05.088312. Found 51200 examples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Standard output detected as not visible to the user e.g. running in a notebook. Creating a training log redirection. If training gets stuck, try calling tfdf.keras.set_training_logs_redirection(False).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model trained in 0:00:02.392275\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Compiling model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model compiled.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%set_cell_height 300\n", "\n", "# Configure the model and its hyper-parameters.\n", "model = tfdf.keras.RandomForestModel(\n", " verbose=2,\n", " task=tfdf.keras.Task.CATEGORICAL_UPLIFT,\n", " uplift_treatment='treatment'\n", ")\n", "\n", "# Train the model.\n", "model.fit(train_ds)" ] }, { "cell_type": "markdown", "metadata": { "id": "XKhtZuLhGtv_" }, "source": [ "# Evaluating Uplift models.\n", "\n", "## Metrics for Uplift models\n", "\n", "The two most important metrics for evaluating upift models are the **AUUC** (Area Under the Uplift Curve) metric and the **Qini** (Area Under the Qini Curve) metric. This is similar to the use of AUC and accuracy for classification problems. For both metrics, the larger they are, the better.\n", "\n", "Both AUUC and Qini are **not** normalized metrics. This means that the best possible value of the metric can vary from dataset to dataset. This is different from, for example, the AUC matric that always varies between 0 and 1.\n", "\n", "A formal definition of AUUC is below. For more information about these metrics, see [Guelman](https://diposit.ub.edu/dspace/bitstream/2445/65123/1/Leo%20Guelman_PhD_THESIS.pdf) and [Betlei et al.](https://arxiv.org/pdf/2012.09897.pdf)" ] }, { "cell_type": "markdown", "metadata": { "id": "AMSpNTTZmuzv" }, "source": [ "## Model Self-Evaluation\n", "\n", "TF-DF Random Forest models perform self-evaluation on the out-of-bag examples of the training dataset. For uplift models, they expose the AUUC and the Qini metric. You can directly retrieve the two metrics on the training dataset through the inspector\n", "\n", "Later, we are going to recompute the AUUC metric \"manually\" on the test dataset. Note that two metrics are not expected to be exactly equal (out-of-bag on train vs test) since the AUUC is not a normalized metric." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:52.263559Z", "iopub.status.busy": "2025-02-26T12:05:52.262962Z", "iopub.status.idle": "2025-02-26T12:05:52.268670Z", "shell.execute_reply": "2025-02-26T12:05:52.267994Z" }, "id": "OsN1R9mT_8T6" }, "outputs": [ { "data": { "text/plain": [ "Evaluation(num_examples=51200, accuracy=None, loss=None, rmse=None, ndcg=None, aucs=None, auuc=0.0022032308892709586, qini=-0.00017325819500263418)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The self-evaluation is available through the model inspector\n", "insp = model.make_inspector()\n", "insp.evaluation()" ] }, { "cell_type": "markdown", "metadata": { "id": "WErGZZ27HWJN" }, "source": [ "## Manually computing the AUUC\n", "\n", "In this section, we manually compute the AUUC and plot the uplift curves.\n", "\n", "The next few paragraphs explain the AUUC metric in more detail and may be skipped.\n", "\n", "### Computing the AUUC\n", "\n", "Suppose you have a labeled dataset with $|T|$ examples with treatment and $|C|$ examples without treatment, called *control* examples. For each example, the uplift model $f$ produces the conditional probability that a treatment on the example will yield a positive outcome.\n", "\n", "Suppose a decision-maker needs to decide which clients to send an email using an uplift model $f$. The model produces a (conditional) probability that the email will result in a conversion. The decision-maker might therefore just pick the number $k$ of emails to send and send those $k$ emails to the clients with the highest probability.\n", "\n", "Using a labeled test dataset, it is possible to study the impact of $k$ on the success of the campaign. First, we are interested in the ratio $\\frac{|C \\cap T|}{|T|}$ of clients that received an email that converted versus total number of clients that received an email. Here $C$ is the set of clients that received an email and converted and $T$ is the total number of clients that received an email. We plot this ratio against $k$.\n", "\n", "Ideally, we like to have this curve increase steeply. This would mean that the model prioritizes sending email to those clients that will generate a conversion when receiving an email." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:52.271460Z", "iopub.status.busy": "2025-02-26T12:05:52.270937Z", "iopub.status.idle": "2025-02-26T12:05:57.534141Z", "shell.execute_reply": "2025-02-26T12:05:57.533530Z" }, "id": "xUGNWKkkkl-s" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", " 1/512 [..............................] - ETA: 3:23" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 10/512 [..............................] - ETA: 2s " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 19/512 [>.............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 28/512 [>.............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 38/512 [=>............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 48/512 [=>............................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 57/512 [==>...........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 66/512 [==>...........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 76/512 [===>..........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 85/512 [===>..........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", " 95/512 [====>.........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "105/512 [=====>........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "115/512 [=====>........................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "125/512 [======>.......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "135/512 [======>.......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "145/512 [=======>......................] - ETA: 2s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "154/512 [========>.....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "164/512 [========>.....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "174/512 [=========>....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "184/512 [=========>....................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "193/512 [==========>...................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "202/512 [==========>...................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "211/512 [===========>..................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "221/512 [===========>..................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "230/512 [============>.................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "239/512 [=============>................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "248/512 [=============>................] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "257/512 [==============>...............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "266/512 [==============>...............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "275/512 [===============>..............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "284/512 [===============>..............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "293/512 [================>.............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "302/512 [================>.............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "311/512 [=================>............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "320/512 [=================>............] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "330/512 [==================>...........] - ETA: 1s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "339/512 [==================>...........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "349/512 [===================>..........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "358/512 [===================>..........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "368/512 [====================>.........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "378/512 [=====================>........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "388/512 [=====================>........] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "398/512 [======================>.......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "407/512 [======================>.......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "416/512 [=======================>......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "425/512 [=======================>......] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "435/512 [========================>.....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "444/512 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "453/512 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "462/512 [==========================>...] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "472/512 [==========================>...] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "482/512 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "491/512 [===========================>..] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "500/512 [============================>.] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "509/512 [============================>.] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "512/512 [==============================] - 3s 6ms/step\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Compute all predictions on the test dataset\n", "predictions = model.predict(test_ds).flatten()\n", "# Extract outcomes and treatments\n", "outcomes = np.concatenate([outcome.numpy() for _, outcome in test_ds])\n", "treatment = np.concatenate([example['treatment'].numpy() for example,_ in test_ds])\n", "control = 1 - treatment\n", "\n", "num_treatments = np.sum(treatment)\n", "# Clients without treatment are called 'control' group\n", "num_control = np.sum(control)\n", "num_examples = len(predictions)\n", "\n", "# Sort labels and treatments according to predictions in descending order\n", "prediction_order = predictions.argsort()[::-1]\n", "outcomes_sorted = outcomes[prediction_order]\n", "treatment_sorted = treatment[prediction_order]\n", "control_sorted = control[prediction_order]\n", "ratio_treatment = np.cumsum(np.multiply(outcomes_sorted, treatment_sorted), axis=0)/num_treatments\n", "\n", "fig, ax = plt.subplots()\n", "ax.plot(ratio_treatment, label='Conversion ratio of treatment')\n", "ax.set_xlabel('k')\n", "ax.set_ylabel('Ratio of conversion')\n", "ax.legend()" ] }, { "cell_type": "markdown", "metadata": { "id": "97IFpq5epHsx" }, "source": [ "Similarly, we can also compute and plot the conversion ratio of those not receiving an email, called the *control group*. Ideally, this curve is initially flat: This would mean that the model does not prioritize sending emails to clients that will generate a conversion despite **not** receiving a email" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:57.536830Z", "iopub.status.busy": "2025-02-26T12:05:57.536568Z", "iopub.status.idle": "2025-02-26T12:05:57.679805Z", "shell.execute_reply": "2025-02-26T12:05:57.679191Z" }, "id": "bIY-oA9alwzY" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ratio_control = np.cumsum(np.multiply(outcomes_sorted, control_sorted), axis=0)/num_control\n", "ax.plot(ratio_control, label='Conversion ratio of control')\n", "ax.legend()\n", "fig" ] }, { "cell_type": "markdown", "metadata": { "id": "q9MopM5MnCK0" }, "source": [ "The AUUC metric measures the area between these two curves, normalizing the y-axis between 0 and 1" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2025-02-26T12:05:57.682362Z", "iopub.status.busy": "2025-02-26T12:05:57.682121Z", "iopub.status.idle": "2025-02-26T12:05:58.034921Z", "shell.execute_reply": "2025-02-26T12:05:58.034292Z" }, "id": "99XXGsq7nQgN" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmpfs/tmp/ipykernel_9083/1475983573.py:11: DeprecationWarning: `trapz` is deprecated. Use `trapezoid` instead, or one of the numerical integration functions in `scipy.integrate`.\n", " auuc = np.trapz(ratio_treatment-ratio_control, dx=1/num_examples)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The AUUC on the test dataset is 0.007513949426065613\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x = np.linspace(0, 1, num_examples)\n", "plt.plot(x,ratio_treatment, label='Conversion ratio of treatment')\n", "plt.plot(x,ratio_control, label='Conversion ratio of control')\n", "plt.fill_between(x, ratio_treatment, ratio_control, where=(ratio_treatment > ratio_control), color='C0', alpha=0.3)\n", "plt.fill_between(x, ratio_treatment, ratio_control, where=(ratio_treatment < ratio_control), color='C1', alpha=0.3)\n", "plt.xlabel('k')\n", "plt.ylabel('Ratio of conversion')\n", "plt.legend()\n", "\n", "# Approximate the integral of the difference between the two curves.\n", "auuc = np.trapz(ratio_treatment-ratio_control, dx=1/num_examples)\n", "print(f'The AUUC on the test dataset is {auuc}')" ] } ], "metadata": { "colab": { "name": "uplift_colab.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }