30#define DEBUG_TYPE "dxil-flatten-arrays"
35class DXILFlattenArraysLegacy :
public ModulePass {
38 bool runOnModule(
Module &M)
override;
46 Value *RootPointerOperand;
51class DXILFlattenArraysVisitor
52 :
public InstVisitor<DXILFlattenArraysVisitor, bool> {
54 DXILFlattenArraysVisitor(
56 : GlobalMap(GlobalMap) {}
64 bool visitICmpInst(
ICmpInst &ICI) {
return false; }
65 bool visitFCmpInst(
FCmpInst &FCI) {
return false; }
68 bool visitCastInst(
CastInst &CI) {
return false; }
69 bool visitBitCastInst(
BitCastInst &BCI) {
return false; }
73 bool visitPHINode(
PHINode &
PHI) {
return false; }
76 bool visitCallInst(
CallInst &ICI) {
return false; }
77 bool visitFreezeInst(
FreezeInst &FI) {
return false; }
78 static bool isMultiDimensionalArray(
Type *
T);
79 static std::pair<unsigned, Type *> getElementCountAndType(
Type *ArrayTy);
95bool DXILFlattenArraysVisitor::finish() {
96 GEPChainInfoMap.clear();
101bool DXILFlattenArraysVisitor::isMultiDimensionalArray(
Type *
T) {
107std::pair<unsigned, Type *>
108DXILFlattenArraysVisitor::getElementCountAndType(
Type *ArrayTy) {
109 unsigned TotalElements = 1;
110 Type *CurrArrayTy = ArrayTy;
112 TotalElements *= InnerArrayTy->getNumElements();
113 CurrArrayTy = InnerArrayTy->getElementType();
115 return std::make_pair(TotalElements, CurrArrayTy);
118ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices(
121 "Indicies and dimmensions should be the same");
122 unsigned FlatIndex = 0;
123 unsigned Multiplier = 1;
125 for (
int I = Indices.
size() - 1;
I >= 0; --
I) {
126 unsigned DimSize = Dims[
I];
128 assert(CIndex &&
"This function expects all indicies to be ConstantInt");
130 Multiplier *= DimSize;
135Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices(
137 if (Indices.
size() == 1)
141 unsigned Multiplier = 1;
143 for (
int I = Indices.
size() - 1;
I >= 0; --
I) {
144 unsigned DimSize = Dims[
I];
147 FlatIndex = Builder.
CreateAdd(FlatIndex, ScaledIndex);
148 Multiplier *= DimSize;
153bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) {
155 for (
unsigned I = 0;
I < NumOperands; ++
I) {
158 if (CE &&
CE->getOpcode() == Instruction::GetElementPtr) {
159 GetElementPtrInst *OldGEP =
169 visitGetElementPtrInst(*OldGEP);
176bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) {
177 unsigned NumOperands =
SI.getNumOperands();
178 for (
unsigned I = 0;
I < NumOperands; ++
I) {
179 Value *CurrOpperand =
SI.getOperand(
I);
181 if (CE &&
CE->getOpcode() == Instruction::GetElementPtr) {
182 GetElementPtrInst *OldGEP =
187 StoreInst *NewStore = Builder.
CreateStore(
SI.getValueOperand(), OldGEP);
189 SI.replaceAllUsesWith(NewStore);
190 SI.eraseFromParent();
191 visitGetElementPtrInst(*OldGEP);
198bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) {
204 auto [TotalElements,
BaseType] = getElementCountAndType(ArrType);
207 AllocaInst *FlatAlloca =
215bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &
GEP) {
220 Value *PtrOperand =
GEP.getPointerOperand();
225 "Pointer operand of GEP should not be a PHI Node");
230 PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) {
231 GetElementPtrInst *OldGEPI =
238 Builder.
CreateGEP(
GEP.getSourceElementType(), OldGEPI, Indices,
239 GEP.getName(),
GEP.getNoWrapFlags());
241 "Expected newly-created GEP to be an instruction");
244 GEP.replaceAllUsesWith(NewGEPI);
245 GEP.eraseFromParent();
246 visitGetElementPtrInst(*OldGEPI);
247 visitGetElementPtrInst(*NewGEPI);
255 const DataLayout &
DL =
GEP.getDataLayout();
256 unsigned BitWidth =
DL.getIndexTypeSizeInBits(
GEP.getType());
258 [[maybe_unused]]
bool Success =
GEP.collectOffset(
270 if (!GEPChainInfoMap.contains(PtrOpGEP))
273 GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP];
274 Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType;
275 Info.RootPointerOperand = PGEPInfo.RootPointerOperand;
276 for (
auto &VariableOffset : PGEPInfo.VariableOffsets)
277 Info.VariableOffsets.insert(VariableOffset);
278 Info.ConstantOffset += PGEPInfo.ConstantOffset;
280 Info.RootPointerOperand = PtrOperand;
285 Type *RootTy =
GEP.getSourceElementType();
287 if (GlobalMap.contains(GlobalVar))
292 RootTy = Alloca->getAllocatedType();
293 assert(!isMultiDimensionalArray(RootTy) &&
294 "Expected root array type to be flattened");
306 bool ReplaceThisGEP =
GEP.users().empty();
309 ReplaceThisGEP =
true;
311 if (ReplaceThisGEP) {
312 unsigned BytesPerElem =
313 DL.getTypeAllocSize(
Info.RootFlattenedArrayType->getArrayElementType());
315 "Bytes per element should be a power of 2");
321 uint64_t ConstantOffset =
323 assert(ConstantOffset < UINT32_MAX &&
324 "Constant byte offset for flat GEP index must fit within 32 bits");
326 for (
auto [VarIndex, Multiplier] :
Info.VariableOffsets) {
327 assert(Multiplier.getActiveBits() <= 32 &&
328 "The multiplier for a flat GEP index must fit within 32 bits");
329 assert(VarIndex->getType()->isIntegerTy(32) &&
330 "Expected i32-typed GEP indices");
332 if (Multiplier.getZExtValue() % BytesPerElem != 0) {
337 Builder.
getInt32(Multiplier.getZExtValue()));
342 Builder.
getInt32(Multiplier.getZExtValue() / BytesPerElem));
343 FlattenedIndex = Builder.
CreateAdd(FlattenedIndex, VI);
348 Info.RootFlattenedArrayType,
Info.RootPointerOperand,
349 {ZeroIndex, FlattenedIndex},
GEP.getName(),
GEP.getNoWrapFlags());
357 Info.RootFlattenedArrayType,
Info.RootPointerOperand,
358 {ZeroIndex, FlattenedIndex},
GEP.getNoWrapFlags(),
GEP.getName(),
364 GEP.replaceAllUsesWith(NewGEP);
365 GEP.eraseFromParent();
373 PotentiallyDeadInstrs.emplace_back(&
GEP);
377bool DXILFlattenArraysVisitor::visit(Function &
F) {
378 bool MadeChange =
false;
379 ReversePostOrderTraversal<Function *> RPOT(&
F);
393 Elements.push_back(
Init);
396 unsigned ArrSize = ArrayTy->getNumElements();
398 for (
unsigned I = 0;
I < ArrSize; ++
I)
405 for (
unsigned I = 0;
I < ArrayConstant->getNumOperands(); ++
I) {
409 for (
unsigned I = 0;
I < DataArrayConstant->getNumElements(); ++
I) {
414 "Expected a ConstantArray or ConstantDataArray for array initializer!");
434 assert(FlattenedType->getNumElements() == FlattenedElements.
size() &&
435 "The number of collected elements should match the FlattenedType");
443 Type *OrigType =
G.getValueType();
444 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType))
449 DXILFlattenArraysVisitor::getElementCountAndType(ArrType);
456 nullptr,
G.getName() +
".1dim", &
G,
457 G.getThreadLocalMode(),
G.getAddressSpace(),
458 G.isExternallyInitialized());
462 if (
G.getAlignment() > 0) {
466 if (
G.hasInitializer()) {
472 GlobalMap[&
G] = NewGlobal;
477 bool MadeChange =
false;
480 DXILFlattenArraysVisitor Impl(GlobalMap);
482 if (
F.isDeclaration())
484 MadeChange |= Impl.visit(
F);
486 for (
auto &[Old, New] : GlobalMap) {
487 Old->replaceAllUsesWith(New);
488 Old->eraseFromParent();
502bool DXILFlattenArraysLegacy::runOnModule(
Module &M) {
506char DXILFlattenArraysLegacy::ID = 0;
509 "DXIL Array Flattener",
false,
false)
514 return new DXILFlattenArraysLegacy();
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Analysis containing CSE Info
static Constant * transformInitializer(Constant *Init, Type *OrigType, Type *NewType, LLVMContext &Ctx)
static void collectElements(Constant *Init, SmallVectorImpl< Constant * > &Elements)
static bool flattenArrays(Module &M)
static void flattenGlobalArrays(Module &M, SmallDenseMap< GlobalVariable *, GlobalVariable * > &GlobalMap)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
BaseType
A given derived pointer can have multiple base pointers through phi/selects.
Class for arbitrary precision integers.
an instruction to allocate memory on the stack
Align getAlign() const
Return the alignment of the memory that is being allocated by the instruction.
Type * getAllocatedType() const
Return the type that is being allocated by the instruction.
void setAlignment(Align Align)
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
static LLVM_ABI ArrayType * get(Type *ElementType, uint64_t NumElements)
This static method is the primary way to construct an ArrayType.
This class represents a no-op cast from one type to another.
This class represents a function call, abstracting a target machine's calling convention.
This is the base class for all instructions that perform data casts.
static LLVM_ABI ConstantAggregateZero * get(Type *Ty)
static LLVM_ABI Constant * get(ArrayType *T, ArrayRef< Constant * > V)
This is the shared class of boolean and integer constants.
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
This is an important base class in LLVM.
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &)
This instruction compares its operands according to the predicate given to the constructor.
This class represents a freeze function that returns random concrete value if an operand is either a ...
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
static GetElementPtrInst * Create(Type *PointeeType, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
void setUnnamedAddr(UnnamedAddr Val)
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
void setAlignment(Align Align)
Sets the alignment attribute of the GlobalVariable.
This instruction compares its operands according to the predicate given to the constructor.
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
BasicBlock::iterator GetInsertPoint() const
Value * CreateLShr(Value *LHS, Value *RHS, const Twine &Name="", bool isExact=false)
Value * CreateGEP(Type *Ty, Value *Ptr, ArrayRef< Value * > IdxList, const Twine &Name="", GEPNoWrapFlags NW=GEPNoWrapFlags::none())
ConstantInt * getInt32(uint32_t C)
Get a constant 32-bit value.
LoadInst * CreateLoad(Type *Ty, Value *Ptr, const char *Name)
Provided to resolve 'CreateLoad(Ty, Ptr, "...")' correctly, instead of converting the string to 'bool...
StoreInst * CreateStore(Value *Val, Value *Ptr, bool isVolatile=false)
Value * CreateAdd(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
Value * CreateMul(Value *LHS, Value *RHS, const Twine &Name="", bool HasNUW=false, bool HasNSW=false)
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
This instruction inserts a single (scalar) element into a VectorType value.
Base class for instruction visitors.
void visit(Iterator Start, Iterator End)
LLVM_ABI void insertBefore(InstListType::iterator InsertPos)
Insert an unlinked instruction into a basic block immediately before the specified position.
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
void setAlignment(Align Align)
Align getAlign() const
Return the alignment of the access that is being performed.
ModulePass class - This class is used to implement unstructured interprocedural optimizations and ana...
A Module instance is used to store all the information related to an LLVM module.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class represents the LLVM 'select' instruction.
This instruction constructs a fixed permutation of two input vectors.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
void setAlignment(Align Align)
The instances of the Type class are immutable: once they are created, they are never changed.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Value * getOperand(unsigned i) const
unsigned getNumOperands() const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ CE
Windows NT (Windows on ARM)
This is an optimization pass for GlobalISel generic memory operations.
FunctionAddr VTableAddr Value
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
ModulePass * createDXILFlattenArraysLegacyPass()
Pass to flatten arrays into a one dimensional DXIL legal form.
unsigned Log2_32(uint32_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
constexpr bool isPowerOf2_32(uint32_t Value)
Return true if the argument is a power of two > 0.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI bool RecursivelyDeleteTriviallyDeadInstructionsPermissive(SmallVectorImpl< WeakTrackingVH > &DeadInsts, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr, std::function< void(Value *)> AboutToDeleteCallback=std::function< void(Value *)>())
Same functionality as RecursivelyDeleteTriviallyDeadInstructions, but allow instructions that are not...
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
A MapVector that performs no allocations if smaller than a certain size.