{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "R2-i8jBl9GRH" }, "source": [ "![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)\n", "\n", "# Advanced RAG example\n", "\n", "Now that you have a good foundation in Redis data structures, search capabilities, and basic RAG with the redisvl client from [/getting_started/02_redisvl](../getting_started/02_redisvl.ipynb).\n", "\n", "We will extend the basic RAG example with a few special topics/techniques:\n", "- Dense content representation\n", "- Query rewriting / expansion\n", "- Semantic caching\n", "- Conversational memory persistence\n", "\n", "## Let's Begin!\n", "\"Open\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Improve accuracy with dense content representations\n", "In the basic example, we took raw chunks of text from our pdf documents and generated embeddings for them to be stored in the vector database. This is okay but one technique we can use to improve the quality of retrieval is to leverage an LLM from OpenAI during ETL. We will prompt the LLM to summarize and decompose the raw pdf text into more discrete propositional phrases. This will enhance the clarity of the text and improve semantic retrieval for RAG.\n", "\n", "The goal is to utilize a preprocessing technique similar to what's outlined here:\n", "https://github.com/langchain-ai/langchain/blob/master/templates/propositional-retrieval/propositional_retrieval/proposal_chain.py\n", "\n", "If you already have a redis-stack instance running locally from before feel free to jump ahead but if not execute the following commands to get the environment properly setup." ] }, { "cell_type": "markdown", "metadata": { "id": "rT9HzsnQ1uiz" }, "source": [ "## Environment Setup\n", "\n", "### Pull Github Materials\n", "Because you are likely running this notebook in **Google Colab**, we need to first\n", "pull the necessary dataset and materials directly from GitHub.\n", "\n", "**If you are running this notebook locally**, FYI you may not need to perform this\n", "step at all." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AJJ2UW6M1ui0" }, "outputs": [], "source": [ "# NBVAL_SKIP\n", "!git clone https://github.com/redis-developer/redis-ai-resources.git temp_repo\n", "!mv temp_repo/python-recipes/RAG/resources .\n", "!rm -rf temp_repo" ] }, { "cell_type": "markdown", "metadata": { "id": "Z67mf6T91ui2" }, "source": [ "### Install Python Dependencies" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "DgxBQFXQ1ui2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "%pip install -q \"redisvl>=0.4.1\" pandas \"unstructured[pdf]\" sentence-transformers langchain langchain-community \"openai>=1.57.0\" tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install Redis Stack\n", "\n", "Later in this tutorial, Redis will be used to store, index, and query vector\n", "embeddings created from PDF document chunks. **We need to make sure we have a Redis\n", "instance available.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### For Colab\n", "Use the shell script below to download, extract, and install [Redis Stack](https://redis.io/docs/getting-started/install-stack/) directly\n", "from the Redis package archive." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NBVAL_SKIP\n", "%%sh\n", "curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg\n", "echo \"deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/redis.list\n", "sudo apt-get update > /dev/null 2>&1\n", "sudo apt-get install redis-stack-server > /dev/null 2>&1\n", "redis-stack-server --daemonize yes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### For Alternative Environments\n", "There are many ways to get the necessary redis-stack instance running\n", "1. On cloud, deploy a [FREE instance of Redis in the cloud](https://redis.com/try-free/). Or, if you have your\n", "own version of Redis Enterprise running, that works too!\n", "2. Per OS, [see the docs](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/)\n", "3. With docker: `docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the Redis Connection URL\n", "\n", "By default this notebook connects to the local instance of Redis Stack. **If you have your own Redis Enterprise instance** - replace REDIS_PASSWORD, REDIS_HOST and REDIS_PORT values with your own." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "import nest_asyncio\n", "# Apply the nest_asyncio patch: let's us run async code in Jupyter\n", "nest_asyncio.apply()\n", "\n", "warnings.filterwarnings('ignore')\n", "\n", "# Replace values below with your own if using Redis Cloud instance\n", "REDIS_HOST = os.getenv(\"REDIS_HOST\", \"localhost\") # ex: \"redis-18374.c253.us-central1-1.gce.cloud.redislabs.com\"\n", "REDIS_PORT = os.getenv(\"REDIS_PORT\", \"6379\") # ex: 18374\n", "REDIS_PASSWORD = os.getenv(\"REDIS_PASSWORD\", \"\") # ex: \"1TNxTEdYRDgIDKM2gDfasupCADXXXX\"\n", "\n", "# If SSL is enabled on the endpoint, use rediss:// as the URL prefix\n", "REDIS_URL = f\"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now that our environment is setup we can again load our financial documents" ] }, { "cell_type": "markdown", "metadata": { "id": "KrtWWU4I1ui3" }, "source": [ "### Dataset Preparation (PDF Documents)\n", "\n", "To best demonstrate Redis as a vector database layer, we will load a single\n", "financial (10k filings) doc and preprocess it using some helpers from LangChain:\n", "\n", "- `PyPDFLoader` is not the only document loader type that LangChain provides. Docs: https://python.langchain.com/docs/integrations/document_loaders/unstructured_file\n", "- `RecursiveCharacterTextSplitter` is what we use to create smaller chunks of text from the doc. Docs: https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "uijl2qFH1ui3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done preprocessing. Created 211 chunks of the original pdf resources/nke-10k-2023.pdf\n" ] } ], "source": [ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "from langchain_community.document_loaders import PyPDFLoader\n", "\n", "# pdf to load\n", "path = 'resources/nke-10k-2023.pdf'\n", "assert os.path.exists(path), f\"File not found: {path}\"\n", "\n", "# load and split\n", "loader = PyPDFLoader(path)\n", "pages = loader.load()\n", "text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500, chunk_overlap=0)\n", "chunks = text_splitter.split_documents(pages)\n", "\n", "print(\"Done preprocessing. Created\", len(chunks), \"chunks of the original pdf\", path)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Document(metadata={'source': 'resources/nke-10k-2023.pdf', 'page': 0, 'page_label': '1'}, page_content=\"Table of Contents\\nUNITED STATES\\nSECURITIES AND EXCHANGE COMMISSION\\nWashington, D.C. 20549\\nFORM 10-K\\n(Mark One)\\n☑ ANNUAL REPORT PURSUANT TO SECTION 13 OR 15(D) OF THE SECURITIES EXCHANGE ACT OF 1934\\nFOR THE FISCAL YEAR ENDED MAY 31, 2023\\nOR\\n☐ TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(D) OF THE SECURITIES EXCHANGE ACT OF 1934\\nFOR THE TRANSITION PERIOD FROM TO .\\nCommission File No. 1-10635\\nNIKE, Inc.\\n(Exact name of Registrant as specified in its charter)\\nOregon 93-0584541\\n(State or other jurisdiction of incorporation) (IRS Employer Identification No.)\\nOne Bowerman Drive, Beaverton, Oregon 97005-6453\\n(Address of principal executive offices and zip code)\\n(503) 671-6453\\n(Registrant's telephone number, including area code)\\nSECURITIES REGISTERED PURSUANT TO SECTION 12(B) OF THE ACT:\\nClass B Common Stock NKE New York Stock Exchange\\n(Title of each class) (Trading symbol) (Name of each exchange on which registered)\\nSECURITIES REGISTERED PURSUANT TO SECTION 12(G) OF THE ACT:\\nNONE\\nIndicate by check mark: YES NO\\n• if the registrant is a well-known seasoned issuer, as defined in Rule 405 of the Securities Act. þ ¨ \\n• if the registrant is not required to file reports pursuant to Section 13 or Section 15(d) of the Act. ¨ þ \\n• whether the registrant (1) has filed all reports required to be filed by Section 13 or 15(d) of the Securities Exchange Act of 1934 during the preceding\\n12 months (or for such shorter period that the registrant was required to file such reports), and (2) has been subject to such filing requirements for the\\npast 90 days.\\nþ ¨ \\n• whether the registrant has submitted electronically every Interactive Data File required to be submitted pursuant to Rule 405 of Regulation S-T\\n(§232.405 of this chapter) during the preceding 12 months (or for such shorter period that the registrant was required to submit such files).\\nþ ¨ \\n• whether the registrant is a large accelerated filer, an accelerated filer, a non-accelerated filer, a smaller reporting company or an emerging growth company. See the definitions of “large accelerated filer,”\\n“accelerated filer,” “smaller reporting company,” and “emerging growth company” in Rule 12b-2 of the Exchange Act.\\nLarge accelerated filer þ Accelerated filer ☐ Non-accelerated filer ☐ Smaller reporting company ☐ Emerging growth company ☐\")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chunks[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### In the previous example, we would have gone ahead and embed the chunks as extracted here.\n", "\n", "Now we will instead leverage an LLM to create dense content representations to improve our retrieval accuracy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup OpenAI as LLM" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import os\n", "import getpass\n", "import openai\n", "\n", "CHAT_MODEL = \"gpt-3.5-turbo-0125\"\n", "\n", "\n", "if \"OPENAI_API_KEY\" not in os.environ:\n", " os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OPENAI_API_KEY\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import tqdm\n", "import json\n", "\n", "\n", "def create_dense_props(chunk):\n", " \"\"\"Create dense representation of raw text content.\"\"\"\n", "\n", " # The system message here should be HEAVILY customized for your specific use case\n", " SYSTEM_PROMPT = \"\"\"\n", " You are a helpful PDF extractor tool. You will be presented with segments from\n", " raw PDF documents composed of 10k SEC filings information about public companies.\n", "\n", " Decompose and summarize the raw content into clear and simple propositions,\n", " ensuring they are interpretable out of context. Consider the following rules:\n", " 1. Split compound sentences into simpler dense phrases that retain existing\n", " meaning.\n", " 2. Simplify technical jargon or wording if possible while retaining existing\n", " meaning.\n", " 2. For any named entity that is accompanied by additional descriptive information,\n", " separate this information into its own distinct proposition.\n", " 3. Decontextualize the proposition by adding necessary modifier to nouns or\n", " entire sentences and replacing pronouns (e.g., \"it\", \"he\", \"she\", \"they\", \"this\", \"that\")\n", " with the full name of the entities they refer to.\n", " 4. Present the results as a list of strings, formatted in JSON, under the key \"propositions\".\n", " \"\"\"\n", "\n", " response = openai.OpenAI().chat.completions.create(\n", " model=CHAT_MODEL,\n", " response_format={ \"type\": \"json_object\" },\n", " messages=[\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": f\"Decompose this raw content using the rules above:\\n{chunk.page_content} \"}\n", " ]\n", " )\n", " res = response.choices[0].message.content\n", "\n", " try:\n", " return json.loads(res)[\"propositions\"]\n", " except Exception as e:\n", " print(f\"Failed to parse propositions\", str(e), flush=True)\n", " # Retry\n", " return create_dense_props(chunk)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create text propositions using OpenAI" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Load from disk to save time or regenerate as needed.\n", "try:\n", " with open(\"resources/propositions.json\", \"r\") as f:\n", " propositions = json.load(f)\n", "except:\n", " # create props\n", " propositions = [create_dense_props(chunk) for chunk in tqdm.tqdm(chunks)]\n", " propositions = [\" \".join(prop) for prop in propositions]\n", "\n", " # Save to disk for faster reload..\n", " with open(\"resources/propositions.json\", \"w\") as f:\n", " json.dump(propositions, f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let's evaluate the proposition vs the raw chunk" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"Registrant check: Well-known seasoned issuer (YES/NO) Registrant check: Required to file reports under Section 13 or 15(d) (YES/NO) Registrant check: Filed all reports required by Section 13 or 15(d) in the past 12 months (YES/NO) and subject to filing requirements for the past 90 days (YES/NO) Registrant check: Submitted all Interactive Data Files required by Rule 405 of Regulation S-T in the past 12 months (YES/NO) Registrant classification: Large accelerated filer (YES), Accelerated filer (NO), Non-accelerated filer (NO), Smaller reporting company (NO), Emerging growth company (NO) Emerging growth company check: Elected not to use extended transition period for new financial accounting standards (YES/NO) Registrant check: Filed a report and attestation on management's assessment of internal control over financial reporting under Section 404(b) of the Sarbanes-Oxley Act (YES/NO) Securities registered check: Registered under Section 12(b) and financial statements reflect correction of errors in previously issued financial statements (YES/NO) Error corrections check: Any restatements requiring recovery analysis of executive officers' incentive-based compensation during recovery period (YES/NO) Registrant check: Shell company status (YES/NO)\"" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "propositions[0]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Document(metadata={'source': 'resources/nke-10k-2023.pdf', 'page': 0, 'page_label': '1'}, page_content=\"Table of Contents\\nUNITED STATES\\nSECURITIES AND EXCHANGE COMMISSION\\nWashington, D.C. 20549\\nFORM 10-K\\n(Mark One)\\n☑ ANNUAL REPORT PURSUANT TO SECTION 13 OR 15(D) OF THE SECURITIES EXCHANGE ACT OF 1934\\nFOR THE FISCAL YEAR ENDED MAY 31, 2023\\nOR\\n☐ TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(D) OF THE SECURITIES EXCHANGE ACT OF 1934\\nFOR THE TRANSITION PERIOD FROM TO .\\nCommission File No. 1-10635\\nNIKE, Inc.\\n(Exact name of Registrant as specified in its charter)\\nOregon 93-0584541\\n(State or other jurisdiction of incorporation) (IRS Employer Identification No.)\\nOne Bowerman Drive, Beaverton, Oregon 97005-6453\\n(Address of principal executive offices and zip code)\\n(503) 671-6453\\n(Registrant's telephone number, including area code)\\nSECURITIES REGISTERED PURSUANT TO SECTION 12(B) OF THE ACT:\\nClass B Common Stock NKE New York Stock Exchange\\n(Title of each class) (Trading symbol) (Name of each exchange on which registered)\\nSECURITIES REGISTERED PURSUANT TO SECTION 12(G) OF THE ACT:\\nNONE\\nIndicate by check mark: YES NO\\n• if the registrant is a well-known seasoned issuer, as defined in Rule 405 of the Securities Act. þ ¨ \\n• if the registrant is not required to file reports pursuant to Section 13 or Section 15(d) of the Act. ¨ þ \\n• whether the registrant (1) has filed all reports required to be filed by Section 13 or 15(d) of the Securities Exchange Act of 1934 during the preceding\\n12 months (or for such shorter period that the registrant was required to file such reports), and (2) has been subject to such filing requirements for the\\npast 90 days.\\nþ ¨ \\n• whether the registrant has submitted electronically every Interactive Data File required to be submitted pursuant to Rule 405 of Regulation S-T\\n(§232.405 of this chapter) during the preceding 12 months (or for such shorter period that the registrant was required to submit such files).\\nþ ¨ \\n• whether the registrant is a large accelerated filer, an accelerated filer, a non-accelerated filer, a smaller reporting company or an emerging growth company. See the definitions of “large accelerated filer,”\\n“accelerated filer,” “smaller reporting company,” and “emerging growth company” in Rule 12b-2 of the Exchange Act.\\nLarge accelerated filer þ Accelerated filer ☐ Non-accelerated filer ☐ Smaller reporting company ☐ Emerging growth company ☐\")" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chunks[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create embeddings from propositions data" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from redisvl.utils.vectorize import HFTextVectorizer\n", "\n", "hf = HFTextVectorizer(\"sentence-transformers/all-MiniLM-L6-v2\")\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "\n", "prop_embeddings = hf.embed_many([\n", " proposition for proposition in propositions\n", "])\n", "\n", "# Check to make sure we've created enough embeddings, 1 per document chunk\n", "len(prop_embeddings) == len(propositions) == len(chunks)" ] }, { "cell_type": "markdown", "metadata": { "id": "5baI0xDQ1ui-" }, "source": [ "### Define a schema and create an index\n", "\n", "Below we connect to Redis and create an index that contains a text field, tag field, and vector field." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "zB1EW_9n1ui-" }, "outputs": [], "source": [ "from redisvl.index import SearchIndex\n", "\n", "\n", "index_name = \"redisvl\"\n", "\n", "\n", "schema = {\n", " \"index\": {\n", " \"name\": index_name,\n", " \"prefix\": \"chunk\"\n", " },\n", " \"fields\": [\n", " {\n", " \"name\": \"chunk_id\",\n", " \"type\": \"tag\",\n", " \"attrs\": {\n", " \"sortable\": True\n", " }\n", " },\n", " {\n", " \"name\": \"proposition\",\n", " \"type\": \"text\"\n", " },\n", " {\n", " \"name\": \"text_embedding\",\n", " \"type\": \"vector\",\n", " \"attrs\": {\n", " \"dims\": hf.dims,\n", " \"distance_metric\": \"cosine\",\n", " \"algorithm\": \"hnsw\",\n", " \"datatype\": \"float32\"\n", " }\n", " }\n", " ]\n", "}" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "15:52:01 redisvl.index.index INFO Index already exists, overwriting.\n" ] } ], "source": [ "# create an index from schema and the client\n", "index = SearchIndex.from_dict(schema, redis_url=REDIS_URL)\n", "index.create(overwrite=True, drop=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "C70C-UWj1ujA" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Index Information:\n", "╭──────────────┬────────────────┬────────────┬─────────────────┬────────────╮\n", "│ Index Name │ Storage Type │ Prefixes │ Index Options │ Indexing │\n", "├──────────────┼────────────────┼────────────┼─────────────────┼────────────┤\n", "│ redisvl │ HASH │ ['chunk'] │ [] │ 0 │\n", "╰──────────────┴────────────────┴────────────┴─────────────────┴────────────╯\n", "Index Fields:\n", "╭────────────────┬────────────────┬────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", "├────────────────┼────────────────┼────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", "│ chunk_id │ chunk_id │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │ │ │ │ │\n", "│ proposition │ proposition │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │ │ │ │ │\n", "│ text_embedding │ text_embedding │ VECTOR │ algorithm │ HNSW │ data_type │ FLOAT32 │ dim │ 384 │ distance_metric │ COSINE │ M │ 16 │ ef_construction │ 200 │\n", "╰────────────────┴────────────────┴────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" ] } ], "source": [ "# get info about the index\n", "# NBVAL_SKIP\n", "!rvl index info -i redisvl" ] }, { "cell_type": "markdown", "metadata": { "id": "Qrj-jeGmBRTL" }, "source": [ "### Process and load dataset\n", "Below we use the RedisVL index to simply load the list of document chunks to Redis db." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "Zsg09Keg1ujA" }, "outputs": [], "source": [ "# load expects an iterable of dictionaries\n", "from redisvl.redis.utils import array_to_buffer\n", "\n", "data = [\n", " {\n", " 'chunk_id': f'{i}',\n", " 'proposition': proposition,\n", " # For HASH -- must convert embeddings to bytes\n", " 'text_embedding': array_to_buffer(prop_embeddings[i], dtype=\"float32\")\n", " } for i, proposition in enumerate(propositions)\n", "]\n", "\n", "# RedisVL handles batching automatically\n", "keys = index.load(data, id_field=\"chunk_id\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup RedisVL AsyncSearchIndex" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from redisvl.index import AsyncSearchIndex\n", "\n", "index = AsyncSearchIndex.from_dict(schema, redis_url=REDIS_URL)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Test the updated RAG workflow" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from redisvl.query import VectorQuery\n", "from redisvl.index import AsyncSearchIndex\n", "\n", "\n", "def promptify(query: str, context: str) -> str:\n", " return f'''Use the provided context below derived from public financial\n", " documents to answer the user's question. If you can't answer the user's\n", " question, based on the context; do not guess. If there is no context at all,\n", " respond with \"I don't know\".\n", "\n", " User question:\n", "\n", " {query}\n", "\n", " Helpful context:\n", "\n", " {context}\n", "\n", " Answer:\n", " '''\n", "\n", "# Update the retrieval helper to use propositions\n", "async def retrieve_context(index: AsyncSearchIndex, query_vector) -> str:\n", " \"\"\"Fetch the relevant context from Redis using vector search\"\"\"\n", " print(\"Using dense content representation\", flush=True)\n", " results = await index.query(\n", " VectorQuery(\n", " vector=query_vector,\n", " vector_field_name=\"text_embedding\",\n", " return_fields=[\"proposition\"],\n", " num_results=3\n", " )\n", " )\n", " content = \"\\n\".join([result[\"proposition\"] for result in results])\n", " return content\n", "\n", "# Update the answer_question method\n", "async def answer_question(index: AsyncSearchIndex, query: str):\n", " \"\"\"Answer the user's question\"\"\"\n", "\n", " SYSTEM_PROMPT = \"\"\"You are a helpful financial analyst assistant that has access\n", " to public financial 10k documents in order to answer users questions about company\n", " performance, ethics, characteristics, and core information.\n", " \"\"\"\n", "\n", " query_vector = hf.embed(query)\n", " # Fetch context from Redis using vector search\n", " context = await retrieve_context(index, query_vector)\n", " # Generate contextualized prompt and feed to OpenAI\n", " response = await openai.AsyncClient().chat.completions.create(\n", " model=CHAT_MODEL,\n", " messages=[\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": promptify(query, context)}\n", " ],\n", " temperature=0.1,\n", " seed=42\n", " )\n", " # Response provided by LLM\n", " return response.choices[0].message.content" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Generate a list of questions\n", "questions = [\n", " \"What is the trend in the company's revenue and profit over the past few years?\",\n", " \"What are the company's primary revenue sources?\",\n", " \"How much debt does the company have, and what are its capital expenditure plans?\",\n", " \"What does the company say about its environmental, social, and governance (ESG) practices?\",\n", " \"What is the company's strategy for growth?\"\n", "]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using dense content representation\n", "Using dense content representation\n", "Using dense content representation\n", "Using dense content representation\n", "Using dense content representation\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
questionanswer
0What is the trend in the company's revenue and...The company experienced revenue growth in fisc...
1What are the company's primary revenue sources?The company's primary revenue sources are from...
2How much debt does the company have, and what ...As of May 31, 2023, the company had Long-term ...
3What does the company say about its environmen...The company acknowledges the importance of env...
4What is the company's strategy for growth?The company's strategy for growth includes ide...
\n", "
" ], "text/plain": [ " question \\\n", "0 What is the trend in the company's revenue and... \n", "1 What are the company's primary revenue sources? \n", "2 How much debt does the company have, and what ... \n", "3 What does the company say about its environmen... \n", "4 What is the company's strategy for growth? \n", "\n", " answer \n", "0 The company experienced revenue growth in fisc... \n", "1 The company's primary revenue sources are from... \n", "2 As of May 31, 2023, the company had Long-term ... \n", "3 The company acknowledges the importance of env... \n", "4 The company's strategy for growth includes ide... " ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import asyncio\n", "import pandas as pd\n", "\n", "results = await asyncio.gather(*[\n", " answer_question(index, question) for question in questions\n", "])\n", "\n", "pd.DataFrame(columns=[\"question\", \"answer\"], data=list(zip(questions, results)))" ] }, { "cell_type": "markdown", "metadata": { "id": "TnkK0NwIIM9q" }, "source": [ "### Improve accuracy with query rewriting / expansion\n", "\n", "We can also use the power on an LLM to rewrite or expand an input question.\n", "\n", "Example: https://github.com/langchain-ai/langchain/blob/master/templates/rewrite-retrieve-read/rewrite_retrieve_read/chain.py" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using dense content representation\n" ] }, { "data": { "text/plain": [ "\"Based on the provided context, we can see that the company in question is NIKE, Inc. The company has a significant presence globally with subsidiaries in various jurisdictions such as Delaware, Netherlands, China, Mexico, Missouri, Japan, Korea, and Oregon. Additionally, the company's total revenues are substantial, with revenues in the United States alone amounting to $22,007 million in the fiscal year ended May 31, 2023. NIKE, Inc. also has a diverse range of financial assets, accounts receivable, inventories, and property, plant, and equipment across different regions, indicating a large and well-established company.\"" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# An example question that is a bit simplistic...\n", "await answer_question(index, \"How big is the company?\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "Tg55HqLFIRXJ" }, "outputs": [], "source": [ "async def rewrite_query(query: str, prompt: str = None):\n", " \"\"\"Rewrite the user's original query\"\"\"\n", "\n", " SYSTEM_PROMPT = prompt if prompt else \"\"\"Given the user's input question below, find a better or\n", " more complete way to phrase this question in order to improve semantic search\n", " engine retrieval quality over a set of SEC 10K PDF docs. Return the rephrased\n", " question as a string in a JSON response under the key \"query\".\"\"\"\n", "\n", " response = await openai.AsyncClient().chat.completions.create(\n", " model=CHAT_MODEL,\n", " response_format={ \"type\": \"json_object\" },\n", " messages=[\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": f\"Original input question from user: {query}\"}\n", " ],\n", " temperature=0.1,\n", " seed=42\n", " )\n", " # Response provided by LLM\n", " rewritten_query = json.loads(response.choices[0].message.content)[\"query\"]\n", " return rewritten_query" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'What is the size of the company in terms of revenue, assets, and market capitalization?'" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Example Sinple Query Rewritten\n", "await rewrite_query(\"How big is the company?\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "9ubNQrJOYL42" }, "outputs": [], "source": [ "async def answer_question(index: AsyncSearchIndex, query: str, **kwargs):\n", " \"\"\"Answer the user's question\"\"\"\n", "\n", " SYSTEM_PROMPT = \"\"\"You are a helpful financial analyst assistant that has access\n", " to public financial 10k documents in order to answer users questions about company\n", " performance, ethics, characteristics, and core information.\n", " \"\"\"\n", "\n", " # Rewrite the query using an LLM\n", " rewritten_query = await rewrite_query(query, **kwargs)\n", " print(\"User query updated to:\\n\", rewritten_query, flush=True)\n", "\n", " query_vector = hf.embed(rewritten_query)\n", " # Fetch context from Redis using vector search\n", " context = await retrieve_context(index, query_vector)\n", " print(\"Context retrieved\", flush=True)\n", "\n", " # Generate contextualized prompt and feed to OpenAI\n", " response = await openai.AsyncClient().chat.completions.create(\n", " model=CHAT_MODEL,\n", " messages=[\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": promptify(rewritten_query, context)}\n", " ],\n", " temperature=0.1,\n", " seed=42\n", " )\n", " # Response provided by LLM\n", " return response.choices[0].message.content" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "BIO_jW6KYsMU" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "User query updated to:\n", " What is the size of the company in terms of revenue, assets, and market capitalization?\n", "Using dense content representation\n", "Context retrieved\n" ] }, { "data": { "text/plain": [ "\"Based on the provided context, the company's revenue, assets, and market capitalization figures are not explicitly mentioned. The information mainly focuses on financial assets, investments, return on invested capital, EBIT, and other financial metrics. Without specific details on revenue, assets, and market capitalization, I am unable to provide the exact size of the company in those terms.\"" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# NBVAL_SKIP\n", "# Now try again with query re-writing enabled\n", "await answer_question(index, \"How big is the company?\")" ] }, { "cell_type": "markdown", "metadata": { "id": "p97uL4g9T6LQ" }, "source": [ "### Improve performance and cut costs with LLM caching" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "7geEAsYST6LQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "15:53:30 redisvl.index.index INFO Index already exists, not overwriting.\n" ] } ], "source": [ "from redisvl.extensions.llmcache import SemanticCache\n", "\n", "llmcache = SemanticCache(\n", " name=\"llmcache\",\n", " vectorizer=hf,\n", " redis_url=REDIS_URL,\n", " ttl=120,\n", " distance_threshold=0.2\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "id": "1ALcQXAqT6LQ" }, "outputs": [], "source": [ "from functools import wraps\n", "\n", "# Create an LLM caching decorator\n", "def cache(func):\n", " @wraps(func)\n", " async def wrapper(index, query_text, *args, **kwargs):\n", " query_vector = llmcache._vectorizer.embed(query_text)\n", "\n", " # Check the cache with the vector\n", " if result := llmcache.check(vector=query_vector):\n", " return result[0]['response']\n", "\n", " response = await func(index, query_text, query_vector=query_vector)\n", " llmcache.store(query_text, response, query_vector)\n", " return response\n", " return wrapper\n", "\n", "\n", "@cache\n", "async def answer_question(index: AsyncSearchIndex, query: str, **kwargs):\n", " \"\"\"Answer the user's question\"\"\"\n", "\n", " SYSTEM_PROMPT = \"\"\"You are a helpful financial analyst assistant that has access\n", " to public financial 10k documents in order to answer users questions about company\n", " performance, ethics, characteristics, and core information.\n", " \"\"\"\n", "\n", " context = await retrieve_context(index, kwargs[\"query_vector\"])\n", " response = await openai.AsyncClient().chat.completions.create(\n", " model=CHAT_MODEL,\n", " messages=[\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"user\", \"content\": promptify(query, context)}\n", " ],\n", " temperature=0.1,\n", " seed=42\n", " )\n", " # Response provided by GPT-3.5\n", " return response.choices[0].message.content" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "BXK_BXuhT6LQ" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using dense content representation\n" ] }, { "data": { "text/plain": [ "\"Nike's total revenue for the fiscal year 2023 was $27.4 billion from sales to wholesale customers and $21.3 billion through direct-to-consumer channels. Comparing this to the previous year, the total revenue for the fiscal year 2022 was not explicitly provided in the context.\"" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# NBVAL_SKIP\n", "query = \"What was Nike's revenue last year compared to this year??\"\n", "\n", "await answer_question(index, query)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "7mZpSpf9T6LQ" }, "outputs": [ { "data": { "text/plain": [ "\"Nike's total revenue for the fiscal year 2023 was $27.4 billion from sales to wholesale customers and $21.3 billion through direct-to-consumer channels. Comparing this to the previous year, the total revenue for the fiscal year 2022 was not explicitly provided in the context.\"" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# NBVAL_SKIP\n", "query = \"What was Nike's total revenue in the last year compared to now??\"\n", "\n", "await answer_question(index, query)\n", "\n", "# notice no HTTP request to OpenAI since this question is \"close enough\" to the last one" ] }, { "cell_type": "markdown", "metadata": { "id": "UaiF_ws7itsi" }, "source": [ "### Improve personalization with including chat session history\n", "\n", "In order to preserve state in the conversation, it's imperitive to offload conversation history to a database that can handle high transaction throughput for writes/reads to limit system latency.\n", "\n", "We can store message history for a particular user session in a Redis List data type.\n" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "WMOF7fJQdhgN" }, "outputs": [], "source": [ "import json\n", "\n", "\n", "class ChatBot:\n", " def __init__(self, index: AsyncSearchIndex, user: str):\n", " self.index = index\n", " self.user = user\n", "\n", " async def get_messages(self) -> list:\n", " \"\"\"Get all messages associated with a session\"\"\"\n", " return [\n", " json.loads(msg) for msg in await self.index.client.lrange(f\"messages:{self.user}\", 0, -1)\n", " ]\n", "\n", " async def add_messages(self, messages: list):\n", " \"\"\"Add chat messages to a Redis List\"\"\"\n", " return await self.index.client.rpush(\n", " f\"messages:{self.user}\", *[json.dumps(msg) for msg in messages]\n", " )\n", "\n", " async def clear_history(self):\n", " \"\"\"Clear session chat\"\"\"\n", " await index.client.delete(f\"messages:{self.user}\")\n", "\n", " @staticmethod\n", " def promptify(query: str, context: str) -> str:\n", " return f'''Use the provided context below derived from public financial\n", " documents to answer the user's question. If you can't answer the user's\n", " question, based on the context; do not guess. If there is no context at all,\n", " respond with \"I don't know\".\n", "\n", " User question:\n", "\n", " {query}\n", "\n", " Helpful context:\n", "\n", " {context}\n", "\n", " Answer:\n", " '''\n", "\n", " async def retrieve_context(self, query_vector) -> str:\n", " \"\"\"Fetch the relevant context from Redis using vector search\"\"\"\n", " results = await self.index.query(\n", " VectorQuery(\n", " vector=query_vector,\n", " vector_field_name=\"text_embedding\",\n", " return_fields=[\"proposition\"],\n", " num_results=3\n", " )\n", " )\n", " content = \"\\n\".join([result[\"proposition\"] for result in results])\n", " return content\n", "\n", " async def answer_question(self, query: str):\n", " \"\"\"Answer the user's question with historical context and caching baked-in\"\"\"\n", "\n", " SYSTEM_PROMPT = \"\"\"You are a helpful financial analyst assistant that has access\n", " to public financial 10k documents in order to answer users questions about company\n", " performance, ethics, characteristics, and core information.\n", " \"\"\"\n", "\n", " # Create query vector\n", " query_vector = llmcache._vectorizer.embed(query)\n", "\n", " # TODO - implement semantic gaurdrails?\n", "\n", " # Check the cache with the vector\n", " if result := llmcache.check(vector=query_vector):\n", " answer = result[0]['response']\n", " else:\n", " # TODO - implement query rewriting?\n", " context = await self.retrieve_context(query_vector)\n", " session = await self.get_messages()\n", " # TODO - implement session summarization?\n", " messages = (\n", " [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}] +\n", " session +\n", " [{\"role\": \"user\", \"content\": self.promptify(query, context)}]\n", " )\n", " # Response provided by GPT-3.5\n", " response = await openai.AsyncClient().chat.completions.create(\n", " model=CHAT_MODEL,\n", " messages=messages,\n", " temperature=0.1,\n", " seed=42\n", " )\n", " answer = response.choices[0].message.content\n", " llmcache.store(query, answer, query_vector)\n", "\n", " # Add message history\n", " await self.add_messages([\n", " {\"role\": \"user\", \"content\": query},\n", " {\"role\": \"assistant\", \"content\": answer}\n", " ])\n", "\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test the entire RAG workflow" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "_Z3RUvyxdhiz" }, "outputs": [], "source": [ "# Setup Session\n", "chat = ChatBot(index, \"tyler\")\n", "await chat.clear_history()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run a simple chat\n", "stopterms = [\"exit\", \"quit\", \"end\", \"cancel\"]\n", "\n", "# Simple Chat\n", "# NBVAL_SKIP\n", "while True:\n", " user_query = input()\n", " if user_query.lower() in stopterms:\n", " break\n", " answer = await chat.answer_question(user_query)\n", " print(answer, flush=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZoPQMAShZ5Uy" }, "outputs": [], "source": [ "# NBVAL_SKIP\n", "await chat.get_messages()" ] }, { "cell_type": "markdown", "metadata": { "id": "5l4uEgKzljes" }, "source": [ "## Your Next Steps\n", "\n", "While a good start, there is still more to do. **For example**:\n", "- we could utilize message history to generate an updated and contextualized query to use for retrieval and answer generation (with an LLM). Otherwise, there can be a disconnect between what a user is asking (in context) and what they are asking in isolation.\n", "- we could utilize an LLM to summarize conversation history to use as context instead of passing the whole slew of messages to the Chat endpoint.\n", "- we could utilize semantic properties of the message history (or summaries) in order to fetch only relevant conversation bits (vector search).\n", "- we could utilize a technique like HyDE ( a form of query rewriting ) to improve the retrieval quality from raw user input to source documents OR try to break down user questions into sub questions and fetch / join context based on the different searces.\n", "- we could incorporate semantic routing to take a broken down question and route to different data sources, indices, or query types (etc).\n", "- we could add semantic guardrails on the front end or back end of the conversation I/O to ensure we are within bounds of approved topics." ] }, { "cell_type": "markdown", "metadata": { "id": "Wscs4Mvo1ujD" }, "source": [ "## Cleanup\n", "\n", "Clean up the database." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "On6yNuQn1ujD" }, "outputs": [], "source": [ "await index.client.flushall()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 0 }