@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
112112
113113bool LLVM::LoadOp::storesTo (const MemorySlot &slot) { return false ; }
114114
115- Value LLVM::LoadOp::getStored (const MemorySlot &slot, RewriterBase &rewriter) {
115+ Value LLVM::LoadOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
116+ const DataLayout &dataLayout) {
116117 llvm_unreachable (" getStored should not be called on LoadOp" );
117118}
118119
@@ -122,37 +123,121 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122123 return getAddr () == slot.ptr ;
123124}
124125
125- // / Checks that two types are the same or can be cast into one another.
126- static bool areCastCompatible (const DataLayout &layout, Type lhs, Type rhs) {
127- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129- layout.getTypeSize (lhs) == layout.getTypeSize (rhs));
126+ // / Checks if `type` can be used in any kind of conversion sequences.
127+ static bool isSupportedTypeForConversion (Type type) {
128+ // Aggregate types are not bitcastable.
129+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130+ return false ;
131+
132+ // LLVM vector types are only used for either pointers or target specific
133+ // types. These types cannot be casted in the general case, thus the memory
134+ // optimizations do not support them.
135+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136+ return false ;
137+
138+ // Scalable types are not supported.
139+ if (auto vectorType = dyn_cast<VectorType>(type))
140+ return !vectorType.isScalable ();
141+ return true ;
130142}
131143
144+ // / Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145+ // / truncations.
146+ static bool areConversionCompatible (const DataLayout &layout, Type targetType,
147+ Type srcType) {
148+ if (targetType == srcType)
149+ return true ;
150+
151+ if (!isSupportedTypeForConversion (targetType) ||
152+ !isSupportedTypeForConversion (srcType))
153+ return false ;
154+
155+ // Pointer casts will only be sane when the bitsize of both pointer types is
156+ // the same.
157+ if (isa<LLVM::LLVMPointerType>(targetType) &&
158+ isa<LLVM::LLVMPointerType>(srcType))
159+ return layout.getTypeSize (targetType) == layout.getTypeSize (srcType);
160+
161+ return layout.getTypeSize (targetType) <= layout.getTypeSize (srcType);
162+ }
163+
164+ // / Checks if `dataLayout` describes a little endian layout.
165+ static bool isBigEndian (const DataLayout &dataLayout) {
166+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness ());
167+ return endiannessStr && endiannessStr == " big" ;
168+ }
169+
170+ // / The size of a byte in bits.
171+ constexpr const static uint64_t kBitsInByte = 8 ;
172+
132173// / Constructs operations that convert `inputValue` into a new value of type
133174// / `targetType`. Assumes that this conversion is possible.
134175static Value createConversionSequence (RewriterBase &rewriter, Location loc,
135- Value inputValue, Type targetType) {
136- if (inputValue.getType () == targetType)
137- return inputValue;
138-
139- if (!isa<LLVM::LLVMPointerType>(targetType) &&
140- !isa<LLVM::LLVMPointerType>(inputValue.getType ()))
141- return rewriter.createOrFold <LLVM::BitcastOp>(loc, targetType, inputValue);
176+ Value srcValue, Type targetType,
177+ const DataLayout &dataLayout) {
178+ // Get the types of the source and target values.
179+ Type srcType = srcValue.getType ();
180+ assert (areConversionCompatible (dataLayout, targetType, srcType) &&
181+ " expected that the compatibility was checked before" );
182+
183+ uint64_t srcTypeSize = dataLayout.getTypeSize (srcType);
184+ uint64_t targetTypeSize = dataLayout.getTypeSize (targetType);
185+
186+ // Nothing has to be done if the types are already the same.
187+ if (srcType == targetType)
188+ return srcValue;
189+
190+ // In the special case of casting one pointer to another, we want to generate
191+ // an address space cast. Bitcasts of pointers are not allowed and using
192+ // pointer to integer conversions are not equivalent due to the loss of
193+ // provenance.
194+ if (isa<LLVM::LLVMPointerType>(targetType) &&
195+ isa<LLVM::LLVMPointerType>(srcType))
196+ return rewriter.createOrFold <LLVM::AddrSpaceCastOp>(loc, targetType,
197+ srcValue);
198+
199+ IntegerType valueSizeInteger =
200+ rewriter.getIntegerType (srcTypeSize * kBitsInByte );
201+ Value replacement = srcValue;
202+
203+ // First, cast the value to a same-sized integer type.
204+ if (isa<LLVM::LLVMPointerType>(srcType))
205+ replacement = rewriter.createOrFold <LLVM::PtrToIntOp>(loc, valueSizeInteger,
206+ replacement);
207+ else if (replacement.getType () != valueSizeInteger)
208+ replacement = rewriter.createOrFold <LLVM::BitcastOp>(loc, valueSizeInteger,
209+ replacement);
210+
211+ // Truncate the integer if the size of the target is less than the value.
212+ if (targetTypeSize != srcTypeSize) {
213+ if (isBigEndian (dataLayout)) {
214+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte ;
215+ auto shiftConstant = rewriter.create <LLVM::ConstantOp>(
216+ loc, rewriter.getIntegerAttr (srcType, shiftAmount));
217+ replacement =
218+ rewriter.createOrFold <LLVM::LShrOp>(loc, srcValue, shiftConstant);
219+ }
142220
143- if (!isa<LLVM::LLVMPointerType>(targetType))
144- return rewriter.createOrFold <LLVM::PtrToIntOp>(loc, targetType, inputValue);
221+ replacement = rewriter.create <LLVM::TruncOp>(
222+ loc, rewriter.getIntegerType (targetTypeSize * kBitsInByte ),
223+ replacement);
224+ }
145225
146- if (!isa<LLVM::LLVMPointerType>(inputValue.getType ()))
147- return rewriter.createOrFold <LLVM::IntToPtrOp>(loc, targetType, inputValue);
226+ // Now cast the integer to the actual target type if required.
227+ if (isa<LLVM::LLVMPointerType>(targetType))
228+ replacement =
229+ rewriter.createOrFold <LLVM::IntToPtrOp>(loc, targetType, replacement);
230+ else if (replacement.getType () != targetType)
231+ replacement =
232+ rewriter.createOrFold <LLVM::BitcastOp>(loc, targetType, replacement);
148233
149- return rewriter.createOrFold <LLVM::AddrSpaceCastOp>(loc, targetType,
150- inputValue);
234+ return replacement;
151235}
152236
153- Value LLVM::StoreOp::getStored (const MemorySlot &slot, RewriterBase &rewriter) {
154- return createConversionSequence (rewriter, getLoc (), getValue (),
155- slot.elemType );
237+ Value LLVM::StoreOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
238+ const DataLayout &dataLayout) {
239+ return createConversionSequence (rewriter, getLoc (), getValue (), slot.elemType ,
240+ dataLayout);
156241}
157242
158243bool LLVM::LoadOp::canUsesBeRemoved (
@@ -167,17 +252,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
167252 // be removed (provided it loads the exact stored value and is not
168253 // volatile).
169254 return blockingUse == slot.ptr && getAddr () == slot.ptr &&
170- areCastCompatible (dataLayout, getResult ().getType (), slot.elemType ) &&
255+ areConversionCompatible (dataLayout, getResult ().getType (),
256+ slot.elemType ) &&
171257 !getVolatile_ ();
172258}
173259
174260DeletionKind LLVM::LoadOp::removeBlockingUses (
175261 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176- RewriterBase &rewriter, Value reachingDefinition) {
262+ RewriterBase &rewriter, Value reachingDefinition,
263+ const DataLayout &dataLayout) {
177264 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
178265 // pointer.
179- Value newResult = createConversionSequence (
180- rewriter, getLoc (), reachingDefinition, getResult ().getType ());
266+ Value newResult =
267+ createConversionSequence (rewriter, getLoc (), reachingDefinition,
268+ getResult ().getType (), dataLayout);
181269 rewriter.replaceAllUsesWith (getResult (), newResult);
182270 return DeletionKind::Delete;
183271}
@@ -194,13 +282,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
194282 // store OF the slot pointer, only INTO the slot pointer.
195283 return blockingUse == slot.ptr && getAddr () == slot.ptr &&
196284 getValue () != slot.ptr &&
197- areCastCompatible (dataLayout, slot.elemType , getValue ().getType ()) &&
285+ areConversionCompatible (dataLayout, slot.elemType ,
286+ getValue ().getType ()) &&
198287 !getVolatile_ ();
199288}
200289
201290DeletionKind LLVM::StoreOp::removeBlockingUses (
202291 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
203- RewriterBase &rewriter, Value reachingDefinition) {
292+ RewriterBase &rewriter, Value reachingDefinition,
293+ const DataLayout &dataLayout) {
204294 return DeletionKind::Delete;
205295}
206296
@@ -747,8 +837,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
747837 return getDst () == slot.ptr ;
748838}
749839
750- Value LLVM::MemsetOp::getStored (const MemorySlot &slot,
751- RewriterBase &rewriter ) {
840+ Value LLVM::MemsetOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
841+ const DataLayout &dataLayout ) {
752842 // TODO: Support non-integer types.
753843 return TypeSwitch<Type, Value>(slot.elemType )
754844 .Case ([&](IntegerType intType) -> Value {
@@ -802,7 +892,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
802892
803893DeletionKind LLVM::MemsetOp::removeBlockingUses (
804894 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
805- RewriterBase &rewriter, Value reachingDefinition) {
895+ RewriterBase &rewriter, Value reachingDefinition,
896+ const DataLayout &dataLayout) {
806897 return DeletionKind::Delete;
807898}
808899
@@ -1059,8 +1150,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
10591150 return memcpyStoresTo (*this , slot);
10601151}
10611152
1062- Value LLVM::MemcpyOp::getStored (const MemorySlot &slot,
1063- RewriterBase &rewriter ) {
1153+ Value LLVM::MemcpyOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
1154+ const DataLayout &dataLayout ) {
10641155 return memcpyGetStored (*this , slot, rewriter);
10651156}
10661157
@@ -1074,7 +1165,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
10741165
10751166DeletionKind LLVM::MemcpyOp::removeBlockingUses (
10761167 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1077- RewriterBase &rewriter, Value reachingDefinition) {
1168+ RewriterBase &rewriter, Value reachingDefinition,
1169+ const DataLayout &dataLayout) {
10781170 return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
10791171 reachingDefinition);
10801172}
@@ -1109,7 +1201,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
11091201}
11101202
11111203Value LLVM::MemcpyInlineOp::getStored (const MemorySlot &slot,
1112- RewriterBase &rewriter) {
1204+ RewriterBase &rewriter,
1205+ const DataLayout &dataLayout) {
11131206 return memcpyGetStored (*this , slot, rewriter);
11141207}
11151208
@@ -1123,7 +1216,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
11231216
11241217DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses (
11251218 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126- RewriterBase &rewriter, Value reachingDefinition) {
1219+ RewriterBase &rewriter, Value reachingDefinition,
1220+ const DataLayout &dataLayout) {
11271221 return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
11281222 reachingDefinition);
11291223}
@@ -1159,8 +1253,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
11591253 return memcpyStoresTo (*this , slot);
11601254}
11611255
1162- Value LLVM::MemmoveOp::getStored (const MemorySlot &slot,
1163- RewriterBase &rewriter ) {
1256+ Value LLVM::MemmoveOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
1257+ const DataLayout &dataLayout ) {
11641258 return memcpyGetStored (*this , slot, rewriter);
11651259}
11661260
@@ -1174,7 +1268,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
11741268
11751269DeletionKind LLVM::MemmoveOp::removeBlockingUses (
11761270 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1177- RewriterBase &rewriter, Value reachingDefinition) {
1271+ RewriterBase &rewriter, Value reachingDefinition,
1272+ const DataLayout &dataLayout) {
11781273 return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
11791274 reachingDefinition);
11801275}
0 commit comments