LLVM 22.0.0git
JumpTableToSwitch.cpp
Go to the documentation of this file.
1//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/DenseSet.h"
11#include "llvm/ADT/STLExtras.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/LLVMContext.h"
23#include "llvm/Support/Error.h"
25#include <limits>
26
27using namespace llvm;
28
30 JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
31 cl::desc("Only split jump tables with size less or "
32 "equal than JumpTableSizeThreshold."),
33 cl::init(10));
34
35// TODO: Consider adding a cost model for profitability analysis of this
36// transformation. Currently we replace a jump table with a switch if all the
37// functions in the jump table are smaller than the provided threshold.
39 "jump-table-to-switch-function-size-threshold", cl::Hidden,
40 cl::desc("Only split jump tables containing functions whose sizes are less "
41 "or equal than this threshold."),
42 cl::init(50));
43
45
46#define DEBUG_TYPE "jump-table-to-switch"
47
48namespace {
49struct JumpTableTy {
50 Value *Index;
52};
53} // anonymous namespace
54
55static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
56 PointerType *PtrTy) {
57 Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
58 if (!Ptr)
59 return std::nullopt;
60
62 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
63 return std::nullopt;
64
65 Function &F = *GEP->getParent()->getParent();
66 const DataLayout &DL = F.getDataLayout();
67 const unsigned BitWidth =
68 DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
70 APInt ConstantOffset(BitWidth, 0);
71 if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
72 return std::nullopt;
73 if (VariableOffsets.size() != 1)
74 return std::nullopt;
75 // TODO: consider supporting more general patterns
76 if (!ConstantOffset.isZero())
77 return std::nullopt;
78 APInt StrideBytes = VariableOffsets.front().second;
79 const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
80 if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
81 return std::nullopt;
82 const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
84 return std::nullopt;
85
86 JumpTableTy JumpTable;
87 JumpTable.Index = VariableOffsets.front().first;
88 JumpTable.Funcs.reserve(N);
89 for (uint64_t Index = 0; Index < N; ++Index) {
90 // ConstantOffset is zero.
91 APInt Offset = Index * StrideBytes;
92 Constant *C =
94 auto *Func = dyn_cast_or_null<Function>(C);
95 if (!Func || Func->isDeclaration() ||
96 Func->getInstructionCount() > FunctionSizeThreshold)
97 return std::nullopt;
98 JumpTable.Funcs.push_back(Func);
99 }
100 return JumpTable;
101}
102
103static BasicBlock *
104expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
107 GetGuidForFunction) {
108 const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
109
111 BasicBlock *BB = CB->getParent();
112 BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
113 BB->getName() + Twine(".tail"));
114 DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
116
117 Function &F = *BB->getParent();
118 BasicBlock *BBUnreachable = BasicBlock::Create(
119 F.getContext(), "default.switch.case.unreachable", &F, Tail);
120 IRBuilder<> BuilderUnreachable(BBUnreachable);
121 BuilderUnreachable.CreateUnreachable();
122
123 IRBuilder<> Builder(BB);
124 SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
125 DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
126
127 IRBuilder<> BuilderTail(CB);
128 PHINode *PHI =
129 IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
130 const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
131
132 SmallVector<uint64_t> BranchWeights;
134 const bool HadProfile = isValueProfileMD(ProfMD);
135 if (HadProfile) {
136 // The assumptions, coming in, are that the functions in JT.Funcs are
137 // defined in this module (from parseJumpTable).
139 JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
140 BranchWeights.reserve(JT.Funcs.size() + 1);
141 // The first is the default target, which is the unreachable block created
142 // above.
143 BranchWeights.push_back(0U);
144 uint64_t TotalCount = 0;
145 auto Targets = getValueProfDataFromInst(
146 *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
147 std::numeric_limits<uint32_t>::max(), TotalCount);
148
149 for (const auto &[G, C] : Targets) {
150 [[maybe_unused]] auto It = GuidToCounter.insert({G, C});
151 assert(It.second);
152 }
153 }
154 for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
155 BasicBlock *B = BasicBlock::Create(Func->getContext(),
156 "call." + Twine(Index), &F, Tail);
157 DTUpdates.push_back({DominatorTree::Insert, BB, B});
158 DTUpdates.push_back({DominatorTree::Insert, B, Tail});
159
161 // The MD_prof metadata (VP kind), if it existed, can be dropped, it doesn't
162 // make sense on a direct call. Note that the values are used for the branch
163 // weights of the switch.
164 Call->setMetadata(LLVMContext::MD_prof, nullptr);
165 Call->setCalledFunction(Func);
166 Call->insertInto(B, B->end());
167 Switch->addCase(
168 cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
169 GlobalValue::GUID FctID = GetGuidForFunction(*Func);
170 // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
171 // just some of the jump targets are taken (for the given profile).
172 BranchWeights.push_back(FctID == 0U ? 0U
173 : GuidToCounter.lookup_or(FctID, 0U));
175 if (PHI)
176 PHI->addIncoming(Call, B);
177 }
178 DTU.applyUpdates(DTUpdates);
179 ORE.emit([&]() {
180 return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
181 << "expanded indirect call into switch";
182 });
183 if (HadProfile && !ProfcheckDisableMetadataFixes) {
184 // At least one of the targets must've been taken.
185 assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; }));
186 setBranchWeights(*Switch, downscaleWeights(BranchWeights),
187 /*IsExpected=*/false);
188 } else
190 if (PHI)
192 CB->eraseFromParent();
193 return Tail;
194}
195
202 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
203 bool Changed = false;
204 InstrProfSymtab Symtab;
205 if (auto E = Symtab.create(*F.getParent()))
206 F.getContext().emitError(
207 "Could not create indirect call table, likely corrupted IR" +
208 toString(std::move(E)));
210 for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
211 FToGuid.insert({FPtr, G});
212
213 for (BasicBlock &BB : make_early_inc_range(F)) {
214 BasicBlock *CurrentBB = &BB;
215 while (CurrentBB) {
216 BasicBlock *SplittedOutTail = nullptr;
217 for (Instruction &I : make_early_inc_range(*CurrentBB)) {
218 auto *Call = dyn_cast<CallInst>(&I);
219 if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
220 continue;
221 auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
222 // Skip atomic or volatile loads.
223 if (!L || !L->isSimple())
224 continue;
225 auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
226 if (!GEP)
227 continue;
228 auto *PtrTy = dyn_cast<PointerType>(L->getType());
229 assert(PtrTy && "call operand must be a pointer");
230 std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
231 if (!JumpTable)
232 continue;
233 SplittedOutTail = expandToSwitch(
234 Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
236 return AssignGUIDPass::getGUID(Fct);
237 return FToGuid.lookup_or(&Fct, 0U);
238 });
239 Changed = true;
240 break;
241 }
242 CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
243 }
244 }
245
246 if (!Changed)
247 return PreservedAnalyses::all();
248
250 if (DT)
252 if (PDT)
254 return PA;
255}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseSet and SmallDenseSet classes.
#define DEBUG_TYPE
Hexagon Common GEP
static cl::opt< unsigned > FunctionSizeThreshold("jump-table-to-switch-function-size-threshold", cl::Hidden, cl::desc("Only split jump tables containing functions whose sizes are less " "or equal than this threshold."), cl::init(50))
static BasicBlock * expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU, OptimizationRemarkEmitter &ORE, llvm::function_ref< GlobalValue::GUID(const Function &)> GetGuidForFunction)
static cl::opt< unsigned > JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, cl::desc("Only split jump tables with size less or " "equal than JumpTableSizeThreshold."), cl::init(10))
static std::optional< JumpTableTy > parseJumpTable(GetElementPtrInst *GEP, PointerType *PtrTy)
#define F(x, y, z)
Definition MD5.cpp:55
#define I(x, y, z)
Definition MD5.cpp:58
#define G(x, y, z)
Definition MD5.cpp:56
This file contains the declarations for profiling metadata utility functions.
This file contains some templates that are useful if you are working with the STL at all.
cl::opt< bool > ProfcheckDisableMetadataFixes("profcheck-disable-metadata-fixes", cl::Hidden, cl::init(false), cl::desc("Disable metadata propagation fixes discovered through Issue #147390"))
This file defines the SmallVector class.
Class for arbitrary precision integers.
Definition APInt.h:78
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1540
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:380
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
static LLVM_ABI uint64_t getGUID(const Function &F)
static LLVM_ABI const char * GUIDMetadataName
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:63
ValueT lookup_or(const_arg_type_t< KeyT > Val, U &&Default) const
Definition DenseMap.h:197
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:214
Analysis pass which computes a DominatorTree.
Definition Dominators.h:284
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:165
void applyUpdates(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
MDNode * getMetadata(unsigned KindID) const
Get the current metadata attachments for the given kind, if any.
Definition Value.h:576
uint64_t GUID
Declare a type to represent a global unique identifier for a global value.
Type * getValueType() const
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
bool isConstant() const
If the value is a global constant, its value is immutable throughout the runtime execution of the pro...
bool hasDefinitiveInitializer() const
hasDefinitiveInitializer - Whether the global variable has an initializer, and any other instances of...
UnreachableInst * CreateUnreachable()
Definition IRBuilder.h:1339
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2494
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2780
A symbol table used for function [IR]PGO name look-up with keys (such as pointers,...
Definition InstrProf.h:506
const std::vector< std::pair< uint64_t, Function * > > & getIDToNameMap() const
Definition InstrProf.h:668
LLVM_ABI Error create(object::SectionRef &Section)
Create InstrProfSymtab from an object file section which contains function PGO names.
LLVM_ABI Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
size_type size() const
Definition MapVector.h:56
std::pair< KeyT, ValueT > & front()
Definition MapVector.h:79
The optimization diagnostic interface.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
Analysis pass which computes a PostDominatorTree.
PostDominatorTree Class - Concrete subclass of DominatorTree that is used to compute the post-dominat...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
void reserve(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Multiway switch.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
static LLVM_ABI Type * getVoidTy(LLVMContext &C)
Definition Type.cpp:281
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:546
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1101
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
CallInst * Call
Changed
@ Tail
Attemps to make calls as fast as possible while guaranteeing that tail call optimization can always b...
Definition CallingConv.h:76
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
@ Offset
Definition DWP.cpp:477
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1707
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2454
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:649
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:626
LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName)
Specify that the branch weights for this terminator cannot be known at compile time.
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:759
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1714
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef< uint32_t > Weights, bool IsExpected)
Create a new branch_weights metadata node and add or overwrite a prof metadata reference to instructi...
LLVM_ABI Constant * ConstantFoldLoadFromConst(Constant *C, Type *Ty, const APInt &Offset, const DataLayout &DL)
Extract value of C at the given Offset reinterpreted as Ty.
LLVM_ABI SmallVector< InstrProfValueData, 4 > getValueProfDataFromInst(const Instruction &Inst, InstrProfValueKind ValueKind, uint32_t MaxNumValueData, uint64_t &TotalC, bool GetNoICPValue=false)
Extract the value profile data from Inst and returns them if Inst is annotated with value profile dat...
LLVM_ABI bool isValueProfileMD(const MDNode *ProfileData)
Checks if an MDNode contains value profiling Metadata.
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:565
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI SmallVector< uint32_t > downscaleWeights(ArrayRef< uint64_t > Weights, std::optional< uint64_t > KnownMaxCount=std::nullopt)
downscale the given weights preserving the ratio.
#define N
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Run the pass over the function.
A MapVector that performs no allocations if smaller than a certain size.
Definition MapVector.h:249