2424import com .facebook .presto .spi .plan .ProjectNode ;
2525import com .facebook .presto .spi .relation .CallExpression ;
2626import com .facebook .presto .spi .relation .ConstantExpression ;
27+ import com .facebook .presto .spi .relation .InputReferenceExpression ;
2728import com .facebook .presto .spi .relation .LambdaDefinitionExpression ;
2829import com .facebook .presto .spi .relation .RowExpression ;
30+ import com .facebook .presto .spi .relation .RowExpressionVisitor ;
31+ import com .facebook .presto .spi .relation .SpecialFormExpression ;
2932import com .facebook .presto .spi .relation .VariableReferenceExpression ;
3033import com .facebook .presto .sql .planner .TypeProvider ;
3134import com .facebook .presto .sql .planner .plan .RPCNode ;
3235import com .facebook .presto .sql .planner .plan .SimplePlanRewriter ;
3336import com .fasterxml .jackson .databind .JsonNode ;
3437import com .fasterxml .jackson .databind .ObjectMapper ;
3538import com .google .common .collect .ImmutableList ;
39+ import com .google .common .collect .ImmutableMap ;
3640import com .google .common .collect .ImmutableSet ;
3741import io .airlift .slice .Slice ;
3842
@@ -232,6 +236,15 @@ public RowExpression rewriteCall(
232236 RpcExtractionContext context ,
233237 RowExpressionTreeRewriter <RpcExtractionContext > treeRewriter )
234238 {
239+ if (node .getDisplayName ().equals ("$internal$try" )
240+ && node .getArguments ().size () == 1
241+ && isLambdaOrBindWithLambda (node .getArguments ().get (0 ))) {
242+ RowExpression result = rewriteTryWithRpcFunction (node , context );
243+ if (result != null ) {
244+ return result ;
245+ }
246+ }
247+
235248 if (!rpcFunctionNames .contains (
236249 node .getDisplayName ().toLowerCase (Locale .ENGLISH ))) {
237250 return null ;
@@ -258,12 +271,160 @@ public RowExpression rewriteLambda(
258271 RowExpressionTreeRewriter <RpcExtractionContext > treeRewriter )
259272 {
260273 // Do not traverse into lambda bodies — RPC calls inside lambdas
261- // have per-element semantics incompatible with RPCNode batching
274+ // have per-element semantics incompatible with RPCNode batching.
275+ // TRY lambdas are handled specially in rewriteCall above.
262276 return node ;
263277 }
264278 };
265279 }
266280
281+ private static boolean isLambdaOrBindWithLambda (RowExpression expression )
282+ {
283+ if (expression instanceof LambdaDefinitionExpression ) {
284+ return true ;
285+ }
286+ if (expression instanceof SpecialFormExpression
287+ && ((SpecialFormExpression ) expression ).getForm () == SpecialFormExpression .Form .BIND ) {
288+ List <RowExpression > bindArgs = ((SpecialFormExpression ) expression ).getArguments ();
289+ return bindArgs .get (bindArgs .size () - 1 ) instanceof LambdaDefinitionExpression ;
290+ }
291+ return false ;
292+ }
293+
294+ private RowExpression rewriteTryWithRpcFunction (
295+ CallExpression tryCall ,
296+ RpcExtractionContext context )
297+ {
298+ RowExpression tryArgument = tryCall .getArguments ().get (0 );
299+ LambdaDefinitionExpression lambda ;
300+ List <RowExpression > bindVariables ;
301+
302+ if (tryArgument instanceof LambdaDefinitionExpression ) {
303+ lambda = (LambdaDefinitionExpression ) tryArgument ;
304+ bindVariables = ImmutableList .of ();
305+ }
306+ else {
307+ SpecialFormExpression bind = (SpecialFormExpression ) tryArgument ;
308+ List <RowExpression > bindArgs = bind .getArguments ();
309+ lambda = (LambdaDefinitionExpression ) bindArgs .get (bindArgs .size () - 1 );
310+ bindVariables = bindArgs .subList (0 , bindArgs .size () - 1 );
311+ }
312+
313+ RowExpression body = lambda .getBody ();
314+ if (!bindVariables .isEmpty ()) {
315+ List <String > lambdaParams = lambda .getArguments ();
316+ ImmutableMap .Builder <String , RowExpression > substitutions = ImmutableMap .builder ();
317+ for (int i = 0 ; i < lambdaParams .size (); i ++) {
318+ substitutions .put (lambdaParams .get (i ), bindVariables .get (i ));
319+ }
320+ body = VariableSubstitutor .substitute (body , substitutions .build ());
321+ }
322+
323+ if (!containsRpcFunction (body )) {
324+ return null ;
325+ }
326+
327+ RowExpression rewrittenBody = RowExpressionTreeRewriter .rewriteWith (
328+ createRpcRewriter (), body , context );
329+
330+ return rewrittenBody ;
331+ }
332+
333+ private boolean containsRpcFunction (RowExpression expression )
334+ {
335+ if (expression instanceof CallExpression ) {
336+ CallExpression call = (CallExpression ) expression ;
337+ if (rpcFunctionNames .contains (call .getDisplayName ().toLowerCase (Locale .ENGLISH ))) {
338+ return true ;
339+ }
340+ return call .getArguments ().stream ().anyMatch (this ::containsRpcFunction );
341+ }
342+ if (expression instanceof SpecialFormExpression ) {
343+ return ((SpecialFormExpression ) expression ).getArguments ().stream ()
344+ .anyMatch (this ::containsRpcFunction );
345+ }
346+ return false ;
347+ }
348+
349+ private static class VariableSubstitutor
350+ implements RowExpressionVisitor <RowExpression , Void >
351+ {
352+ private final Map <String , RowExpression > substitutions ;
353+
354+ VariableSubstitutor (Map <String , RowExpression > substitutions )
355+ {
356+ this .substitutions = substitutions ;
357+ }
358+
359+ static RowExpression substitute (RowExpression expression , Map <String , RowExpression > substitutions )
360+ {
361+ return expression .accept (new VariableSubstitutor (substitutions ), null );
362+ }
363+
364+ @ Override
365+ public RowExpression visitCall (CallExpression call , Void context )
366+ {
367+ ImmutableList .Builder <RowExpression > newArgs = ImmutableList .builder ();
368+ boolean changed = false ;
369+ for (RowExpression arg : call .getArguments ()) {
370+ RowExpression newArg = arg .accept (this , null );
371+ newArgs .add (newArg );
372+ changed |= newArg != arg ;
373+ }
374+ return changed
375+ ? new CallExpression (call .getSourceLocation (), call .getDisplayName (), call .getFunctionHandle (), call .getType (), newArgs .build ())
376+ : call ;
377+ }
378+
379+ @ Override
380+ public RowExpression visitInputReference (InputReferenceExpression reference , Void context )
381+ {
382+ return reference ;
383+ }
384+
385+ @ Override
386+ public RowExpression visitConstant (ConstantExpression literal , Void context )
387+ {
388+ return literal ;
389+ }
390+
391+ @ Override
392+ public RowExpression visitLambda (LambdaDefinitionExpression lambda , Void context )
393+ {
394+ RowExpression newBody = lambda .getBody ().accept (this , null );
395+ if (newBody .equals (lambda .getBody ())) {
396+ return lambda ;
397+ }
398+ return new LambdaDefinitionExpression (
399+ lambda .getSourceLocation (),
400+ lambda .getArgumentTypes (),
401+ lambda .getArguments (),
402+ newBody );
403+ }
404+
405+ @ Override
406+ public RowExpression visitVariableReference (VariableReferenceExpression reference , Void context )
407+ {
408+ RowExpression replacement = substitutions .get (reference .getName ());
409+ return replacement != null ? replacement : reference ;
410+ }
411+
412+ @ Override
413+ public RowExpression visitSpecialForm (SpecialFormExpression specialForm , Void context )
414+ {
415+ ImmutableList .Builder <RowExpression > newArgs = ImmutableList .builder ();
416+ boolean changed = false ;
417+ for (RowExpression arg : specialForm .getArguments ()) {
418+ RowExpression newArg = arg .accept (this , null );
419+ newArgs .add (newArg );
420+ changed |= newArg != arg ;
421+ }
422+ return changed
423+ ? new SpecialFormExpression (specialForm .getForm (), specialForm .getType (), newArgs .build ())
424+ : specialForm ;
425+ }
426+ }
427+
267428 private static class RpcExtractionContext
268429 {
269430 private final List <ExtractedRpcCall > extractedCalls = new ArrayList <>();
0 commit comments