Skip to content

Commit e82e26a

Browse files
committed
chore: get method of pg vectorstore exact same as chroma
1 parent 20e2bfc commit e82e26a

File tree

6 files changed

+154
-71
lines changed

6 files changed

+154
-71
lines changed

examples/pg_vectorstore.ipynb

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,22 @@
258258
"\n",
259259
"docs = [\n",
260260
" Document(\n",
261-
" id=uuid.uuid4(),\n",
261+
" id=str(uuid.uuid4()),\n",
262262
" page_content=\"there are cats in the pond\",\n",
263263
" metadata={\"likes\": 1, \"location\": \"pond\", \"topic\": \"animals\"},\n",
264264
" ),\n",
265265
" Document(\n",
266-
" id=uuid.uuid4(),\n",
266+
" id=str(uuid.uuid4()),\n",
267267
" page_content=\"ducks are also found in the pond\",\n",
268268
" metadata={\"likes\": 30, \"location\": \"pond\", \"topic\": \"animals\"},\n",
269269
" ),\n",
270270
" Document(\n",
271-
" id=uuid.uuid4(),\n",
271+
" id=str(uuid.uuid4()),\n",
272272
" page_content=\"fresh apples are available at the market\",\n",
273273
" metadata={\"likes\": 20, \"location\": \"market\", \"topic\": \"food\"},\n",
274274
" ),\n",
275275
" Document(\n",
276-
" id=uuid.uuid4(),\n",
276+
" id=str(uuid.uuid4()),\n",
277277
" page_content=\"the market also sells fresh oranges\",\n",
278278
" metadata={\"likes\": 5, \"location\": \"market\", \"topic\": \"food\"},\n",
279279
" ),\n",
@@ -287,9 +287,9 @@
287287
"cell_type": "markdown",
288288
"metadata": {},
289289
"source": [
290-
"### Get documents\n",
290+
"## Get collection\n",
291291
"\n",
292-
"Get documents from the vectorstore using filters and parameters."
292+
"Get collection from the vectorstore using filters and parameters."
293293
]
294294
},
295295
{
@@ -298,13 +298,11 @@
298298
"metadata": {},
299299
"outputs": [],
300300
"source": [
301-
"documents_with_apple = await vectorstore.aget({\"content\": {\"$ilike\": \"%apple%\"}})\n",
302-
"first_three_documents = await vectorstore.aget(limit=3)\n",
303-
"rest_of_documents = await vectorstore.aget(limit=5, offset=3)\n",
301+
"documents_with_apple = await vectorstore.aget(where_document={\"$ilike\": \"%apple%\"}, include=\"documents\")\n",
302+
"paginated_ids = await vectorstore.aget(limit=3, offset=3)\n",
304303
"\n",
305-
"print([doc.page_content for doc in documents_with_apple])\n",
306-
"print([doc.page_content for doc in first_three_documents])\n",
307-
"print([doc.page_content for doc in rest_of_documents])"
304+
"print(documents_with_apple[\"documents\"])\n",
305+
"print(paginated_ids[\"ids\"])"
308306
]
309307
},
310308
{

examples/pg_vectorstore_how_to.ipynb

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,9 @@
331331
"cell_type": "markdown",
332332
"metadata": {},
333333
"source": [
334-
"### Get documents\n",
334+
"### Get collection\n",
335335
"\n",
336-
"Get documents from the vectorstore using filters and parameters."
336+
"Get collection from the vectorstore using filters and parameters."
337337
]
338338
},
339339
{
@@ -342,13 +342,11 @@
342342
"metadata": {},
343343
"outputs": [],
344344
"source": [
345-
"documents_with_apple = await store.aget({\"content\": {\"$ilike\": \"%apple%\"}})\n",
346-
"first_three_documents = await store.aget(limit=3)\n",
347-
"rest_of_documents = await store.aget(limit=5, offset=3)\n",
345+
"documents_with_apple = await store.aget(where_document={\"$ilike\": \"%apple%\"}, include=\"documents\")\n",
346+
"paginated_ids = await store.aget(limit=3, offset=3)\n",
348347
"\n",
349-
"print([doc.page_content for doc in documents_with_apple])\n",
350-
"print([doc.page_content for doc in first_three_documents])\n",
351-
"print([doc.page_content for doc in rest_of_documents])"
348+
"print(documents_with_apple[\"documents\"])\n",
349+
"print(paginated_ids[\"ids\"])"
352350
]
353351
},
354352
{

langchain_postgres/v2/async_vectorstore.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -678,18 +678,11 @@ async def __query_collection_with_filter(
678678
limit: Optional[int] = None,
679679
offset: Optional[int] = None,
680680
filter: Optional[dict] = None,
681+
columns: Optional[list[str]] = None,
681682
**kwargs: Any,
682683
) -> Sequence[RowMapping]:
683684
"""Asynchronously query the database collection using filters and parameters and return matching rows."""
684685

685-
columns = [
686-
self.id_column,
687-
self.content_column,
688-
self.embedding_column,
689-
] + self.metadata_columns
690-
if self.metadata_json_column:
691-
columns.append(self.metadata_json_column)
692-
693686
column_names = ", ".join(f'"{col}"' for col in columns)
694687

695688
safe_filter = None
@@ -1037,35 +1030,68 @@ async def is_valid_index(
10371030

10381031
async def aget(
10391032
self,
1040-
filter: Optional[dict] = None,
1033+
ids: Optional[Sequence[str]] = None,
1034+
where: Optional[dict] = None,
10411035
limit: Optional[int] = None,
10421036
offset: Optional[int] = None,
1037+
where_document: Optional[dict] = None,
1038+
include: Optional[list[str]] = None,
10431039
**kwargs: Any,
1044-
) -> list[Document]:
1040+
) -> dict[str, Any]:
10451041
"""Retrieve documents from the collection using filters and parameters."""
1042+
filter = {}
1043+
if ids:
1044+
filter.update({self.id_column: {"$in": ids}})
1045+
if where:
1046+
filter.update(where)
1047+
if where_document:
1048+
filter.update({self.content_column: where_document})
1049+
1050+
if include is None:
1051+
include = ["metadatas", "documents"]
1052+
1053+
fields_mapping = {
1054+
"embeddings": [self.embedding_column],
1055+
"metadatas": self.metadata_columns + [self.metadata_json_column]
1056+
if self.metadata_json_column
1057+
else self.metadata_columns,
1058+
"documents": [self.content_column],
1059+
}
1060+
1061+
included_fields = ["ids"]
1062+
columns = [self.id_column]
1063+
1064+
for field, cols in fields_mapping.items():
1065+
if field in include:
1066+
included_fields.append(field)
1067+
columns.extend(cols)
10461068

10471069
results = await self.__query_collection_with_filter(
1048-
limit=limit, offset=offset, filter=filter, **kwargs
1070+
limit=limit, offset=offset, filter=filter, columns=columns, **kwargs
10491071
)
10501072

1051-
documents = []
1073+
final_results = {field: [] for field in included_fields}
1074+
10521075
for row in results:
1053-
metadata = (
1054-
row[self.metadata_json_column]
1055-
if self.metadata_json_column and row[self.metadata_json_column]
1056-
else {}
1057-
)
1058-
for col in self.metadata_columns:
1059-
metadata[col] = row[col]
1060-
documents.append(
1061-
Document(
1062-
page_content=row[self.content_column],
1063-
metadata=metadata,
1064-
id=str(row[self.id_column]),
1065-
),
1066-
)
1076+
final_results["ids"].append(str(row[self.id_column]))
10671077

1068-
return documents
1078+
if "metadatas" in final_results:
1079+
metadata = (
1080+
row.get(self.metadata_json_column) or {}
1081+
if self.metadata_json_column
1082+
else {}
1083+
)
1084+
for col in self.metadata_columns:
1085+
metadata[col] = row[col]
1086+
final_results["metadatas"].append(metadata)
1087+
1088+
if "documents" in final_results:
1089+
final_results["documents"].append(row[self.content_column])
1090+
1091+
if "embeddings" in final_results:
1092+
final_results["embeddings"].append(row[self.embedding_column])
1093+
1094+
return final_results
10691095

10701096
async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
10711097
"""Get documents by ids."""
@@ -1323,11 +1349,14 @@ def _create_filter_clause(self, filters: Any) -> tuple[str, dict]:
13231349

13241350
def get(
13251351
self,
1326-
filter: Optional[dict] = None,
1352+
ids: Optional[Sequence[str]] = None,
1353+
where: Optional[dict] = None,
13271354
limit: Optional[int] = None,
13281355
offset: Optional[int] = None,
1356+
where_document: Optional[dict] = None,
1357+
include: Optional[list[str]] = None,
13291358
**kwargs: Any,
1330-
) -> list[Document]:
1359+
) -> dict[str, Any]:
13311360
raise NotImplementedError(
13321361
"Sync methods are not implemented for AsyncPGVectorStore. Use PGVectorStore interface instead."
13331362
)

langchain_postgres/v2/vectorstores.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -877,26 +877,48 @@ def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
877877

878878
async def aget(
879879
self,
880-
filter: Optional[dict] = None,
880+
ids: Optional[Sequence[str]] = None,
881+
where: Optional[dict] = None,
881882
limit: Optional[int] = None,
882883
offset: Optional[int] = None,
884+
where_document: Optional[dict] = None,
885+
include: Optional[list[str]] = None,
883886
**kwargs: Any,
884-
) -> list[Document]:
887+
) -> dict[str, Any]:
885888
"""Retrieve documents from the collection using filters and parameters."""
886889
return await self._engine._run_as_async(
887-
self.__vs.aget(filter=filter, limit=limit, offset=offset, **kwargs)
890+
self.__vs.aget(
891+
ids=ids,
892+
where=where,
893+
limit=limit,
894+
offset=offset,
895+
where_document=where_document,
896+
include=include,
897+
**kwargs,
898+
)
888899
)
889900

890901
def get(
891902
self,
892-
filter: Optional[dict] = None,
903+
ids: Optional[Sequence[str]] = None,
904+
where: Optional[dict] = None,
893905
limit: Optional[int] = None,
894906
offset: Optional[int] = None,
907+
where_document: Optional[dict] = None,
908+
include: Optional[list[str]] = None,
895909
**kwargs: Any,
896-
) -> list[Document]:
910+
) -> dict[str, Any]:
897911
"""Retrieve documents from the collection using filters and parameters."""
898912
return self._engine._run_as_sync(
899-
self.__vs.aget(filter=filter, limit=limit, offset=offset, **kwargs)
913+
self.__vs.aget(
914+
ids=ids,
915+
where=where,
916+
limit=limit,
917+
offset=offset,
918+
where_document=where_document,
919+
include=include,
920+
**kwargs,
921+
)
900922
)
901923

902924
def get_table_name(self) -> str:

tests/unit_tests/v2/test_async_pg_vectorstore_search.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,24 @@ async def test_vectorstore_with_metadata_filters(
370370
)
371371
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter
372372

373+
async def test_async_vectorstore_get_ids(
374+
self,
375+
vs_custom_filter: AsyncPGVectorStore
376+
) -> None:
377+
"""Test end to end construction and filter."""
378+
379+
res = await vs_custom_filter.aget(ids=ids[:2])
380+
assert set(res["ids"]) == set(ids[:2])
381+
382+
async def test_async_vectorstore_get_docs(
383+
self,
384+
vs_custom_filter: AsyncPGVectorStore
385+
) -> None:
386+
"""Test end to end construction and filter."""
387+
388+
res = await vs_custom_filter.aget(where_document={"$in": texts[:2]})
389+
assert set(res["documents"]) == set(texts[:2])
390+
373391
@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
374392
async def test_vectorstore_get(
375393
self,
@@ -378,8 +396,8 @@ async def test_vectorstore_get(
378396
expected_ids: list[str],
379397
) -> None:
380398
"""Test end to end construction and filter."""
381-
docs = await vs_custom_filter.aget(test_filter)
382-
assert set([doc.metadata["code"] for doc in docs]) == set(expected_ids), (
399+
res = await vs_custom_filter.aget(where=test_filter)
400+
assert set([r["code"] for r in res["metadatas"]]) == set(expected_ids), (
383401
test_filter
384402
)
385403

@@ -389,14 +407,14 @@ async def test_vectorstore_get_limit_offset(
389407
) -> None:
390408
"""Test limit and offset parameters of get method"""
391409

392-
all_docs = await vs_custom_filter.aget()
393-
docs_from_combining = (
394-
(await vs_custom_filter.aget(limit=1))
395-
+ (await vs_custom_filter.aget(limit=1, offset=1))
396-
+ (await vs_custom_filter.aget(offset=2))
410+
all_ids = (await vs_custom_filter.aget())["ids"]
411+
ids_from_combining = (
412+
(await vs_custom_filter.aget(limit=1))["ids"]
413+
+ (await vs_custom_filter.aget(limit=1, offset=1))["ids"]
414+
+ (await vs_custom_filter.aget(offset=2))["ids"]
397415
)
398416

399-
assert all_docs == docs_from_combining
417+
assert all_ids == ids_from_combining
400418

401419
async def test_asimilarity_hybrid_search(self, vs: AsyncPGVectorStore) -> None:
402420
results = await vs.asimilarity_search(

tests/unit_tests/v2/test_pg_vectorstore_search.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,24 @@ def test_sync_vectorstore_with_metadata_filters(
429429
docs = vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter)
430430
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter
431431

432+
def test_sync_vectorstore_get_ids(
433+
self,
434+
vs_custom_filter_sync: PGVectorStore
435+
) -> None:
436+
"""Test end to end construction and filter."""
437+
438+
res = vs_custom_filter_sync.get(ids=ids[:2])
439+
assert set(res["ids"]) == set(ids[:2])
440+
441+
def test_sync_vectorstore_get_docs(
442+
self,
443+
vs_custom_filter_sync: PGVectorStore
444+
) -> None:
445+
"""Test end to end construction and filter."""
446+
447+
res = vs_custom_filter_sync.get(where_document={"$in": texts[:2]})
448+
assert set(res["documents"]) == set(texts[:2])
449+
432450
@pytest.mark.parametrize("test_filter, expected_ids", FILTERING_TEST_CASES)
433451
def test_sync_vectorstore_get(
434452
self,
@@ -438,8 +456,8 @@ def test_sync_vectorstore_get(
438456
) -> None:
439457
"""Test end to end construction and filter."""
440458

441-
docs = vs_custom_filter_sync.get(filter=test_filter)
442-
assert set([doc.metadata["code"] for doc in docs]) == set(expected_ids), (
459+
res = vs_custom_filter_sync.get(where=test_filter)
460+
assert set([r["code"] for r in res["metadatas"]]) == set(expected_ids), (
443461
test_filter
444462
)
445463

@@ -449,14 +467,14 @@ def test_sync_vectorstore_get_limit_offset(
449467
) -> None:
450468
"""Test limit and offset parameters of get method"""
451469

452-
all_docs = vs_custom_filter_sync.get()
453-
docs_from_combining = (
454-
vs_custom_filter_sync.get(limit=1)
455-
+ vs_custom_filter_sync.get(limit=1, offset=1)
456-
+ vs_custom_filter_sync.get(offset=2)
470+
all_ids = vs_custom_filter_sync.get()["ids"]
471+
ids_from_combining = (
472+
vs_custom_filter_sync.get(limit=1)["ids"]
473+
+ vs_custom_filter_sync.get(limit=1, offset=1)["ids"]
474+
+ vs_custom_filter_sync.get(offset=2)["ids"]
457475
)
458476

459-
assert all_docs == docs_from_combining
477+
assert all_ids == ids_from_combining
460478

461479
@pytest.mark.parametrize("test_filter", NEGATIVE_TEST_CASES)
462480
def test_metadata_filter_negative_tests(

0 commit comments

Comments
 (0)