@@ -46,6 +46,8 @@ def get_pathways_worker_args(args) -> str:
46
46
- --resource_manager_address={rm_address}
47
47
- --gcs_scratch_location={args.pathways_gcs_location}"""
48
48
if args .use_pathways :
49
+ if args .custom_pathways_worker_args :
50
+ yaml = append_custom_pathways_args (yaml , args .custom_pathways_worker_args )
49
51
return yaml .format (args = args , rm_address = get_rm_address (args ))
50
52
else :
51
53
return ''
@@ -64,6 +66,10 @@ def get_pathways_proxy_args(args) -> str:
64
66
- --gcs_scratch_location={args.pathways_gcs_location}"""
65
67
66
68
if args .use_pathways :
69
+ if args .custom_pathways_proxy_server_args :
70
+ yaml = append_custom_pathways_args (
71
+ yaml , args .custom_pathways_proxy_server_args
72
+ )
67
73
return yaml .format (args = args , rm_address = get_rm_address (args ))
68
74
else :
69
75
return ''
@@ -233,6 +239,8 @@ def get_pathways_rm_args(args, system: SystemCharacteristics) -> str:
233
239
- --instance_count={instance_count}
234
240
- --instance_type={instance_type}"""
235
241
if args .use_pathways :
242
+ if args .custom_pathways_server_args :
243
+ yaml = append_custom_pathways_args (yaml , args .custom_pathways_server_args )
236
244
return yaml .format (
237
245
args = args ,
238
246
instance_count = args .num_slices ,
@@ -242,6 +250,28 @@ def get_pathways_rm_args(args, system: SystemCharacteristics) -> str:
242
250
return ''
243
251
244
252
253
+ def append_custom_pathways_args (yaml , custom_args ) -> str :
254
+ """Append custom Pathways args to the YAML with proper indentation.
255
+
256
+ Args:
257
+ yaml (string): existing yaml containing args
258
+
259
+ Returns:
260
+ yaml (string): yaml with additional args appended.
261
+ """
262
+ second_line = yaml .split ('\n ' )[1 ]
263
+ if (
264
+ not second_line
265
+ ): # to cover edge case if only one arg remains, we would have to look at the entire YAML in this case.
266
+ return yaml
267
+ # Calculate the indentation based on the second line of existing YAML.
268
+ indentation = ' ' * (len (second_line ) - len (second_line .lstrip ()))
269
+ custom_args = custom_args .split (' ' )
270
+ for arg in custom_args :
271
+ yaml += '\n ' + indentation + '- ' + arg
272
+ return yaml
273
+
274
+
245
275
def get_user_workload_for_pathways (
246
276
args ,
247
277
system : SystemCharacteristics ,
0 commit comments