Skip to content

Commit 022fe8c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: offline store - pass credentials to FeatureGroup/Feature
PiperOrigin-RevId: 651081315
1 parent bbd4a49 commit 022fe8c

File tree

4 files changed

+317
-18
lines changed

4 files changed

+317
-18
lines changed

tests/unit/vertexai/test_feature_group.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from unittest import mock
2121
from unittest.mock import call, patch
2222

23+
from google.auth import credentials as auth_credentials
2324
from google.api_core import operation as ga_operation
2425
from google.cloud import aiplatform
2526
from google.cloud.aiplatform import base
@@ -158,6 +159,16 @@ def list_features_mock():
158159
yield list_features_mock
159160

160161

162+
@pytest.fixture()
163+
def mock_base_instantiate_client():
164+
with patch.object(
165+
aiplatform.base.VertexAiResourceNoun,
166+
"_instantiate_client",
167+
) as base_instantiate_client_mock:
168+
base_instantiate_client_mock.return_value = mock.MagicMock()
169+
yield base_instantiate_client_mock
170+
171+
161172
def fg_eq(
162173
fg_to_check: FeatureGroup,
163174
name: str,
@@ -399,6 +410,260 @@ def test_get_feature(get_fg_mock, get_feature_mock):
399410
)
400411

401412

413+
def test_get_feature_credentials_set_in_init(mock_base_instantiate_client):
414+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
415+
aiplatform.init(
416+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=credentials
417+
)
418+
419+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
420+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
421+
422+
fg = FeatureGroup(_TEST_FG1_ID)
423+
mock_base_instantiate_client.assert_called_with(
424+
location=_TEST_LOCATION,
425+
credentials=credentials,
426+
appended_user_agent=None,
427+
)
428+
429+
feature = fg.get_feature(_TEST_FG1_F1_ID)
430+
mock_base_instantiate_client.assert_called_with(
431+
location=_TEST_LOCATION,
432+
credentials=credentials,
433+
appended_user_agent=None,
434+
)
435+
436+
feature_eq(
437+
feature,
438+
name=_TEST_FG1_F1_ID,
439+
resource_name=_TEST_FG1_F1_PATH,
440+
project=_TEST_PROJECT,
441+
location=_TEST_LOCATION,
442+
description=_TEST_FG1_F1_DESCRIPTION,
443+
labels=_TEST_FG1_F1_LABELS,
444+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
445+
)
446+
447+
448+
def test_get_feature_from_feature_group_with_explicit_credentials(
449+
mock_base_instantiate_client,
450+
):
451+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
452+
453+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
454+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
455+
456+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
457+
fg = FeatureGroup(_TEST_FG1_ID, credentials=credentials)
458+
mock_base_instantiate_client.assert_called_with(
459+
location=_TEST_LOCATION,
460+
credentials=credentials,
461+
appended_user_agent=None,
462+
)
463+
464+
feature = fg.get_feature(_TEST_FG1_F1_ID)
465+
mock_base_instantiate_client.assert_called_with(
466+
location=_TEST_LOCATION,
467+
credentials=credentials,
468+
appended_user_agent=None,
469+
)
470+
471+
feature_eq(
472+
feature,
473+
name=_TEST_FG1_F1_ID,
474+
resource_name=_TEST_FG1_F1_PATH,
475+
project=_TEST_PROJECT,
476+
location=_TEST_LOCATION,
477+
description=_TEST_FG1_F1_DESCRIPTION,
478+
labels=_TEST_FG1_F1_LABELS,
479+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
480+
)
481+
482+
483+
def test_get_feature_from_feature_group_with_explicit_credentials_overrides_init_credentials(
484+
mock_base_instantiate_client,
485+
):
486+
init_credentials = mock.MagicMock(spec=auth_credentials.Credentials)
487+
aiplatform.init(
488+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=init_credentials
489+
)
490+
491+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
492+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
493+
494+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
495+
fg = FeatureGroup(_TEST_FG1_ID, credentials=credentials)
496+
mock_base_instantiate_client.assert_called_with(
497+
location=_TEST_LOCATION,
498+
credentials=credentials,
499+
appended_user_agent=None,
500+
)
501+
502+
feature = fg.get_feature(_TEST_FG1_F1_ID)
503+
mock_base_instantiate_client.assert_called_with(
504+
location=_TEST_LOCATION,
505+
credentials=credentials,
506+
appended_user_agent=None,
507+
)
508+
509+
feature_eq(
510+
feature,
511+
name=_TEST_FG1_F1_ID,
512+
resource_name=_TEST_FG1_F1_PATH,
513+
project=_TEST_PROJECT,
514+
location=_TEST_LOCATION,
515+
description=_TEST_FG1_F1_DESCRIPTION,
516+
labels=_TEST_FG1_F1_LABELS,
517+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
518+
)
519+
520+
521+
def test_get_feature_with_explicit_credentials(mock_base_instantiate_client):
522+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
523+
524+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
525+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
526+
527+
fg = FeatureGroup(_TEST_FG1_ID)
528+
mock_base_instantiate_client.assert_called_with(
529+
location=_TEST_LOCATION,
530+
credentials=mock.ANY,
531+
appended_user_agent=None,
532+
)
533+
534+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
535+
feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials)
536+
mock_base_instantiate_client.assert_called_with(
537+
location=_TEST_LOCATION,
538+
credentials=credentials,
539+
appended_user_agent=None,
540+
)
541+
542+
feature_eq(
543+
feature,
544+
name=_TEST_FG1_F1_ID,
545+
resource_name=_TEST_FG1_F1_PATH,
546+
project=_TEST_PROJECT,
547+
location=_TEST_LOCATION,
548+
description=_TEST_FG1_F1_DESCRIPTION,
549+
labels=_TEST_FG1_F1_LABELS,
550+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
551+
)
552+
553+
554+
def test_get_feature_with_explicit_credentials_overrides_init_credentials(
555+
mock_base_instantiate_client,
556+
):
557+
init_credentials = mock.MagicMock(spec=auth_credentials.Credentials)
558+
aiplatform.init(
559+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=init_credentials
560+
)
561+
562+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
563+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
564+
565+
fg = FeatureGroup(_TEST_FG1_ID)
566+
mock_base_instantiate_client.assert_called_with(
567+
location=_TEST_LOCATION,
568+
credentials=init_credentials,
569+
appended_user_agent=None,
570+
)
571+
572+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
573+
feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials)
574+
mock_base_instantiate_client.assert_called_with(
575+
location=_TEST_LOCATION,
576+
credentials=credentials,
577+
appended_user_agent=None,
578+
)
579+
580+
feature_eq(
581+
feature,
582+
name=_TEST_FG1_F1_ID,
583+
resource_name=_TEST_FG1_F1_PATH,
584+
project=_TEST_PROJECT,
585+
location=_TEST_LOCATION,
586+
description=_TEST_FG1_F1_DESCRIPTION,
587+
labels=_TEST_FG1_F1_LABELS,
588+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
589+
)
590+
591+
592+
def test_get_feature_with_explicit_credentials_overrides_feature_group_credentials(
593+
mock_base_instantiate_client,
594+
):
595+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
596+
597+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
598+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
599+
600+
feature_group_credentials = mock.MagicMock(spec=auth_credentials.Credentials)
601+
fg = FeatureGroup(_TEST_FG1_ID, credentials=feature_group_credentials)
602+
mock_base_instantiate_client.assert_called_with(
603+
location=_TEST_LOCATION,
604+
credentials=feature_group_credentials,
605+
appended_user_agent=None,
606+
)
607+
608+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
609+
feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials)
610+
mock_base_instantiate_client.assert_called_with(
611+
location=_TEST_LOCATION,
612+
credentials=credentials,
613+
appended_user_agent=None,
614+
)
615+
616+
feature_eq(
617+
feature,
618+
name=_TEST_FG1_F1_ID,
619+
resource_name=_TEST_FG1_F1_PATH,
620+
project=_TEST_PROJECT,
621+
location=_TEST_LOCATION,
622+
description=_TEST_FG1_F1_DESCRIPTION,
623+
labels=_TEST_FG1_F1_LABELS,
624+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
625+
)
626+
627+
628+
def test_get_feature_with_explicit_credentials_overrides_init_and_feature_group_credentials(
629+
mock_base_instantiate_client,
630+
):
631+
init_credentials = mock.MagicMock(spec=auth_credentials.Credentials)
632+
aiplatform.init(
633+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=init_credentials
634+
)
635+
636+
mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1
637+
mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1
638+
639+
feature_group_credentials = mock.MagicMock(spec=auth_credentials.Credentials)
640+
fg = FeatureGroup(_TEST_FG1_ID, credentials=feature_group_credentials)
641+
mock_base_instantiate_client.assert_called_with(
642+
location=_TEST_LOCATION,
643+
credentials=feature_group_credentials,
644+
appended_user_agent=None,
645+
)
646+
647+
credentials = mock.MagicMock(spec=auth_credentials.Credentials)
648+
feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials)
649+
mock_base_instantiate_client.assert_called_with(
650+
location=_TEST_LOCATION,
651+
credentials=credentials,
652+
appended_user_agent=None,
653+
)
654+
655+
feature_eq(
656+
feature,
657+
name=_TEST_FG1_F1_ID,
658+
resource_name=_TEST_FG1_F1_PATH,
659+
project=_TEST_PROJECT,
660+
location=_TEST_LOCATION,
661+
description=_TEST_FG1_F1_DESCRIPTION,
662+
labels=_TEST_FG1_F1_LABELS,
663+
point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT,
664+
)
665+
666+
402667
@pytest.mark.parametrize("create_request_timeout", [None, 1.0])
403668
@pytest.mark.parametrize("sync", [True, False])
404669
def test_create_feature(

tests/unit/vertexai/test_offline_store.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,4 +478,16 @@ def test_one_feature_with_explicit_credentials(
478478
index_col=bigframes.enums.DefaultIndexKind.NULL,
479479
)
480480

481+
# Ensure when getting the FeatureGroup and Feature, the credentials are
482+
# passed through.
483+
mock_fg.assert_called_once_with(
484+
FeatureGroup,
485+
"fake",
486+
project=None,
487+
credentials=credentials,
488+
)
489+
mock_fg.return_value.get_feature.assert_called_once_with(
490+
"my_feature",
491+
)
492+
481493
assert rsp == "SOME SQL QUERY OUTPUT"

vertexai/resources/preview/feature_store/feature_group.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,30 @@ def delete(self, force: bool = False, sync: bool = True) -> None:
238238
lro.result()
239239
_LOGGER.log_delete_complete(self)
240240

241-
def get_feature(self, feature_id: str) -> Feature:
241+
def get_feature(
242+
self,
243+
feature_id: str,
244+
credentials: Optional[auth_credentials.Credentials] = None,
245+
) -> Feature:
242246
"""Retrieves an existing managed feature.
243247
244248
Args:
245249
feature_id: The ID of the feature.
250+
credentials:
251+
Custom credentials to use to retrieve the feature under this
252+
feature group. The order of which credentials are used is as
253+
follows: (1) this parameter (2) credentials passed to FeatureGroup
254+
constructor (3) credentials set in aiplatform.init.
246255
247256
Returns:
248257
Feature - the Feature resource object under this feature group.
249258
"""
250-
return Feature(f"{self.resource_name}/features/{feature_id}")
259+
credentials = (
260+
credentials or self.credentials or initializer.global_config.credentials
261+
)
262+
return Feature(
263+
f"{self.resource_name}/features/{feature_id}", credentials=credentials
264+
)
251265

252266
def create_feature(
253267
self,

0 commit comments

Comments
 (0)