Skip to content

Commit 25f611e

Browse files
committed
Change dtype
1 parent 8c50b67 commit 25f611e

File tree

1 file changed

+44
-53
lines changed

1 file changed

+44
-53
lines changed

lib/gc/Transforms/CST.cpp

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,8 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
302302

303303
std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304304
// simply allocate buffer and return
305-
std::shared_ptr<void> base =
306-
std::shared_ptr<void>{std::aligned_alloc(64, size), [](void *p) {
307-
std::free(p); }};
305+
std::shared_ptr<void> base = std::shared_ptr<void>{
306+
std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
308307
return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
309308
}
310309

@@ -324,72 +323,63 @@ struct constGraphTensorCacheManager {
324323
}
325324

326325
// alloc and set the buf_base_ and offset_ attributes of cache
327-
std::vector<uint64_t> alloc(std::vector<size_t> buffers_size) {
328-
size_t total_size = 0;
329-
for (size_t i = 0; i < buffers_size.size(); i++) {
330-
total_size += divideAndCeil(buffers_size[i], 64) * 64;
326+
std::vector<uint64_t> alloc(std::vector<size_t> buffersSize) {
327+
size_t totalSize = 0;
328+
for (size_t i = 0; i < buffersSize.size(); i++) {
329+
totalSize += divideAndCeil(buffersSize[i], 64) * 64;
331330
}
332-
llvm::dbgs() << "Alloc total size: " << total_size << '\n';
333-
auto base = createConstCacheProxy(total_size);
334-
std::vector<uint64_t> global_ids(buffers_size.size());
331+
llvm::dbgs() << "Alloc total size: " << totalSize << '\n';
332+
auto base = createConstCacheProxy(totalSize);
333+
std::vector<uint64_t> globalIds(buffersSize.size());
335334
size_t offset = 0;
336-
for (size_t i = 0; i < buffers_size.size(); i++) {
335+
for (size_t i = 0; i < buffersSize.size(); i++) {
337336
llvm::dbgs() << "Alloc offset: " << offset << '\n';
338337
regCachedTensor(cachedTensorGlobalId, base, offset);
339-
global_ids[i] = cachedTensorGlobalId;
338+
globalIds[i] = cachedTensorGlobalId;
340339
++cachedTensorGlobalId;
341-
offset += divideAndCeil(buffers_size[i], 64) * 64;
340+
offset += divideAndCeil(buffersSize[i], 64) * 64;
342341
}
343-
return global_ids;
342+
return globalIds;
344343
}
345344
};
346345

347-
static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
348-
StringRef name, int64_t value) {
346+
static void addGlobalI32(ModuleOp module, Location loc, OpBuilder &builder,
347+
StringRef name, int32_t value) {
349348
OpBuilder::InsertionGuard insertGuard(builder);
350349
builder.setInsertionPointToStart(module.getBody());
351350

352-
auto type = IntegerType::get(builder.getContext(), 8);
351+
auto type = IntegerType::get(builder.getContext(), 32);
353352
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
354353
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
355-
builder.getIndexAttr(value),
354+
builder.getI32IntegerAttr(value),
356355
/*alignment=*/0);
357356
}
358357

359-
static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
360-
StringRef name, ArrayRef<int64_t> array) {
358+
static void addGlobalI64Array(ModuleOp module, Location loc, OpBuilder &builder,
359+
StringRef name, ArrayRef<int64_t> array) {
361360
OpBuilder::InsertionGuard insertGuard(builder);
362361
builder.setInsertionPointToStart(module.getBody());
363362

364363
auto type = LLVM::LLVMArrayType::get(
365-
IntegerType::get(builder.getContext(), 8), array.size());
364+
IntegerType::get(builder.getContext(), 64), array.size());
366365
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
367366
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
368-
builder.getIndexArrayAttr(array),
367+
builder.getI64ArrayAttr(array),
369368
/*alignment=*/0);
370369
}
371370

372-
// static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
373-
// StringRef name, ArrayRef<int64_t> array) {
374-
// OpBuilder::InsertionGuard insertGuard(builder);
375-
// builder.setInsertionPointToStart(module.getBody());
376-
377-
// MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType());
378-
// IntegerAttr memrefAlignment = IntegerAttr();
379-
// auto global = builder.create<memref::GlobalOp>(
380-
// loc, name,
381-
// /*sym_visibility=*/builder.getStringAttr("public"),
382-
// /*type=*/type,
383-
// /*initial_value=*/builder.getIndexTensorAttr(array),
384-
// /*constant=*/true,
385-
// /*alignment=*/memrefAlignment);
386-
// }
387-
388-
// static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
389-
// StringRef name, int64_t value) {
390-
// SmallVector<int64_t> array{value};
391-
// addGlobalArray(module, loc, builder, name, array);
392-
// }
371+
static void addGlobalI32Array(ModuleOp module, Location loc, OpBuilder &builder,
372+
StringRef name, ArrayRef<int32_t> array) {
373+
OpBuilder::InsertionGuard insertGuard(builder);
374+
builder.setInsertionPointToStart(module.getBody());
375+
376+
auto type = LLVM::LLVMArrayType::get(
377+
IntegerType::get(builder.getContext(), 32), array.size());
378+
LLVM::GlobalOp global = builder.create<LLVM::GlobalOp>(
379+
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
380+
builder.getI32ArrayAttr(array),
381+
/*alignment=*/0);
382+
}
393383

394384
// Operate on tensors. Create fold() and compute() on module. The
395385
// folded weights and first-run flag is maintained by upper-level runtime.
@@ -547,16 +537,16 @@ void CST::runOnOperation() {
547537
globalIndexes.push_back(id);
548538
}
549539
globalIndexes.insert(globalIndexes.begin(), globalIndexes.size());
550-
addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids",
551-
globalIndexes);
540+
addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids",
541+
globalIndexes);
552542

553543
foldFunc.setVisibility(SymbolTable::Visibility::Public);
554544
moduleOp.push_back(foldFunc);
555545
symbolTable.insert(foldFunc);
556546

557-
SmallVector<int64_t> foldArgs;
558-
SmallVector<int64_t> foldIds;
559-
SmallVector<int64_t> computeArgs;
547+
SmallVector<int32_t> foldArgs;
548+
SmallVector<int32_t> foldIds;
549+
SmallVector<int32_t> computeArgs;
560550

561551
// modify the BlockArguments of block
562552
size_t oriNumArgs = block.getNumArguments();
@@ -607,14 +597,15 @@ void CST::runOnOperation() {
607597
foldArgs.insert(foldArgs.end(), id);
608598
}
609599
foldArgs.insert(foldArgs.begin(), foldArgs.size());
610-
addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__fold_args", foldArgs);
600+
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args",
601+
foldArgs);
611602

612603
computeArgs.insert(computeArgs.begin(), computeArgs.size());
613-
addGlobalArray(moduleOp, moduleOp.getLoc(), builder, "__compute_args",
614-
computeArgs);
604+
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args",
605+
computeArgs);
615606

616-
addGlobal(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args",
617-
oriNumArgs);
607+
addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args",
608+
oriNumArgs);
618609

619610
// modify the compute func signature
620611
func::FuncOp computeFunc = cast<func::FuncOp>(topFunc);

0 commit comments

Comments
 (0)