Skip to content

Commit 9f603dd

Browse files
authored
fix: tb-gcp-uploader to show flags in "--help" correctly (#409)
1 parent 7b7c950 commit 9f603dd

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

google/cloud/aiplatform/tensorboard/uploader_main.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,24 @@
4949
flags.DEFINE_integer(
5050
"event_file_inactive_secs",
5151
None,
52-
"Age in seconds of last write after which an event file is considered " "inactive.",
52+
"Age in seconds of last write after which an event file is considered inactive.",
5353
)
5454
flags.DEFINE_string(
5555
"run_name_prefix",
5656
None,
5757
"If present, all runs created by this invocation will have their name "
5858
"prefixed by this value.",
5959
)
60+
flags.DEFINE_string(
61+
"api_uri",
62+
"aiplatform.googleapis.com",
63+
"The API URI for fetching Tensorboard metadata.",
64+
)
65+
flags.DEFINE_string(
66+
"web_server_uri",
67+
"tensorboard.googleusercontent.com",
68+
"The API URI for accessing the Tensorboard UI.",
69+
)
6070

6171
flags.DEFINE_multi_string(
6272
"allowed_plugins",
@@ -79,6 +89,7 @@ def main(argv):
7989
if len(argv) > 1:
8090
raise app.UsageError("Too many command-line arguments.")
8191

92+
aiplatform.constants.API_BASE_PATH = FLAGS.api_uri
8293
m = re.match(
8394
"projects/(.*)/locations/(.*)/tensorboards/.*", FLAGS.tensorboard_resource_name
8495
)
@@ -131,7 +142,7 @@ def main(argv):
131142
print(
132143
"View your Tensorboard at https://{}.{}/experiment/{}".format(
133144
region,
134-
"tensorboard.googleusercontent.com",
145+
FLAGS.web_server_uri,
135146
tb_uploader.get_experiment_resource_name().replace("/", "+"),
136147
)
137148
)
@@ -141,8 +152,16 @@ def main(argv):
141152
tb_uploader.start_uploading()
142153

143154

155+
def flags_parser(args):
156+
# Plumbs the flags defined in this file to the main module, mostly for the
157+
# console script wrapper tb-gcp-uploader.
158+
for flag in set(flags.FLAGS.get_key_flags_for_module(__name__)):
159+
flags.FLAGS.register_flag_by_module(args[0], flag)
160+
return app.parse_flags_with_usage(args)
161+
162+
144163
def run_main():
145-
app.run(main)
164+
app.run(main, flags_parser=flags_parser)
146165

147166

148167
if __name__ == "__main__":

0 commit comments

Comments
 (0)