Skip to content

Commit bdfdf4f

Browse files
feilong-liufacebook-github-bot
authored andcommitted
[Presto] Handle RPC functions inside TRY expressions in RpcFunctionOptimizer
Summary: `TRY(rpc_func(...))` in SQL is lowered to `$internal$try(BIND(vars, lambda))`, wrapping the RPC function inside a lambda. The `RpcFunctionOptimizer` skips lambda bodies (to avoid breaking per-element semantics in `transform()`/`filter()` lambdas), so RPC functions inside TRY were never rewritten to use `RPCNode`. This caused `VeloxRuntimeError` at runtime because the stub function was called directly. This diff adds special handling for `$internal$try(BIND(..., lambda))` in the optimizer: 1. Detects the `$internal$try` pattern in `rewriteCall` 2. Unwraps the BIND+lambda structure 3. Substitutes lambda parameters with bound variables (so the body uses plan-level variables) 4. Rewrites the substituted body to extract RPC functions into `RPCNode` Follows the same pattern as `PlanRemoteProjections.processInternalTry` which solves the identical problem for remote function planning. Differential Revision: D105342187
1 parent 65547f4 commit bdfdf4f

2 files changed

Lines changed: 360 additions & 1 deletion

File tree

presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RpcFunctionOptimizer.java

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@
2424
import com.facebook.presto.spi.plan.ProjectNode;
2525
import com.facebook.presto.spi.relation.CallExpression;
2626
import com.facebook.presto.spi.relation.ConstantExpression;
27+
import com.facebook.presto.spi.relation.InputReferenceExpression;
2728
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
2829
import com.facebook.presto.spi.relation.RowExpression;
30+
import com.facebook.presto.spi.relation.RowExpressionVisitor;
31+
import com.facebook.presto.spi.relation.SpecialFormExpression;
2932
import com.facebook.presto.spi.relation.VariableReferenceExpression;
3033
import com.facebook.presto.sql.planner.TypeProvider;
3134
import com.facebook.presto.sql.planner.plan.RPCNode;
3235
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
3336
import com.fasterxml.jackson.databind.JsonNode;
3437
import com.fasterxml.jackson.databind.ObjectMapper;
3538
import com.google.common.collect.ImmutableList;
39+
import com.google.common.collect.ImmutableMap;
3640
import com.google.common.collect.ImmutableSet;
3741
import io.airlift.slice.Slice;
3842

@@ -232,6 +236,15 @@ public RowExpression rewriteCall(
232236
RpcExtractionContext context,
233237
RowExpressionTreeRewriter<RpcExtractionContext> treeRewriter)
234238
{
239+
if (node.getDisplayName().equals("$internal$try")
240+
&& node.getArguments().size() == 1
241+
&& isLambdaOrBindWithLambda(node.getArguments().get(0))) {
242+
RowExpression result = rewriteTryWithRpcFunction(node, context);
243+
if (result != null) {
244+
return result;
245+
}
246+
}
247+
235248
if (!rpcFunctionNames.contains(
236249
node.getDisplayName().toLowerCase(Locale.ENGLISH))) {
237250
return null;
@@ -258,12 +271,160 @@ public RowExpression rewriteLambda(
258271
RowExpressionTreeRewriter<RpcExtractionContext> treeRewriter)
259272
{
260273
// Do not traverse into lambda bodies — RPC calls inside lambdas
261-
// have per-element semantics incompatible with RPCNode batching
274+
// have per-element semantics incompatible with RPCNode batching.
275+
// TRY lambdas are handled specially in rewriteCall above.
262276
return node;
263277
}
264278
};
265279
}
266280

281+
private static boolean isLambdaOrBindWithLambda(RowExpression expression)
282+
{
283+
if (expression instanceof LambdaDefinitionExpression) {
284+
return true;
285+
}
286+
if (expression instanceof SpecialFormExpression
287+
&& ((SpecialFormExpression) expression).getForm() == SpecialFormExpression.Form.BIND) {
288+
List<RowExpression> bindArgs = ((SpecialFormExpression) expression).getArguments();
289+
return bindArgs.get(bindArgs.size() - 1) instanceof LambdaDefinitionExpression;
290+
}
291+
return false;
292+
}
293+
294+
private RowExpression rewriteTryWithRpcFunction(
295+
CallExpression tryCall,
296+
RpcExtractionContext context)
297+
{
298+
RowExpression tryArgument = tryCall.getArguments().get(0);
299+
LambdaDefinitionExpression lambda;
300+
List<RowExpression> bindVariables;
301+
302+
if (tryArgument instanceof LambdaDefinitionExpression) {
303+
lambda = (LambdaDefinitionExpression) tryArgument;
304+
bindVariables = ImmutableList.of();
305+
}
306+
else {
307+
SpecialFormExpression bind = (SpecialFormExpression) tryArgument;
308+
List<RowExpression> bindArgs = bind.getArguments();
309+
lambda = (LambdaDefinitionExpression) bindArgs.get(bindArgs.size() - 1);
310+
bindVariables = bindArgs.subList(0, bindArgs.size() - 1);
311+
}
312+
313+
RowExpression body = lambda.getBody();
314+
if (!bindVariables.isEmpty()) {
315+
List<String> lambdaParams = lambda.getArguments();
316+
ImmutableMap.Builder<String, RowExpression> substitutions = ImmutableMap.builder();
317+
for (int i = 0; i < lambdaParams.size(); i++) {
318+
substitutions.put(lambdaParams.get(i), bindVariables.get(i));
319+
}
320+
body = VariableSubstitutor.substitute(body, substitutions.build());
321+
}
322+
323+
if (!containsRpcFunction(body)) {
324+
return null;
325+
}
326+
327+
RowExpression rewrittenBody = RowExpressionTreeRewriter.rewriteWith(
328+
createRpcRewriter(), body, context);
329+
330+
return rewrittenBody;
331+
}
332+
333+
private boolean containsRpcFunction(RowExpression expression)
334+
{
335+
if (expression instanceof CallExpression) {
336+
CallExpression call = (CallExpression) expression;
337+
if (rpcFunctionNames.contains(call.getDisplayName().toLowerCase(Locale.ENGLISH))) {
338+
return true;
339+
}
340+
return call.getArguments().stream().anyMatch(this::containsRpcFunction);
341+
}
342+
if (expression instanceof SpecialFormExpression) {
343+
return ((SpecialFormExpression) expression).getArguments().stream()
344+
.anyMatch(this::containsRpcFunction);
345+
}
346+
return false;
347+
}
348+
349+
private static class VariableSubstitutor
350+
implements RowExpressionVisitor<RowExpression, Void>
351+
{
352+
private final Map<String, RowExpression> substitutions;
353+
354+
VariableSubstitutor(Map<String, RowExpression> substitutions)
355+
{
356+
this.substitutions = substitutions;
357+
}
358+
359+
static RowExpression substitute(RowExpression expression, Map<String, RowExpression> substitutions)
360+
{
361+
return expression.accept(new VariableSubstitutor(substitutions), null);
362+
}
363+
364+
@Override
365+
public RowExpression visitCall(CallExpression call, Void context)
366+
{
367+
ImmutableList.Builder<RowExpression> newArgs = ImmutableList.builder();
368+
boolean changed = false;
369+
for (RowExpression arg : call.getArguments()) {
370+
RowExpression newArg = arg.accept(this, null);
371+
newArgs.add(newArg);
372+
changed |= newArg != arg;
373+
}
374+
return changed
375+
? new CallExpression(call.getSourceLocation(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), newArgs.build())
376+
: call;
377+
}
378+
379+
@Override
380+
public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
381+
{
382+
return reference;
383+
}
384+
385+
@Override
386+
public RowExpression visitConstant(ConstantExpression literal, Void context)
387+
{
388+
return literal;
389+
}
390+
391+
@Override
392+
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
393+
{
394+
RowExpression newBody = lambda.getBody().accept(this, null);
395+
if (newBody.equals(lambda.getBody())) {
396+
return lambda;
397+
}
398+
return new LambdaDefinitionExpression(
399+
lambda.getSourceLocation(),
400+
lambda.getArgumentTypes(),
401+
lambda.getArguments(),
402+
newBody);
403+
}
404+
405+
@Override
406+
public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
407+
{
408+
RowExpression replacement = substitutions.get(reference.getName());
409+
return replacement != null ? replacement : reference;
410+
}
411+
412+
@Override
413+
public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context)
414+
{
415+
ImmutableList.Builder<RowExpression> newArgs = ImmutableList.builder();
416+
boolean changed = false;
417+
for (RowExpression arg : specialForm.getArguments()) {
418+
RowExpression newArg = arg.accept(this, null);
419+
newArgs.add(newArg);
420+
changed |= newArg != arg;
421+
}
422+
return changed
423+
? new SpecialFormExpression(specialForm.getForm(), specialForm.getType(), newArgs.build())
424+
: specialForm;
425+
}
426+
}
427+
267428
private static class RpcExtractionContext
268429
{
269430
private final List<ExtractedRpcCall> extractedCalls = new ArrayList<>();

0 commit comments

Comments
 (0)