25#include "llvm/IR/IntrinsicsSPIRV.h"
33#include <unordered_set>
38using BlockSet = std::unordered_set<BasicBlock *>;
39using Edge = std::pair<BasicBlock *, BasicBlock *>;
46 V.partialOrderVisit(Start, std::move(
Op));
53 if (
Node->Entry == BB)
56 for (
auto *Child :
Node->Children) {
67 std::unordered_set<BasicBlock *> ExitTargets;
75 assert(ExitTargets.size() <= 1);
76 if (ExitTargets.size() == 0)
79 return *ExitTargets.begin();
89 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge &&
90 II->getIntrinsicID() != Intrinsic::spv_selection_merge)
104 if (
II->getIntrinsicID() != Intrinsic::spv_loop_merge)
114 for (
auto &
I : Header) {
167 std::vector<Instruction *> Output;
170 Output.push_back(&
I);
191 std::stack<BasicBlock *> ToVisit;
194 ToVisit.push(&Start);
195 Seen.
insert(ToVisit.top());
196 while (ToVisit.size() != 0) {
220 for (
size_t i = 0; i < BI->getNumSuccessors(); i++) {
221 if (BI->getSuccessor(i) == OldTarget)
222 BI->setSuccessor(i, NewTarget);
226 if (BI->isUnconditional())
230 if (BI->getSuccessor(0) != BI->getSuccessor(1))
236 Builder.SetInsertPoint(BI);
237 Builder.CreateBr(BI->getSuccessor(0));
238 BI->eraseFromParent();
249 if (!
II ||
II->getIntrinsicID() != Intrinsic::spv_selection_merge)
253 II->eraseFromParent();
254 if (!
C->isConstantUsed())
255 C->destroyConstant();
272 for (
size_t i = 0; i <
SI->getNumSuccessors(); i++) {
273 if (
SI->getSuccessor(i) == OldTarget)
274 SI->setSuccessor(i, NewTarget);
279 assert(
false &&
"Unhandled terminator type.");
286 struct DivergentConstruct;
290 using ConstructList = std::vector<std::unique_ptr<DivergentConstruct>>;
296 struct DivergentConstruct {
301 DivergentConstruct *Parent =
nullptr;
302 ConstructList Children;
315 Splitter(Function &F, LoopInfo &LI) : F(F), LI(LI) { invalidate(); }
324 std::vector<BasicBlock *> getLoopConstructBlocks(BasicBlock *Header,
327 std::vector<BasicBlock *> Output;
331 if (DT.dominates(
Merge, BB) || !DT.dominates(Header, BB))
333 Output.push_back(BB);
340 std::vector<BasicBlock *>
341 getSelectionConstructBlocks(DivergentConstruct *Node) {
344 OutsideBlocks.insert(
Node->Merge);
346 for (DivergentConstruct *It =
Node->Parent; It !=
nullptr;
348 OutsideBlocks.insert(It->Merge);
350 OutsideBlocks.insert(It->Continue);
353 std::vector<BasicBlock *> Output;
355 if (OutsideBlocks.count(BB) != 0)
357 if (DT.dominates(Node->Merge, BB) || !DT.dominates(Node->Header, BB))
359 Output.push_back(BB);
366 std::vector<BasicBlock *> getSwitchConstructBlocks(BasicBlock *Header,
370 std::vector<BasicBlock *> Output;
373 if (!DT.dominates(Header, BB))
379 Output.push_back(BB);
386 std::vector<BasicBlock *> getCaseConstructBlocks(BasicBlock *Target,
390 std::vector<BasicBlock *> Output;
394 if (!DT.dominates(Target, BB))
400 Output.push_back(BB);
429 createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
430 std::unordered_set<BasicBlock *> Seen;
431 std::vector<Edge> Output;
432 Output.reserve(Edges.size());
434 for (
auto &[Src, Dst] : Edges) {
435 auto [Iterator,
Inserted] = Seen.insert(Src);
440 F.getContext(), Src->getName() +
".new.src", &F);
443 Builder.CreateBr(Dst);
447 Output.emplace_back(Src, Dst);
453 AllocaInst *CreateVariable(Function &F,
Type *
Type,
455 const DataLayout &
DL = F.getDataLayout();
456 return new AllocaInst(
Type,
DL.getAllocaAddrSpace(),
nullptr,
"reg",
462 BasicBlock *createSingleExitNode(BasicBlock *Header,
463 std::vector<Edge> &Edges) {
465 std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
467 std::vector<BasicBlock *> Dsts;
468 std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
470 Header->getName() +
".new.exit", &F);
472 for (
auto &[Src, Dst] : FixedEdges) {
473 if (DstToIndex.count(Dst) != 0)
475 DstToIndex.emplace(Dst, ExitBuilder.getInt32(DstToIndex.size()));
479 if (Dsts.size() == 1) {
480 for (
auto &[Src, Dst] : FixedEdges) {
483 ExitBuilder.CreateBr(Dsts[0]);
487 AllocaInst *
Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
488 F.begin()->getFirstInsertionPt());
489 for (
auto &[Src, Dst] : FixedEdges) {
491 B2.SetInsertPoint(Src->getFirstInsertionPt());
492 B2.CreateStore(DstToIndex[Dst], Variable);
496 Value *
Load = ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
501 if (Dsts.size() == 2) {
504 ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
508 SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
510 Sw->
addCase(DstToIndex[BB], BB);
517 Value *createExitVariable(
519 const DenseMap<BasicBlock *, ConstantInt *> &TargetToValue) {
525 Builder.SetInsertPoint(
T);
531 BI->isConditional() ? BI->getSuccessor(1) :
nullptr;
536 if (
LHS ==
nullptr ||
RHS ==
nullptr)
538 return Builder.CreateSelect(BI->getCondition(),
LHS,
RHS);
549 Builder.CreateUnreachable();
554 bool addMergeForLoops(Function &
F) {
555 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
556 auto *TopLevelRegion =
557 getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
559 .getTopLevelRegion();
583 if (
Merge ==
nullptr) {
587 Merge = CreateUnreachable(
F);
588 Builder.SetInsertPoint(Br);
598 SmallVector<Value *, 2>
Args = {MergeAddress, ContinueAddress};
601 for (
unsigned Imm : LoopControlImms)
602 Args.emplace_back(ConstantInt::get(Builder.getInt32Ty(), Imm));
603 Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {
Args});
613 bool addMergeForNodesWithMultiplePredecessors(Function &
F) {
632 Builder.SetInsertPoint(Header->getTerminator());
635 createOpSelectMerge(&Builder, MergeAddress);
648 bool sortSelectionMerge(Function &
F, BasicBlock &
Block) {
649 std::vector<Instruction *> MergeInstructions;
650 for (Instruction &
I :
Block)
652 MergeInstructions.push_back(&
I);
654 if (MergeInstructions.size() <= 1)
657 Instruction *InsertionPoint = *MergeInstructions.begin();
659 PartialOrderingVisitor Visitor(
F);
660 std::sort(MergeInstructions.begin(), MergeInstructions.end(),
661 [&Visitor](Instruction *
Left, Instruction *
Right) {
664 BasicBlock *RightMerge = getDesignatedMergeBlock(Right);
665 BasicBlock *LeftMerge = getDesignatedMergeBlock(Left);
666 return !Visitor.compare(RightMerge, LeftMerge);
669 for (Instruction *
I : MergeInstructions) {
680 bool sortSelectionMergeHeaders(Function &
F) {
682 for (BasicBlock &BB :
F) {
690 bool splitBlocksWithMultipleHeaders(Function &
F) {
691 std::stack<BasicBlock *> Work;
694 if (MergeInstructions.size() <= 1)
699 const bool Modified = Work.size() > 0;
700 while (Work.size() > 0) {
704 std::vector<Instruction *> MergeInstructions =
706 for (
unsigned i = 1; i < MergeInstructions.size(); i++) {
708 Header->splitBasicBlock(MergeInstructions[i],
"new.header");
715 Builder.SetInsertPoint(BI);
716 Builder.CreateCondBr(Builder.getTrue(), NewBlock, Unreachable);
729 bool addMergeForDivergentBlocks(Function &
F) {
741 std::vector<BasicBlock *> Candidates;
750 if (Candidates.size() <= 1)
759 createOpSelectMerge(&Builder, MergeAddress);
767 std::vector<Edge> getExitsFrom(
const BlockSet &Construct,
768 BasicBlock &Header) {
769 std::vector<Edge> Output;
770 visit(Header, [&](BasicBlock *Item) {
771 if (Construct.count(Item) == 0)
786 void constructDivergentConstruct(
BlockSet &Visited, Splitter &S,
787 BasicBlock *BB, DivergentConstruct *Parent) {
788 if (Visited.count(BB) != 0)
793 if (MIS.size() == 0) {
795 constructDivergentConstruct(Visited, S,
Successor, Parent);
805 auto Output = std::make_unique<DivergentConstruct>();
807 Output->Merge =
Merge;
809 Output->Parent = Parent;
811 constructDivergentConstruct(Visited, S,
Merge, Parent);
813 constructDivergentConstruct(Visited, S,
Continue, Output.get());
816 constructDivergentConstruct(Visited, S,
Successor, Output.get());
819 Parent->Children.emplace_back(std::move(Output));
823 BlockSet getConstructBlocks(Splitter &S, DivergentConstruct *Node) {
826 if (
Node->Continue) {
827 auto LoopBlocks = S.getLoopConstructBlocks(
Node->Header,
Node->Merge);
828 return BlockSet(LoopBlocks.begin(), LoopBlocks.end());
831 auto SelectionBlocks = S.getSelectionConstructBlocks(Node);
832 return BlockSet(SelectionBlocks.begin(), SelectionBlocks.end());
837 bool fixupConstruct(Splitter &S, DivergentConstruct *Node) {
839 for (
auto &Child :
Node->Children)
840 Modified |= fixupConstruct(S, Child.get());
844 if (
Node->Parent ==
nullptr)
850 if (
Node->Parent->Header ==
nullptr)
857 BlockSet ConstructBlocks = getConstructBlocks(S, Node);
858 auto Edges = getExitsFrom(ConstructBlocks, *
Node->Header);
861 if (Edges.size() < 1)
864 bool HasBadEdge =
Node->Merge ==
Node->Parent->Merge ||
865 Node->Merge ==
Node->Parent->Continue;
867 for (
auto &[Src, Dst] : Edges) {
874 if (
Node->Merge == Dst)
879 if (
Node->Continue == Dst)
891 BasicBlock *NewExit = S.createSingleExitNode(
Node->Header, Edges);
898 assert(MergeInstructions.size() == 1);
903 I->setOperand(0, MergeAddress);
911 Node->Merge = NewExit;
917 bool splitCriticalEdges(Function &
F) {
918 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
921 DivergentConstruct Root;
923 constructDivergentConstruct(Visited, S, &*
F.begin(), &Root);
924 return fixupConstruct(S, &Root);
932 bool simplifyBranches(Function &
F) {
935 for (BasicBlock &BB :
F) {
939 if (
SI->getNumCases() > 1)
944 Builder.SetInsertPoint(SI);
946 if (
SI->getNumCases() == 0) {
947 Builder.CreateBr(
SI->getDefaultDest());
951 SI->case_begin()->getCaseValue());
952 Builder.CreateCondBr(Condition,
SI->case_begin()->getCaseSuccessor(),
953 SI->getDefaultDest());
955 SI->eraseFromParent();
964 bool splitSwitchCases(Function &
F) {
967 for (BasicBlock &BB :
F) {
973 Seen.insert(
SI->getDefaultDest());
975 auto It =
SI->case_begin();
976 while (It !=
SI->case_end()) {
978 if (Seen.count(Target) == 0) {
988 Builder.CreateBr(Target);
989 SI->addCase(It->getCaseValue(), NewTarget);
990 It =
SI->removeCase(It);
999 bool removeUselessBlocks(Function &
F) {
1000 std::vector<BasicBlock *>
ToRemove;
1005 for (BasicBlock &BB :
F) {
1012 if (MergeBlocks.count(&BB) != 0 || ContinueBlocks.count(&BB) != 0)
1021 for (BasicBlock *Predecessor : Predecessors)
1032 bool addHeaderToRemainingDivergentDAG(Function &
F) {
1044 for (BasicBlock &BB :
F) {
1045 if (HeaderBlocks.count(&BB) != 0)
1050 size_t CandidateEdges = 0;
1052 if (MergeBlocks.count(
Successor) != 0 ||
1057 CandidateEdges += 1;
1060 if (CandidateEdges <= 1)
1066 bool HasBadBlock =
false;
1067 visit(*Header, [&](
const BasicBlock *Node) {
1072 if (Node == Header || Node ==
Merge)
1075 HasBadBlock |= MergeBlocks.count(Node) != 0 ||
1076 ContinueBlocks.count(Node) != 0 ||
1077 HeaderBlocks.count(Node) != 0;
1078 return !HasBadBlock;
1086 if (
Merge ==
nullptr) {
1089 Builder.SetInsertPoint(Header->getTerminator());
1092 createOpSelectMerge(&Builder, MergeAddress);
1098 SplitInstruction = SplitInstruction->
getPrevNode();
1100 Merge->splitBasicBlockBefore(SplitInstruction,
"new.merge");
1103 Builder.SetInsertPoint(Header->getTerminator());
1106 createOpSelectMerge(&Builder, MergeAddress);
1115 SPIRVStructurizer() : FunctionPass(ID) {}
1135 Modified |= addMergeForNodesWithMultiplePredecessors(
F);
1140 Modified |= sortSelectionMergeHeaders(
F);
1145 Modified |= splitBlocksWithMultipleHeaders(
F);
1151 Modified |= addMergeForDivergentBlocks(
F);
1175 Modified |= addHeaderToRemainingDivergentDAG(
F);
1183 void getAnalysisUsage(AnalysisUsage &AU)
const override {
1186 AU.
addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
1188 AU.
addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
1189 FunctionPass::getAnalysisUsage(AU);
1192 void createOpSelectMerge(
IRBuilder<> *Builder, BlockAddress *MergeAddress) {
1195 MDNode *MDNode = BBTerminatorInst->
getMetadata(
"hlsl.controlflow.hint");
1197 ConstantInt *BranchHint = ConstantInt::get(Builder->
getInt32Ty(), 0);
1201 "invalid metadata hlsl.controlflow.hint");
1205 SmallVector<Value *, 2>
Args = {MergeAddress, BranchHint};
1213char SPIRVStructurizer::ID = 0;
1216 "structurize SPIRV",
false,
false)
1226 return new SPIRVStructurizer();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
ReachingDefAnalysis InstSet & ToRemove
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
This file defines the DenseMap class.
static bool runOnFunction(Function &F, bool PostInlining)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
uint64_t IntrinsicInst * II
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
static BasicBlock * getDesignatedMergeBlock(Instruction *I)
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
static std::vector< Instruction * > getMergeInstructions(BasicBlock &BB)
static BasicBlock * getDesignatedContinueBlock(Instruction *I)
std::unordered_set< BasicBlock * > BlockSet
static const ConvergenceRegion * getRegionForHeader(const ConvergenceRegion *Node, BasicBlock *BB)
static bool hasLoopMergeInstruction(BasicBlock &BB)
static SmallPtrSet< BasicBlock *, 2 > getContinueBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getMergeBlocks(Function &F)
static SmallPtrSet< BasicBlock *, 2 > getHeaderBlocks(Function &F)
static bool isDefinedAsSelectionMergeBy(BasicBlock &Header, BasicBlock &Merge)
static void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
static void partialOrderVisit(BasicBlock &Start, std::function< bool(BasicBlock *)> Op)
static bool isMergeInstruction(Instruction *I)
static BasicBlock * getExitFor(const ConvergenceRegion *CR)
static void replaceIfBranchTargets(BasicBlock *BB, BasicBlock *OldTarget, BasicBlock *NewTarget)
This file defines the SmallPtrSet class.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
LLVM Basic Block Representation.
const Function * getParent() const
Return the enclosing method, or null if none.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
LLVM_ABI SymbolTableList< BasicBlock >::iterator eraseFromParent()
Unlink 'this' from the containing function and delete it.
InstListType::iterator iterator
Instruction iterators...
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
The address of a basic block.
BasicBlock * getBasicBlock() const
static LLVM_ABI BlockAddress * get(Function *F, BasicBlock *BB)
Return a BlockAddress for the specified function and basic block.
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Represents analyses that only rely on functions' control flow.
This is an important base class in LLVM.
LLVM_ABI bool isConstantUsed() const
Return true if the constant has users other than constant expressions and other dangling things.
LLVM_ABI void destroyConstant()
Called if some element of this constant is no longer valid.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
bool dominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
dominates - Returns true iff A dominates B.
void recalculate(ParentType &Func)
recalculate - compute a dominator tree for the given function
DomTreeNodeBase< NodeT > * getNode(const NodeT *BB) const
getNode - return the (Post)DominatorTree node for the specified basic block.
Legacy analysis pass which computes a DominatorTree.
FunctionPass class - This class is used to implement most global optimizations.
IntegerType * getInt32Ty()
Fetch the type representing a 32-bit integer.
BasicBlock * GetInsertBlock() const
LLVM_ABI CallInst * CreateIntrinsic(Intrinsic::ID ID, ArrayRef< Type * > Types, ArrayRef< Value * > Args, FMFSource FMFSource={}, const Twine &Name="")
Create a call to intrinsic ID with Args, mangled using Types.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
A wrapper class for inspecting calls to intrinsic functions.
bool isLoopHeader(const BlockT *BB) const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
const MDOperand & getOperand(unsigned I) const
unsigned getNumOperands() const
Return number of MDNode operands.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserveSet()
Mark an analysis set as preserved.
PreservedAnalyses run(Function &M, FunctionAnalysisManager &AM)
SmallPtrSet< BasicBlock *, 2 > Exits
SmallPtrSet< BasicBlock *, 8 > Blocks
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest)
Add an entry to the switch instruction.
Type * getType() const
All values are typed, get the type of this value.
self_iterator getIterator()
FunctionPassManager manages FunctionPasses.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
@ C
The default llvm calling convention, compatible with C.
PostDomTreeBase< BasicBlock > BBPostDomTree
DomTreeBase< BasicBlock > BBDomTree
@ BasicBlock
Various leaf nodes.
std::enable_if_t< detail::IsValidPointer< X, Y >::value, X * > extract(Y &&MD)
Extract a Value from Metadata.
NodeAddr< NodeBase * > Node
friend class Instruction
Iterator for Instructions in a `BasicBlock.
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
FunctionAddr VTableAddr Value
FunctionPass * createSPIRVStructurizerPass()
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
auto successors(const MachineBasicBlock *BB)
bool sortBlocks(Function &F)
auto pred_size(const MachineBasicBlock *BB)
SmallVector< unsigned, 1 > getSpirvLoopControlOperandsFromLoopMetadata(Loop *L)
auto dyn_cast_or_null(const Y &Val)
auto succ_size(const MachineBasicBlock *BB)
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
DWARFExpression::Operation Op
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
auto predecessors(const MachineBasicBlock *BB)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.