Skip to content

Commit b181f65

Browse files
feat: add PostgreSQLEngine (#13)
* feat: add PostgreSQLEngine * lint * add header * Update src/langchain_google_cloud_sql_pg/postgresql_engine.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> * Update src/langchain_google_cloud_sql_pg/postgresql_engine.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> * Update src/langchain_google_cloud_sql_pg/postgresql_engine.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> * Update src/langchain_google_cloud_sql_pg/postgresql_engine.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> * Update src/langchain_google_cloud_sql_pg/postgresql_engine.py Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> * add json column * Update comments * fix * Update pyproject.toml * clean up * Update pyproject.toml * lint --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
1 parent 7334226 commit b181f65

5 files changed

Lines changed: 352 additions & 2 deletions

File tree

integration.cloudbuild.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,13 @@ steps:
2222
name: python:3.11
2323
entrypoint: python
2424
args: ["-m", "pytest"]
25+
env:
26+
- "PROJECT_ID=$PROJECT_ID"
27+
- "INSTANCE_ID=$_INSTANCE_ID"
28+
- "DATABASE_ID=$_DATABASE_ID"
29+
- "REGION=$_REGION"
30+
31+
substitutions:
32+
_DATABASE_USER: test-instance
33+
_DATABASE_ID: test-database
34+
_REGION: us-central1

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ license = {file = "LICENSE"}
77
requires-python = ">=3.8"
88
dependencies = [
99
"langchain==0.1.1",
10-
"SQLAlchemy==2.0.7",
11-
"cloud-sql-python-connector[asyncpg]==1.5.0"
10+
"SQLAlchemy>=2.0.25",
11+
"cloud-sql-python-connector[asyncpg]>=1.6.0",
12+
"pgvector>=0.2.5"
1213
]
1314

1415
[project.urls]

src/langchain_google_cloud_sql_pg/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from langchain_google_cloud_sql_pg.postgresql_engine import Column, PostgreSQLEngine
16+
17+
__all__ = ["PostgreSQLEngine", "Column"]
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
from dataclasses import dataclass
19+
from threading import Thread
20+
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, TypeVar
21+
22+
import aiohttp
23+
import google.auth # type: ignore
24+
import google.auth.transport.requests # type: ignore
25+
from google.cloud.sql.connector import Connector, create_async_connector
26+
from sqlalchemy import text # Column,
27+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
28+
29+
if TYPE_CHECKING:
30+
import asyncpg # type: ignore
31+
import google.auth.credentials # type: ignore
32+
33+
T = TypeVar("T")
34+
35+
36+
async def _get_iam_principal_email(
37+
credentials: google.auth.credentials.Credentials,
38+
) -> str:
39+
"""Get email address associated with current authenticated IAM principal.
40+
41+
Email will be used for automatic IAM database authentication to Cloud SQL.
42+
43+
Args:
44+
credentials (google.auth.credentials.Credentials):
45+
The credentials object to use in finding the associated IAM
46+
principal email address.
47+
48+
Returns:
49+
email (str):
50+
The email address associated with the current authenticated IAM
51+
principal.
52+
"""
53+
# refresh credentials if they are not valid
54+
if not credentials.valid:
55+
request = google.auth.transport.requests.Request()
56+
credentials.refresh(request)
57+
# call OAuth2 api to get IAM principal email associated with OAuth2 token
58+
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
59+
async with aiohttp.ClientSession() as client:
60+
response = await client.get(url, raise_for_status=True)
61+
response_json: Dict = await response.json()
62+
email = response_json.get("email")
63+
if email is None:
64+
raise ValueError(
65+
"Failed to automatically obtain authenticated IAM princpal's "
66+
"email address using environment's ADC credentials!"
67+
)
68+
return email
69+
70+
71+
@dataclass
72+
class Column:
73+
name: str
74+
data_type: str
75+
nullable: bool = True
76+
77+
78+
class PostgreSQLEngine:
79+
"""A class for managing connections to a Cloud SQL for Postgres database."""
80+
81+
_connector: Optional[Connector] = None
82+
83+
def __init__(
84+
self,
85+
engine: AsyncEngine,
86+
loop: Optional[asyncio.AbstractEventLoop],
87+
thread: Optional[Thread],
88+
):
89+
self._engine = engine
90+
self._loop = loop
91+
self._thread = thread
92+
93+
@classmethod
94+
def from_instance(
95+
cls,
96+
project_id: str,
97+
region: str,
98+
instance: str,
99+
database: str,
100+
) -> PostgreSQLEngine:
101+
# Running a loop in a background thread allows us to support
102+
# async methods from non-async enviroments
103+
loop = asyncio.new_event_loop()
104+
thread = Thread(target=loop.run_forever, daemon=True)
105+
thread.start()
106+
coro = cls.afrom_instance(project_id, region, instance, database)
107+
return asyncio.run_coroutine_threadsafe(coro, loop).result()
108+
109+
@classmethod
110+
async def _create(
111+
cls,
112+
project_id: str,
113+
region: str,
114+
instance: str,
115+
database: str,
116+
loop: Optional[asyncio.AbstractEventLoop] = None,
117+
thread: Optional[Thread] = None,
118+
) -> PostgreSQLEngine:
119+
credentials, _ = google.auth.default(
120+
scopes=["https://www.googleapis.com/auth/userinfo.email"]
121+
)
122+
iam_database_user = await _get_iam_principal_email(credentials)
123+
if cls._connector is None:
124+
cls._connector = await create_async_connector()
125+
126+
# anonymous function to be used for SQLAlchemy 'creator' argument
127+
def getconn() -> asyncpg.Connection:
128+
conn = cls._connector.connect_async( # type: ignore
129+
f"{project_id}:{region}:{instance}",
130+
"asyncpg",
131+
user=iam_database_user,
132+
db=database,
133+
enable_iam_auth=True,
134+
)
135+
return conn
136+
137+
engine = create_async_engine(
138+
"postgresql+asyncpg://",
139+
async_creator=getconn,
140+
)
141+
return cls(engine, loop, thread)
142+
143+
@classmethod
144+
async def afrom_instance(
145+
cls,
146+
project_id: str,
147+
region: str,
148+
instance: str,
149+
database: str,
150+
) -> PostgreSQLEngine:
151+
return await cls._create(project_id, region, instance, database)
152+
153+
async def _aexecute(self, query: str):
154+
"""Execute a SQL query."""
155+
async with self._engine.connect() as conn:
156+
await conn.execute(text(query))
157+
await conn.commit()
158+
159+
async def _afetch(self, query: str):
160+
async with self._engine.connect() as conn:
161+
"""Fetch results from a SQL query."""
162+
result = await conn.execute(text(query))
163+
result_map = result.mappings()
164+
result_fetch = result_map.fetchall()
165+
166+
return result_fetch
167+
168+
def run_as_sync(self, coro: Awaitable[T]): # TODO: add return type
169+
if not self._loop:
170+
raise Exception("Engine was initialized async.")
171+
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
172+
173+
async def init_vectorstore_table(
174+
self,
175+
table_name: str,
176+
vector_size: int,
177+
content_column: str = "content",
178+
embedding_column: str = "embedding",
179+
metadata_columns: List[Column] = [],
180+
metadata_json_columns: str = "langchain_metadata",
181+
id_column: str = "langchain_id",
182+
overwrite_existing: bool = False,
183+
store_metadata: bool = True,
184+
) -> None:
185+
await self._aexecute("CREATE EXTENSION IF NOT EXISTS vector")
186+
187+
if overwrite_existing:
188+
await self._aexecute(f"DROP TABLE IF EXISTS {table_name}")
189+
190+
query = f"""CREATE TABLE {table_name}(
191+
{id_column} UUID PRIMARY KEY,
192+
{content_column} TEXT NOT NULL,
193+
{embedding_column} vector({vector_size}) NOT NULL"""
194+
for column in metadata_columns:
195+
query += f",\n{column.name} {column.data_type}" + (
196+
"NOT NULL" if not column.nullable else ""
197+
)
198+
if store_metadata:
199+
query += f",\n{metadata_json_columns} JSON"
200+
query += "\n);"
201+
202+
await self._aexecute(query)

tests/test_postgresql_engine.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
import uuid
18+
from typing import List
19+
20+
import pytest
21+
import pytest_asyncio
22+
from langchain_community.embeddings import FakeEmbeddings
23+
24+
from langchain_google_cloud_sql_pg import Column, PostgreSQLEngine
25+
26+
DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
27+
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
28+
VECTOR_SIZE = 768
29+
30+
31+
class FakeEmbeddingsWithDimension(FakeEmbeddings):
32+
"""Fake embeddings functionality for testing."""
33+
34+
size: int = VECTOR_SIZE
35+
36+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
37+
"""Return simple embeddings."""
38+
return [
39+
[float(1.0)] * (VECTOR_SIZE - 1) + [float(i)] for i in range(len(texts))
40+
]
41+
42+
def embed_query(self, text: str = "default") -> List[float]:
43+
"""Return simple embeddings."""
44+
return [float(1.0)] * (VECTOR_SIZE - 1) + [float(0.0)]
45+
46+
47+
embeddings_service = FakeEmbeddingsWithDimension()
48+
49+
50+
def get_env_var(key: str, desc: str) -> str:
51+
v = os.environ.get(key)
52+
if v is None:
53+
raise ValueError(f"Must set env var {key} to: {desc}")
54+
return v
55+
56+
57+
@pytest.mark.asyncio
58+
class TestEngineAsync:
59+
@pytest.fixture(scope="module")
60+
def db_project(self) -> str:
61+
return get_env_var("PROJECT_ID", "project id for google cloud")
62+
63+
@pytest.fixture(scope="module")
64+
def db_region(self) -> str:
65+
return get_env_var("REGION", "region for cloud sql instance")
66+
67+
@pytest.fixture(scope="module")
68+
def db_instance(self) -> str:
69+
return get_env_var("INSTANCE_ID", "instance for cloud sql")
70+
71+
@pytest.fixture(scope="module")
72+
def db_name(self) -> str:
73+
return get_env_var("DATABASE_ID", "instance for cloud sql")
74+
75+
@pytest_asyncio.fixture
76+
async def engine(self, db_project, db_region, db_instance, db_name):
77+
engine = await PostgreSQLEngine.afrom_instance(
78+
project_id=db_project,
79+
instance=db_instance,
80+
region=db_region,
81+
database=db_name,
82+
)
83+
yield engine
84+
85+
async def test_execute(self, engine):
86+
await engine._aexecute("SELECT 1")
87+
88+
async def test_init_table(self, engine):
89+
await engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE)
90+
id = str(uuid.uuid4())
91+
content = "coffee"
92+
embedding = await embeddings_service.aembed_query(content)
93+
stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');"
94+
await engine._aexecute(stmt)
95+
96+
async def test_fetch(self, engine):
97+
results = await engine._afetch(f"SELECT * FROM {DEFAULT_TABLE}")
98+
assert len(results) > 0
99+
await engine._aexecute(f"DROP TABLE {DEFAULT_TABLE}")
100+
101+
async def test_init_table_custom(self, engine):
102+
await engine.init_vectorstore_table(
103+
CUSTOM_TABLE,
104+
VECTOR_SIZE,
105+
id_column="uuid",
106+
content_column="mycontent",
107+
embedding_column="myembedding",
108+
metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")],
109+
store_metadata=True,
110+
)
111+
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{CUSTOM_TABLE}';"
112+
results = await engine._afetch(stmt)
113+
expected = [
114+
{"column_name": "uuid", "data_type": "uuid"},
115+
{"column_name": "myembedding", "data_type": "USER-DEFINED"},
116+
{"column_name": "langchain_metadata", "data_type": "json"},
117+
{"column_name": "mycontent", "data_type": "text"},
118+
{"column_name": "page", "data_type": "text"},
119+
{"column_name": "source", "data_type": "text"},
120+
]
121+
for row in results:
122+
assert row in expected
123+
124+
await engine._aexecute(f"DROP TABLE {CUSTOM_TABLE}")
125+
126+
def test_sync_engine(self, db_project, db_region, db_instance, db_name):
127+
engine = PostgreSQLEngine.from_instance(
128+
project_id=db_project,
129+
instance=db_instance,
130+
region=db_region,
131+
database=db_name,
132+
)
133+
assert engine

0 commit comments

Comments
 (0)