@@ -1061,6 +1061,9 @@ def from_local_script(
1061
1061
accelerator_count : int = 0 ,
1062
1062
boot_disk_type : str = "pd-ssd" ,
1063
1063
boot_disk_size_gb : int = 100 ,
1064
+ reduction_server_replica_count : int = 0 ,
1065
+ reduction_server_machine_type : Optional [str ] = None ,
1066
+ reduction_server_container_uri : Optional [str ] = None ,
1064
1067
base_output_dir : Optional [str ] = None ,
1065
1068
project : Optional [str ] = None ,
1066
1069
location : Optional [str ] = None ,
@@ -1127,6 +1130,13 @@ def from_local_script(
1127
1130
boot_disk_size_gb (int):
1128
1131
Optional. Size in GB of the boot disk, default is 100GB.
1129
1132
boot disk size must be within the range of [100, 64000].
1133
+ reduction_server_replica_count (int):
1134
+ The number of reduction server replicas, default is 0.
1135
+ reduction_server_machine_type (str):
1136
+ Optional. The type of machine to use for reduction server.
1137
+ reduction_server_container_uri (str):
1138
+ Optional. The Uri of the reduction server container image.
1139
+ See details: https://cloud.google.com/vertex-ai/docs/training/distributed-training#reduce_training_time_with_reduction_server
1130
1140
base_output_dir (str):
1131
1141
Optional. GCS output directory of job. If not provided a
1132
1142
timestamped directory in the staging directory will be used.
@@ -1181,6 +1191,8 @@ def from_local_script(
1181
1191
accelerator_type = accelerator_type ,
1182
1192
boot_disk_type = boot_disk_type ,
1183
1193
boot_disk_size_gb = boot_disk_size_gb ,
1194
+ reduction_server_replica_count = reduction_server_replica_count ,
1195
+ reduction_server_machine_type = reduction_server_machine_type ,
1184
1196
).pool_specs
1185
1197
1186
1198
python_packager = source_utils ._TrainingScriptPythonPackager (
@@ -1191,21 +1203,33 @@ def from_local_script(
1191
1203
gcs_staging_dir = staging_bucket , project = project , credentials = credentials ,
1192
1204
)
1193
1205
1194
- for spec in worker_pool_specs :
1195
- spec ["python_package_spec" ] = {
1196
- "executor_image_uri" : container_uri ,
1197
- "python_module" : python_packager .module_name ,
1198
- "package_uris" : [package_gcs_uri ],
1199
- }
1200
-
1201
- if args :
1202
- spec ["python_package_spec" ]["args" ] = args
1203
-
1204
- if environment_variables :
1205
- spec ["python_package_spec" ]["env" ] = [
1206
- {"name" : key , "value" : value }
1207
- for key , value in environment_variables .items ()
1208
- ]
1206
+ for spec_order , spec in enumerate (worker_pool_specs ):
1207
+
1208
+ if not spec :
1209
+ continue
1210
+
1211
+ if (
1212
+ spec_order == worker_spec_utils ._SPEC_ORDERS ["server_spec" ]
1213
+ and reduction_server_replica_count > 0
1214
+ ):
1215
+ spec ["container_spec" ] = {
1216
+ "image_uri" : reduction_server_container_uri ,
1217
+ }
1218
+ else :
1219
+ spec ["python_package_spec" ] = {
1220
+ "executor_image_uri" : container_uri ,
1221
+ "python_module" : python_packager .module_name ,
1222
+ "package_uris" : [package_gcs_uri ],
1223
+ }
1224
+
1225
+ if args :
1226
+ spec ["python_package_spec" ]["args" ] = args
1227
+
1228
+ if environment_variables :
1229
+ spec ["python_package_spec" ]["env" ] = [
1230
+ {"name" : key , "value" : value }
1231
+ for key , value in environment_variables .items ()
1232
+ ]
1209
1233
1210
1234
return cls (
1211
1235
display_name = display_name ,
0 commit comments