Skip to content

Commit 2de3cba

Browse files
authored
feat: Support IAM account override (#160)
* feat: Support IAM account override * rename to `iam_account_email` * add integration test for iam override * add self * add sync fixture * clean up
1 parent 61413f1 commit 2de3cba

3 files changed

Lines changed: 64 additions & 6 deletions

File tree

integration.cloudbuild.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ steps:
3232
- "INSTANCE_ID=$_INSTANCE_ID"
3333
- "DATABASE_ID=$_DATABASE_ID"
3434
- "REGION=$_REGION"
35-
secretEnv: ["DB_USER", "DB_PASSWORD"]
35+
secretEnv: ["DB_USER", "DB_PASSWORD", "IAM_ACCOUNT"]
3636

3737
availableSecrets:
3838
secretManager:
3939
- versionName: projects/$PROJECT_ID/secrets/langchain-test-pg-username/versions/1
4040
env: "DB_USER"
4141
- versionName: projects/$PROJECT_ID/secrets/langchain-test-pg-password/versions/1
4242
env: "DB_PASSWORD"
43+
- versionName: projects/$PROJECT_ID/secrets/service_account_email/versions/1
44+
env: "IAM_ACCOUNT"
4345

4446
substitutions:
4547
_DATABASE_ID: test-database

src/langchain_google_cloud_sql_pg/engine.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def from_instance(
136136
password: Optional[str] = None,
137137
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
138138
quota_project: Optional[str] = None,
139+
iam_account_email: Optional[str] = None,
139140
) -> PostgresEngine:
140141
"""Create a PostgresEngine from a Postgres instance.
141142
@@ -169,6 +170,7 @@ def from_instance(
169170
loop=loop,
170171
thread=thread,
171172
quota_project=quota_project,
173+
iam_account_email=iam_account_email,
172174
)
173175
return asyncio.run_coroutine_threadsafe(coro, loop).result()
174176

@@ -185,6 +187,7 @@ async def _create(
185187
loop: Optional[asyncio.AbstractEventLoop] = None,
186188
thread: Optional[Thread] = None,
187189
quota_project: Optional[str] = None,
190+
iam_account_email: Optional[str] = None,
188191
) -> PostgresEngine:
189192
"""Create a PostgresEngine instance.
190193
@@ -227,12 +230,15 @@ async def _create(
227230
db_user = user
228231
# otherwise use automatic IAM database authentication
229232
else:
230-
# get application default credentials
231-
credentials, _ = google.auth.default(
232-
scopes=["https://www.googleapis.com/auth/userinfo.email"]
233-
)
234-
db_user = await _get_iam_principal_email(credentials)
235233
enable_iam_auth = True
234+
if iam_account_email:
235+
db_user = iam_account_email
236+
else:
237+
# get application default credentials
238+
credentials, _ = google.auth.default(
239+
scopes=["https://www.googleapis.com/auth/userinfo.email"]
240+
)
241+
db_user = await _get_iam_principal_email(credentials)
236242

237243
# anonymous function to be used for SQLAlchemy 'creator' argument
238244
async def getconn() -> asyncpg.Connection:
@@ -264,6 +270,7 @@ async def afrom_instance(
264270
password: Optional[str] = None,
265271
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
266272
quota_project: Optional[str] = None,
273+
iam_account_email: Optional[str] = None,
267274
) -> PostgresEngine:
268275
"""Create a PostgresEngine from a Postgres instance.
269276
@@ -290,6 +297,7 @@ async def afrom_instance(
290297
user,
291298
password,
292299
quota_project=quota_project,
300+
iam_account_email=iam_account_email,
293301
)
294302

295303
@classmethod

tests/test_postgresql_engine.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def user(self) -> str:
6565
def password(self) -> str:
6666
return get_env_var("DB_PASSWORD", "database password for cloud sql")
6767

68+
@pytest.fixture(scope="module")
69+
def iam_account(self) -> str:
70+
return get_env_var("IAM_ACCOUNT", "Cloud SQL IAM account email")
71+
6872
@pytest_asyncio.fixture
6973
async def engine(self, db_project, db_region, db_instance, db_name):
7074
engine = await PostgresEngine.afrom_instance(
@@ -176,6 +180,26 @@ async def test_column(self, engine):
176180
with pytest.raises(ValueError):
177181
Column(1, "INTEGER")
178182

183+
async def test_iam_account_override(
184+
self,
185+
db_project,
186+
db_instance,
187+
db_region,
188+
db_name,
189+
iam_account,
190+
):
191+
engine = await PostgresEngine.afrom_instance(
192+
project_id=db_project,
193+
instance=db_instance,
194+
region=db_region,
195+
database=db_name,
196+
iam_account_email=iam_account,
197+
)
198+
assert engine
199+
await engine._aexecute("SELECT 1")
200+
await engine._connector.close_async()
201+
await engine._engine.dispose()
202+
179203

180204
@pytest.mark.asyncio
181205
class TestEngineSync:
@@ -203,6 +227,10 @@ def user(self) -> str:
203227
def password(self) -> str:
204228
return get_env_var("DB_PASSWORD", "database password for cloud sql")
205229

230+
@pytest.fixture(scope="module")
231+
def iam_account(self) -> str:
232+
return get_env_var("IAM_ACCOUNT", "Cloud SQL IAM account email")
233+
206234
@pytest_asyncio.fixture
207235
def engine(self, db_project, db_region, db_instance, db_name):
208236
engine = PostgresEngine.from_instance(
@@ -285,3 +313,23 @@ async def test_engine_constructor_key(
285313
key = object()
286314
with pytest.raises(Exception):
287315
PostgresEngine(key, engine)
316+
317+
def test_iam_account_override(
318+
self,
319+
db_project,
320+
db_instance,
321+
db_region,
322+
db_name,
323+
iam_account,
324+
):
325+
engine = PostgresEngine.from_instance(
326+
project_id=db_project,
327+
instance=db_instance,
328+
region=db_region,
329+
database=db_name,
330+
iam_account_email=iam_account,
331+
)
332+
assert engine
333+
engine._execute("SELECT 1")
334+
engine._connector.close()
335+
engine._engine.dispose()

0 commit comments

Comments
 (0)