58#include <forward_list>
64#define LLE_OPTION "loop-load-elim"
65#define DEBUG_TYPE LLE_OPTION
68 "runtime-check-per-loop-load-elim",
cl::Hidden,
69 cl::desc(
"Max number of memchecks allowed per eliminated load on average"),
74 cl::desc(
"The maximum number of SCEV checks allowed for Loop "
77STATISTIC(NumLoopLoadEliminted,
"Number of loads eliminated by LLE");
82struct StoreToLoadForwardingCandidate {
87 : Load(Load), Store(Store) {}
92 bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE,
94 Value *LoadPtr = Load->getPointerOperand();
95 Value *StorePtr = Store->getPointerOperand();
97 auto &
DL = Load->getDataLayout();
101 DL.getTypeSizeInBits(LoadType) ==
103 "Should be a known dependence");
105 int64_t StrideLoad =
getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0);
106 int64_t StrideStore =
getPtrStride(PSE, LoadType, StorePtr, L).value_or(0);
107 if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)
117 if (std::abs(StrideLoad) != 1)
120 unsigned TypeByteSize =
DL.getTypeAllocSize(LoadType);
131 const APInt &Val = Dist->getAPInt();
132 return Val == TypeByteSize * StrideLoad;
135 Value *getLoadPtr()
const {
return Load->getPointerOperand(); }
138 friend raw_ostream &
operator<<(raw_ostream &OS,
139 const StoreToLoadForwardingCandidate &Cand) {
140 OS << *Cand.Store <<
" -->\n";
141 OS.
indent(2) << *Cand.Load <<
"\n";
154 L->getLoopLatches(Latches);
162 return Load->getParent() != L->getHeader();
168class LoadEliminationForLoop {
170 LoadEliminationForLoop(Loop *L, LoopInfo *LI,
const LoopAccessInfo &LAI,
171 DominatorTree *DT, BlockFrequencyInfo *BFI,
172 ProfileSummaryInfo* PSI)
173 : L(L), LI(LI), LAI(LAI), DT(DT), BFI(BFI), PSI(PSI), PSE(LAI.getPSE()) {}
180 std::forward_list<StoreToLoadForwardingCandidate>
181 findStoreToLoadDependences(
const LoopAccessInfo &LAI) {
182 std::forward_list<StoreToLoadForwardingCandidate> Candidates;
184 const auto &DepChecker = LAI.getDepChecker();
185 const auto *Deps = DepChecker.getDependences();
193 SmallPtrSet<Instruction *, 4> LoadsWithUnknownDependence;
195 for (
const auto &Dep : *Deps) {
197 Instruction *Destination = Dep.getDestination(DepChecker);
202 LoadsWithUnknownDependence.
insert(Source);
204 LoadsWithUnknownDependence.
insert(Destination);
208 if (Dep.isBackward())
214 assert(Dep.isForward() &&
"Needs to be a forward dependence");
226 Store->getDataLayout()))
229 Candidates.emplace_front(Load, Store);
232 if (!LoadsWithUnknownDependence.
empty())
233 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &
C) {
234 return LoadsWithUnknownDependence.
count(
C.Load);
241 unsigned getInstrIndex(Instruction *Inst) {
242 auto I = InstOrder.find(Inst);
243 assert(
I != InstOrder.end() &&
"No index for instruction");
266 void removeDependencesFromMultipleStores(
267 std::forward_list<StoreToLoadForwardingCandidate> &Candidates) {
270 using LoadToSingleCandT =
271 DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *>;
272 LoadToSingleCandT LoadToSingleCand;
274 for (
const auto &Cand : Candidates) {
276 LoadToSingleCandT::iterator Iter;
278 std::tie(Iter, NewElt) =
279 LoadToSingleCand.insert(std::make_pair(Cand.Load, &Cand));
281 const StoreToLoadForwardingCandidate *&OtherCand = Iter->second;
283 if (OtherCand ==
nullptr)
289 if (Cand.Store->getParent() == OtherCand->Store->
getParent() &&
290 Cand.isDependenceDistanceOfOne(PSE, L) &&
291 OtherCand->isDependenceDistanceOfOne(PSE, L)) {
293 if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))
300 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &Cand) {
301 if (LoadToSingleCand[Cand.Load] != &Cand) {
303 dbgs() <<
"Removing from candidates: \n"
305 <<
" The load may have multiple stores forwarding to "
318 bool needsChecking(
unsigned PtrIdx1,
unsigned PtrIdx2,
319 const SmallPtrSetImpl<Value *> &PtrsWrittenOnFwdingPath,
320 const SmallPtrSetImpl<Value *> &CandLoadPtrs) {
322 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx1).PointerValue;
324 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx2).PointerValue;
325 return ((PtrsWrittenOnFwdingPath.
count(Ptr1) && CandLoadPtrs.
count(Ptr2)) ||
326 (PtrsWrittenOnFwdingPath.
count(Ptr2) && CandLoadPtrs.
count(Ptr1)));
333 SmallPtrSet<Value *, 4> findPointersWrittenOnForwardingPath(
334 const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {
354 [&](
const StoreToLoadForwardingCandidate &
A,
355 const StoreToLoadForwardingCandidate &
B) {
356 return getInstrIndex(
A.Load) <
357 getInstrIndex(
B.Load);
360 StoreInst *FirstStore =
362 [&](
const StoreToLoadForwardingCandidate &
A,
363 const StoreToLoadForwardingCandidate &
B) {
364 return getInstrIndex(
A.Store) <
365 getInstrIndex(
B.Store);
372 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath;
376 PtrsWrittenOnFwdingPath.insert(S->getPointerOperand());
378 const auto &MemInstrs = LAI.getDepChecker().getMemoryInstructions();
379 std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1,
380 MemInstrs.end(), InsertStorePtr);
381 std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)],
384 return PtrsWrittenOnFwdingPath;
389 SmallVector<RuntimePointerCheck, 4> collectMemchecks(
390 const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {
392 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath =
393 findPointersWrittenOnForwardingPath(Candidates);
396 SmallPtrSet<Value *, 4> CandLoadPtrs;
397 for (
const auto &Candidate : Candidates)
398 CandLoadPtrs.
insert(Candidate.getLoadPtr());
400 const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();
401 SmallVector<RuntimePointerCheck, 4> Checks;
403 copy_if(AllChecks, std::back_inserter(Checks),
405 for (
auto PtrIdx1 :
Check.first->Members)
406 for (
auto PtrIdx2 :
Check.second->Members)
407 if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
415 LLVM_DEBUG(LAI.getRuntimePointerChecking()->printChecks(
dbgs(), Checks));
422 propagateStoredValueToLoadUsers(
const StoreToLoadForwardingCandidate &Cand,
441 auto *PH = L->getLoopPreheader();
442 assert(PH &&
"Preheader should exist!");
443 Value *InitialPtr =
SEE.expandCodeFor(PtrSCEV->getStart(),
Ptr->getType(),
444 PH->getTerminator());
446 new LoadInst(Cand.Load->
getType(), InitialPtr,
"load_initial",
448 PH->getTerminator()->getIterator());
456 PHI->insertBefore(L->getHeader()->begin());
457 PHI->addIncoming(Initial, PH);
464 assert(
DL.getTypeSizeInBits(LoadType) ==
DL.getTypeSizeInBits(StoreType) &&
465 "The type sizes should match!");
468 if (LoadType != StoreType) {
470 "store_forward_cast",
478 PHI->addIncoming(StoreValue, L->getLoopLatch());
487 LLVM_DEBUG(
dbgs() <<
"\nIn \"" << L->getHeader()->getParent()->getName()
488 <<
"\" checking " << *L <<
"\n");
509 auto StoreToLoadDependences = findStoreToLoadDependences(LAI);
510 if (StoreToLoadDependences.empty())
515 InstOrder = LAI.getDepChecker().generateInstructionOrderMap();
519 removeDependencesFromMultipleStores(StoreToLoadDependences);
520 if (StoreToLoadDependences.empty())
525 for (
const StoreToLoadForwardingCandidate &Cand : StoreToLoadDependences) {
541 if (!Cand.isDependenceDistanceOfOne(PSE, L))
545 "Loading from something other than indvar?");
548 "Storing to something other than indvar?");
554 <<
". Valid store-to-load forwarding across the loop backedge\n");
556 if (Candidates.
empty())
561 SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);
569 if (LAI.getPSE().getPredicate().getComplexity() >
575 if (!L->isLoopSimplifyForm()) {
580 if (!Checks.
empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) {
581 if (LAI.hasConvergentOp()) {
583 "convergent calls\n");
587 auto *HeaderBB = L->getHeader();
589 PGSOQueryType::IRPass)) {
591 dbgs() <<
"Versioning is needed but not allowed when optimizing "
599 LoopVersioning LV(LAI, Checks, L, LI, DT, PSE.getSE());
604 auto NoLongerGoodCandidate = [
this](
605 const StoreToLoadForwardingCandidate &Cand) {
616 SCEVExpander
SEE(*PSE.getSE(), L->getHeader()->getDataLayout(),
618 for (
const auto &Cand : Candidates)
619 propagateStoredValueToLoadUsers(Cand,
SEE);
620 NumLoopLoadEliminted += Candidates.size();
630 DenseMap<Instruction *, unsigned> InstOrder;
634 const LoopAccessInfo &LAI;
636 BlockFrequencyInfo *BFI;
637 ProfileSummaryInfo *PSI;
638 PredicatedScalarEvolution PSE;
658 for (
Loop *TopLevelLoop : LI)
662 if (L->isInnermost())
667 for (
Loop *L : Worklist) {
669 if (!L->isRotatedForm() || !L->getExitingBlock())
672 LoadEliminationForLoop LEL(L, &LI, LAIs.
getInfo(*L), &DT, BFI, PSI);
692 auto *BFI = (PSI && PSI->hasProfileSummary()) ?
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This is the interface for a simple mod/ref and alias analysis over globals.
This header defines various interfaces for pass management in LLVM.
This header provides classes for managing per-loop analyses.
static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, ScalarEvolution *SE, AssumptionCache *AC, LoopAccessInfoManager &LAIs)
static cl::opt< unsigned > LoadElimSCEVCheckThreshold("loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, cl::desc("The maximum number of SCEV checks allowed for Loop " "Load Elimination"))
static bool isLoadConditional(LoadInst *Load, Loop *L)
Return true if the load is not executed on all paths in the loop.
static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT)
Check if the store dominates all latches, so as long as there is no intervening store this value will...
static cl::opt< unsigned > CheckPerElim("runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1))
This header defines the LoopLoadEliminationPass object.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
Analysis pass which computes BlockFrequencyInfo.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static LLVM_ABI bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
static LLVM_ABI CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
static DebugLoc getDropped()
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const DataLayout & getDataLayout() const
Get the data layout of the module this instruction belongs to.
An instruction for reading from memory.
Value * getPointerOperand()
Align getAlign() const
Return the alignment of the access that is being performed.
This analysis provides dependence information for the memory accesses of a loop.
LLVM_ABI const LoopAccessInfo & getInfo(Loop &L, bool AllowPartial=false)
Analysis pass that exposes the LoopInfo for a function.
Represents a single loop in the control flow graph.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
ScalarEvolution * getSE() const
Returns the ScalarEvolution analysis used.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
An analysis pass based on the new PM to deliver ProfileSummaryInfo.
Analysis providing profile information.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Value * getValueOperand()
Value * getPointerOperand()
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
const ParentTy * getParent() const
self_iterator getIterator()
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
@ C
The default llvm calling convention, compatible with C.
initializer< Ty > init(const Ty &Val)
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, MemorySSAUpdater *MSSAU, bool PreserveLCSSA)
Simplify each loop in a loop nest recursively.
FunctionAddr VTableAddr Value
auto min_element(R &&Range)
Provide wrappers to std::min_element which take ranges instead of having to pass begin/end explicitly...
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
std::pair< const RuntimeCheckingPtrGroup *, const RuntimeCheckingPtrGroup * > RuntimePointerCheck
A memcheck which made up of a pair of grouped pointers.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
OuterAnalysisManagerProxy< ModuleAnalysisManager, Function > ModuleAnalysisManagerFunctionProxy
Provide the ModuleAnalysisManager to Function proxy.
LLVM_ABI bool shouldOptimizeForSize(const MachineFunction *MF, ProfileSummaryInfo *PSI, const MachineBlockFrequencyInfo *BFI, PGSOQueryType QueryType=PGSOQueryType::Other)
Returns true if machine function MF is suggested to be size-optimized based on the profile.
OutputIt copy_if(R &&Range, OutputIt Out, UnaryPredicate P)
Provide wrappers to std::copy_if which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
LLVM_ABI std::optional< int64_t > getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, const DenseMap< Value *, const SCEV * > &StridesMap=DenseMap< Value *, const SCEV * >(), bool Assume=false, bool ShouldCheckWrap=true)
If the pointer has a constant stride return it in units of the access type size.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
void erase_if(Container &C, UnaryPredicate P)
Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...
Type * getLoadStoreType(const Value *I)
A helper function that returns the type of a load or store instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)