Skip to content

Commit fa9bc39

Browse files
authored
feat: add tensorboard support to custom job and hyperparameter tuning job (#404)
1 parent aab9e58 commit fa9bc39

File tree

3 files changed

+234
-10
lines changed

3 files changed

+234
-10
lines changed

google/cloud/aiplatform/jobs.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@
4545
batch_prediction_job_v1 as gca_bp_job_v1,
4646
batch_prediction_job_v1beta1 as gca_bp_job_v1beta1,
4747
custom_job as gca_custom_job_compat,
48+
custom_job_v1beta1 as gca_custom_job_v1beta1,
4849
explanation_v1beta1 as gca_explanation_v1beta1,
4950
io as gca_io_compat,
5051
io_v1beta1 as gca_io_v1beta1,
5152
job_state as gca_job_state,
5253
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
54+
hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_v1beta1,
5355
machine_resources as gca_machine_resources_compat,
5456
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
5557
study as gca_study_compat,
@@ -1132,6 +1134,7 @@ def run(
11321134
network: Optional[str] = None,
11331135
timeout: Optional[int] = None,
11341136
restart_job_on_worker_restart: bool = False,
1137+
tensorboard: Optional[str] = None,
11351138
sync: bool = True,
11361139
) -> None:
11371140
"""Run this configured CustomJob.
@@ -1152,6 +1155,21 @@ def run(
11521155
gets restarted. This feature can be used by
11531156
distributed training jobs that are not resilient
11541157
to workers leaving and joining a job.
1158+
tensorboard (str):
1159+
Optional. The name of an AI Platform
1160+
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
1161+
resource to which this CustomJob will upload Tensorboard
1162+
logs. Format:
1163+
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
1164+
1165+
The training script should write Tensorboard to following AI Platform environment
1166+
variable:
1167+
1168+
AIP_TENSORBOARD_LOG_DIR
1169+
1170+
`service_account` is required with provided `tensorboard`.
1171+
For more information on configuring your service account please visit:
1172+
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
11551173
sync (bool):
11561174
Whether to execute this method synchronously. If False, this method
11571175
will unblock and it will be executed in a concurrent Future.
@@ -1170,9 +1188,18 @@ def run(
11701188
restart_job_on_worker_restart=restart_job_on_worker_restart,
11711189
)
11721190

1191+
if tensorboard:
1192+
v1beta1_gca_resource = gca_custom_job_v1beta1.CustomJob()
1193+
v1beta1_gca_resource._pb.MergeFromString(
1194+
self._gca_resource._pb.SerializeToString()
1195+
)
1196+
self._gca_resource = v1beta1_gca_resource
1197+
self._gca_resource.job_spec.tensorboard = tensorboard
1198+
11731199
_LOGGER.log_create_with_lro(self.__class__)
11741200

1175-
self._gca_resource = self.api_client.create_custom_job(
1201+
version = "v1beta1" if tensorboard else "v1"
1202+
self._gca_resource = self.api_client.select_version(version).create_custom_job(
11761203
parent=self._parent, custom_job=self._gca_resource
11771204
)
11781205

@@ -1415,6 +1442,7 @@ def run(
14151442
network: Optional[str] = None,
14161443
timeout: Optional[int] = None, # seconds
14171444
restart_job_on_worker_restart: bool = False,
1445+
tensorboard: Optional[str] = None,
14181446
sync: bool = True,
14191447
) -> None:
14201448
"""Run this configured CustomJob.
@@ -1435,6 +1463,21 @@ def run(
14351463
gets restarted. This feature can be used by
14361464
distributed training jobs that are not resilient
14371465
to workers leaving and joining a job.
1466+
tensorboard (str):
1467+
Optional. The name of an AI Platform
1468+
[Tensorboard][google.cloud.aiplatform.v1beta1.Tensorboard]
1469+
resource to which this CustomJob will upload Tensorboard
1470+
logs. Format:
1471+
``projects/{project}/locations/{location}/tensorboards/{tensorboard}``
1472+
1473+
The training script should write Tensorboard to following AI Platform environment
1474+
variable:
1475+
1476+
AIP_TENSORBOARD_LOG_DIR
1477+
1478+
`service_account` is required with provided `tensorboard`.
1479+
For more information on configuring your service account please visit:
1480+
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
14381481
sync (bool):
14391482
Whether to execute this method synchronously. If False, this method
14401483
will unblock and it will be executed in a concurrent Future.
@@ -1453,9 +1496,22 @@ def run(
14531496
restart_job_on_worker_restart=restart_job_on_worker_restart,
14541497
)
14551498

1499+
if tensorboard:
1500+
v1beta1_gca_resource = (
1501+
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
1502+
)
1503+
v1beta1_gca_resource._pb.MergeFromString(
1504+
self._gca_resource._pb.SerializeToString()
1505+
)
1506+
self._gca_resource = v1beta1_gca_resource
1507+
self._gca_resource.trial_job_spec.tensorboard = tensorboard
1508+
14561509
_LOGGER.log_create_with_lro(self.__class__)
14571510

1458-
self._gca_resource = self.api_client.create_hyperparameter_tuning_job(
1511+
version = "v1beta1" if tensorboard else "v1"
1512+
self._gca_resource = self.api_client.select_version(
1513+
version
1514+
).create_hyperparameter_tuning_job(
14591515
parent=self._parent, hyperparameter_tuning_job=self._gca_resource
14601516
)
14611517

tests/unit/aiplatform/test_custom_job.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,18 @@
2929

3030
from google.cloud import aiplatform
3131
from google.cloud.aiplatform.compat.types import custom_job as gca_custom_job_compat
32+
from google.cloud.aiplatform.compat.types import (
33+
custom_job_v1beta1 as gca_custom_job_v1beta1,
34+
)
3235
from google.cloud.aiplatform.compat.types import io as gca_io_compat
3336
from google.cloud.aiplatform.compat.types import job_state as gca_job_state_compat
3437
from google.cloud.aiplatform.compat.types import (
3538
encryption_spec as gca_encryption_spec_compat,
3639
)
3740
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
41+
from google.cloud.aiplatform_v1beta1.services.job_service import (
42+
client as job_service_client_v1beta1,
43+
)
3844

3945
_TEST_PROJECT = "test-project"
4046
_TEST_LOCATION = "us-central1"
@@ -44,6 +50,7 @@
4450
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
4551

4652
_TEST_CUSTOM_JOB_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"
53+
_TEST_TENSORBOARD_NAME = f"{_TEST_PARENT}/tensorboards/{_TEST_ID}"
4754

4855
_TEST_TRAINING_CONTAINER_IMAGE = "gcr.io/test-training/container:image"
4956

@@ -97,11 +104,20 @@
97104
)
98105

99106

100-
def _get_custom_job_proto(state=None, name=None, error=None):
107+
def _get_custom_job_proto(state=None, name=None, error=None, version="v1"):
101108
custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO)
102109
custom_job_proto.name = name
103110
custom_job_proto.state = state
104111
custom_job_proto.error = error
112+
113+
if version == "v1beta1":
114+
v1beta1_custom_job_proto = gca_custom_job_v1beta1.CustomJob()
115+
v1beta1_custom_job_proto._pb.MergeFromString(
116+
custom_job_proto._pb.SerializeToString()
117+
)
118+
custom_job_proto = v1beta1_custom_job_proto
119+
custom_job_proto.job_spec.tensorboard = _TEST_TENSORBOARD_NAME
120+
105121
return custom_job_proto
106122

107123

@@ -162,6 +178,19 @@ def create_custom_job_mock():
162178
yield create_custom_job_mock
163179

164180

181+
@pytest.fixture
182+
def create_custom_job_v1beta1_mock():
183+
with mock.patch.object(
184+
job_service_client_v1beta1.JobServiceClient, "create_custom_job"
185+
) as create_custom_job_mock:
186+
create_custom_job_mock.return_value = _get_custom_job_proto(
187+
name=_TEST_CUSTOM_JOB_NAME,
188+
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
189+
version="v1beta1",
190+
)
191+
yield create_custom_job_mock
192+
193+
165194
class TestCustomJob:
166195
def setup_method(self):
167196
reload(aiplatform.initializer)
@@ -321,3 +350,43 @@ def test_create_from_local_script_raises_with_no_staging_bucket(
321350
script_path=test_training_jobs._TEST_LOCAL_SCRIPT_FILE_NAME,
322351
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
323352
)
353+
354+
@pytest.mark.parametrize("sync", [True, False])
355+
def test_create_custom_job_with_tensorboard(
356+
self, create_custom_job_v1beta1_mock, get_custom_job_mock, sync
357+
):
358+
359+
aiplatform.init(
360+
project=_TEST_PROJECT,
361+
location=_TEST_LOCATION,
362+
staging_bucket=_TEST_STAGING_BUCKET,
363+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
364+
)
365+
366+
job = aiplatform.CustomJob(
367+
display_name=_TEST_DISPLAY_NAME, worker_pool_specs=_TEST_WORKER_POOL_SPEC
368+
)
369+
370+
job.run(
371+
service_account=_TEST_SERVICE_ACCOUNT,
372+
tensorboard=_TEST_TENSORBOARD_NAME,
373+
network=_TEST_NETWORK,
374+
timeout=_TEST_TIMEOUT,
375+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
376+
sync=sync,
377+
)
378+
379+
job.wait()
380+
381+
expected_custom_job = _get_custom_job_proto(version="v1beta1")
382+
383+
create_custom_job_v1beta1_mock.assert_called_once_with(
384+
parent=_TEST_PARENT, custom_job=expected_custom_job
385+
)
386+
387+
expected_custom_job = _get_custom_job_proto()
388+
389+
assert job.job_spec == expected_custom_job.job_spec
390+
assert (
391+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
392+
)

tests/unit/aiplatform/test_hyperparameter_tuning_job.py

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@
3131
)
3232
from google.cloud.aiplatform.compat.types import (
3333
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
34+
hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_v1beta1,
3435
)
3536
from google.cloud.aiplatform.compat.types import study as gca_study_compat
3637
from google.cloud.aiplatform_v1.services.job_service import client as job_service_client
38+
from google.cloud.aiplatform_v1beta1.services.job_service import (
39+
client as job_service_client_v1beta1,
40+
)
3741

3842
import test_custom_job
3943

@@ -122,12 +126,29 @@
122126
)
123127

124128

125-
def _get_hyperparameter_tuning_job_proto(state=None, name=None, error=None):
126-
custom_job_proto = copy.deepcopy(_TEST_BASE_HYPERPARAMETER_TUNING_JOB_PROTO)
127-
custom_job_proto.name = name
128-
custom_job_proto.state = state
129-
custom_job_proto.error = error
130-
return custom_job_proto
129+
def _get_hyperparameter_tuning_job_proto(
130+
state=None, name=None, error=None, version="v1"
131+
):
132+
hyperparameter_tuning_job_proto = copy.deepcopy(
133+
_TEST_BASE_HYPERPARAMETER_TUNING_JOB_PROTO
134+
)
135+
hyperparameter_tuning_job_proto.name = name
136+
hyperparameter_tuning_job_proto.state = state
137+
hyperparameter_tuning_job_proto.error = error
138+
139+
if version == "v1beta1":
140+
v1beta1_hyperparameter_tuning_job_proto = (
141+
gca_hyperparameter_tuning_job_v1beta1.HyperparameterTuningJob()
142+
)
143+
v1beta1_hyperparameter_tuning_job_proto._pb.MergeFromString(
144+
hyperparameter_tuning_job_proto._pb.SerializeToString()
145+
)
146+
hyperparameter_tuning_job_proto = v1beta1_hyperparameter_tuning_job_proto
147+
hyperparameter_tuning_job_proto.trial_job_spec.tensorboard = (
148+
test_custom_job._TEST_TENSORBOARD_NAME
149+
)
150+
151+
return hyperparameter_tuning_job_proto
131152

132153

133154
@pytest.fixture
@@ -187,7 +208,20 @@ def create_hyperparameter_tuning_job_mock():
187208
yield create_hyperparameter_tuning_job_mock
188209

189210

190-
class TestCustomJob:
211+
@pytest.fixture
212+
def create_hyperparameter_tuning_job_v1beta1_mock():
213+
with mock.patch.object(
214+
job_service_client_v1beta1.JobServiceClient, "create_hyperparameter_tuning_job"
215+
) as create_hyperparameter_tuning_job_mock:
216+
create_hyperparameter_tuning_job_mock.return_value = _get_hyperparameter_tuning_job_proto(
217+
name=_TEST_HYPERPARAMETERTUNING_JOB_NAME,
218+
state=gca_job_state_compat.JobState.JOB_STATE_PENDING,
219+
version="v1beta1",
220+
)
221+
yield create_hyperparameter_tuning_job_mock
222+
223+
224+
class TestHyperparameterTuningJob:
191225
def setup_method(self):
192226
reload(aiplatform.initializer)
193227
reload(aiplatform)
@@ -366,3 +400,68 @@ def test_get_hyperparameter_tuning_job(self, get_hyperparameter_tuning_job_mock)
366400
assert (
367401
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING
368402
)
403+
404+
@pytest.mark.parametrize("sync", [True, False])
405+
def test_create_hyperparameter_tuning_job_with_tensorboard(
406+
self,
407+
create_hyperparameter_tuning_job_v1beta1_mock,
408+
get_hyperparameter_tuning_job_mock,
409+
sync,
410+
):
411+
412+
aiplatform.init(
413+
project=_TEST_PROJECT,
414+
location=_TEST_LOCATION,
415+
staging_bucket=_TEST_STAGING_BUCKET,
416+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
417+
)
418+
419+
custom_job = aiplatform.CustomJob(
420+
display_name=test_custom_job._TEST_DISPLAY_NAME,
421+
worker_pool_specs=test_custom_job._TEST_WORKER_POOL_SPEC,
422+
)
423+
424+
job = aiplatform.HyperparameterTuningJob(
425+
display_name=_TEST_DISPLAY_NAME,
426+
custom_job=custom_job,
427+
metric_spec={_TEST_METRIC_SPEC_KEY: _TEST_METRIC_SPEC_VALUE},
428+
parameter_spec={
429+
"lr": hpt.DoubleParameterSpec(min=0.001, max=0.1, scale="log"),
430+
"units": hpt.IntegerParameterSpec(min=4, max=1028, scale="linear"),
431+
"activation": hpt.CategoricalParameterSpec(
432+
values=["relu", "sigmoid", "elu", "selu", "tanh"]
433+
),
434+
"batch_size": hpt.DiscreteParameterSpec(
435+
values=[16, 32], scale="linear"
436+
),
437+
},
438+
parallel_trial_count=_TEST_PARALLEL_TRIAL_COUNT,
439+
max_trial_count=_TEST_MAX_TRIAL_COUNT,
440+
max_failed_trial_count=_TEST_MAX_FAILED_TRIAL_COUNT,
441+
search_algorithm=_TEST_SEARCH_ALGORITHM,
442+
measurement_selection=_TEST_MEASUREMENT_SELECTION,
443+
)
444+
445+
job.run(
446+
service_account=_TEST_SERVICE_ACCOUNT,
447+
network=_TEST_NETWORK,
448+
timeout=_TEST_TIMEOUT,
449+
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
450+
tensorboard=test_custom_job._TEST_TENSORBOARD_NAME,
451+
sync=sync,
452+
)
453+
454+
job.wait()
455+
456+
expected_hyperparameter_tuning_job = _get_hyperparameter_tuning_job_proto(
457+
version="v1beta1"
458+
)
459+
460+
create_hyperparameter_tuning_job_v1beta1_mock.assert_called_once_with(
461+
parent=_TEST_PARENT,
462+
hyperparameter_tuning_job=expected_hyperparameter_tuning_job,
463+
)
464+
465+
assert (
466+
job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED
467+
)

0 commit comments

Comments
 (0)