Skip to content

feat: MSSQL document loader and saver #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: doc loader and saver
  • Loading branch information
loeng2023 committed Feb 12, 2024
commit 4d0dbc892e83e147622707257b01d63256848528
12 changes: 10 additions & 2 deletions src/langchain_google_cloud_sql_mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
MSSQLChatMessageHistory,
)
from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine
from langchain_google_cloud_sql_mssql.mssql_loader import MSSQLLoader
from langchain_google_cloud_sql_mssql.mssql_loader import (
MSSQLDocumentSaver,
MSSQLLoader,
)

__all__ = ["MSSQLChatMessageHistory", "MSSQLEngine", "MSSQLLoader"]
__all__ = [
"MSSQLChatMessageHistory",
"MSSQLEngine",
"MSSQLLoader",
"MSSQLDocumentSaver",
]
52 changes: 52 additions & 0 deletions src/langchain_google_cloud_sql_mssql/mssql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,55 @@ def create_chat_history_table(self, table_name: str) -> None:
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(create_table_query))
conn.commit()

def init_document_table(
self,
table_name: str,
metadata_columns: List[sqlalchemy.Column] = [],
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of langchain documents.

Args:
table_name (str): The MySQL database table name.
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
to create for custom metadata. Optional.
store_metadata (bool): Whether to store extra metadata in a metadata column
if not described in 'metadata' field list (Default: True).
"""
columns = [
sqlalchemy.Column(
"page_content",
sqlalchemy.UnicodeText,
primary_key=False,
nullable=False,
)
]
columns += metadata_columns
if store_metadata:
columns.append(
sqlalchemy.Column(
"langchain_metadata",
sqlalchemy.JSON,
primary_key=False,
nullable=True,
)
)
sqlalchemy.Table(table_name, sqlalchemy.MetaData(), *columns).create(
self.engine
)

def _load_document_table(self, table_name: str) -> sqlalchemy.Table:
"""
Load table schema from existing table in MySQL database.

Args:
table_name (str): The MySQL database table name.

Returns:
(sqlalchemy.Table): The loaded table.
"""
metadata = sqlalchemy.MetaData()
sqlalchemy.MetaData.reflect(metadata, bind=self.engine, only=[table_name])
return metadata.tables[table_name]
190 changes: 148 additions & 42 deletions src/langchain_google_cloud_sql_mssql/mssql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,48 @@
# limitations under the License.
import json
from collections.abc import Iterable
from typing import Any, List, Optional, Sequence
from typing import Any, Dict, Iterator, List, Optional, Sequence, cast

import pytds
import sqlalchemy
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document

from langchain_google_cloud_sql_mssql.mssql_engine import MSSQLEngine

DEFAULT_CONTENT_COL = "page_content"
DEFAULT_METADATA_COL = "langchain_metadata"


def _parse_doc_from_table(
content_columns: Iterable[str],
metadata_columns: Iterable[str],
column_names: Iterable[str],
rows: Sequence[Any],
) -> List[Document]:
docs = []
for row in rows:
page_content = " ".join(
str(getattr(row, column))
for column in content_columns
if column in column_names
)
metadata = {
column: getattr(row, column)
for column in metadata_columns
if column in column_names
}
if DEFAULT_METADATA_COL in metadata:
extra_metadata = json.loads(metadata[DEFAULT_METADATA_COL])
del metadata[DEFAULT_METADATA_COL]
metadata |= extra_metadata
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
return docs
def _parse_doc_from_row(
content_columns: Iterable[str], metadata_columns: Iterable[str], row: Dict
) -> Document:
page_content = " ".join(
str(row[column]) for column in content_columns if column in row
)
metadata: Dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if DEFAULT_METADATA_COL in metadata_columns and row.get(DEFAULT_METADATA_COL):
for k, v in row[DEFAULT_METADATA_COL].items():
metadata[k] = v
# load metadata from other columns
for column in metadata_columns:
if column in row and column != DEFAULT_METADATA_COL:
metadata[column] = row[column]
return Document(page_content=page_content, metadata=metadata)


def _parse_row_from_doc(column_names: Iterable[str], doc: Document) -> Dict:
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {DEFAULT_CONTENT_COL: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
del doc_metadata[entry]
# store extra metadata in langchain_metadata column in json format
if DEFAULT_METADATA_COL in column_names and len(doc_metadata) > 0:
row[DEFAULT_METADATA_COL] = doc_metadata
return row


class MSSQLLoader(BaseLoader):
Expand All @@ -57,51 +63,151 @@ class MSSQLLoader(BaseLoader):
def __init__(
self,
engine: MSSQLEngine,
query: str,
table_name: str = "",
query: str = "",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
):
"""
Document page content defaults to the first column present in the query or table and
metadata defaults to all other columns. Use with content_columns to overwrite the column
used for page content. Use metadata_columns to select specific metadata columns rather
than using all remaining columns.

If multiple content columns are specified, page_content’s string format will default to
space-separated string concatenation.

Args:
engine (MSSQLEngine): MSSQLEngine object to connect to the MSSQL database.
query (str): The query to execute in MSSQL format.
table_name (str): The MSSQL database table name. (OneOf: table_name, query).
query (str): The query to execute in MSSQL format. (OneOf: table_name, query).
content_columns (List[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
Optional.
"""
self.engine = engine
self.table_name = table_name
self.query = query
self.content_columns = content_columns
self.metadata_columns = metadata_columns
if not self.table_name and not self.query:
raise ValueError("One of 'table_name' or 'query' must be specified.")
if self.table_name and self.query:
raise ValueError(
"Cannot specify both 'table_name' and 'query'. Specify 'table_name' to load "
"entire table or 'query' to load a specific query."
)

def load(self) -> List[Document]:
"""
Load langchain documents from a Cloud SQL MSSQL database.

Document page content defaults to the first columns present in the query or table and
metadata defaults to all other columns. Use with content_columns to overwrite the column
used for page content. Use metadata_columns to select specific metadata columns rather
than using all remaining columns.
Returns:
(List[langchain_core.documents.Document]): a list of Documents with metadata from
specific columns.
"""
return list(self.lazy_load())

If multiple content columns are specified, page_content’s string format will default to
space-separated string concatenation.
def lazy_load(self) -> Iterator[Document]:
"""
Lazy Load langchain documents from a Cloud SQL MSSQL database. Use lazy load to avoid
caching all documents in memory at once.

Returns:
(List[langchain_core.documents.Document]): a list of Documents with metadata from
(Iterator[langchain_core.documents.Document]): a list of Documents with metadata from
specific columns.
"""
if self.query:
stmt = sqlalchemy.text(self.query)
else:
stmt = sqlalchemy.text(f'select * from "{self.table_name}";')
with self.engine.connect() as connection:
result_proxy = connection.execute(sqlalchemy.text(self.query))
result_proxy = connection.execute(stmt)
column_names = list(result_proxy.keys())
results = result_proxy.fetchall()
content_columns = self.content_columns or [column_names[0]]
metadata_columns = self.metadata_columns or [
col for col in column_names if col not in content_columns
]
return _parse_doc_from_table(
content_columns,
metadata_columns,
column_names,
results,
while True:
row = result_proxy.fetchone()
if not row:
break
# Handle default metadata field
row_data = {}
for column in column_names:
value = getattr(row, column)
if column == DEFAULT_METADATA_COL:
row_data[column] = json.loads(value)
else:
row_data[column] = value
yield _parse_doc_from_row(content_columns, metadata_columns, row_data)


class MSSQLDocumentSaver:
"""A class for saving langchain documents into a Cloud SQL MSSQL database table."""

def __init__(
self,
engine: MSSQLEngine,
table_name: str,
):
"""
MSSQLDocumentSaver allows for saving of langchain documents in a database. If the table
doesn't exists, a table with default schema will be created. The default schema:
- page_content (type: text)
- langchain_metadata (type: JSON)

Args:
engine: MSSQLEngine object to connect to the MSSQL database.
table_name: The name of table for saving documents.
"""
self.engine = engine
self.table_name = table_name
self._table = self.engine._load_document_table(table_name)
if DEFAULT_CONTENT_COL not in self._table.columns.keys():
raise ValueError(
f"Missing '{DEFAULT_CONTENT_COL}' field in table {table_name}."
)

def add_documents(self, docs: List[Document]) -> None:
"""
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
stored in langchain_metadata JSON column.

Args:
docs (List[langchain_core.documents.Document]): a list of documents to be saved.
"""
with self.engine.connect() as conn:
for doc in docs:
row = _parse_row_from_doc(self._table.columns.keys(), doc)
for k, v in row.items():
if type(v) == dict:
row[k] = json.dumps(v)
conn.execute(sqlalchemy.insert(self._table).values(row))
conn.commit()

def delete(self, docs: List[Document]) -> None:
"""
Delete all instances of a document from the DocumentSaver table by matching the entire Document
object.

Args:
docs (List[langchain_core.documents.Document]): a list of documents to be deleted.
"""
with self.engine.connect() as conn:
for doc in docs:
row = _parse_row_from_doc(self._table.columns.keys(), doc)
for k, v in row.items():
if type(v) == dict:
row[k] = json.dumps(v)
# delete by matching all fields of document
where_conditions = []
for col in self._table.columns:
where_conditions.append(col == row[col.name])
conn.execute(
sqlalchemy.delete(self._table).where(
sqlalchemy.and_(*where_conditions)
)
)
conn.commit()
Loading