@@ -302,9 +302,8 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
302
302
303
303
std::shared_ptr<ConstCacheProxy> createConstCacheProxy (size_t size) {
304
304
// simply allocate buffer and return
305
- std::shared_ptr<void > base =
306
- std::shared_ptr<void >{std::aligned_alloc (64 , size), [](void *p) {
307
- std::free (p); }};
305
+ std::shared_ptr<void > base = std::shared_ptr<void >{
306
+ std::aligned_alloc (64 , size), [](void *p) { std::free (p); }};
308
307
return std::make_shared<ConstCacheProxy>(base, base.get (), size, true );
309
308
}
310
309
@@ -324,72 +323,63 @@ struct constGraphTensorCacheManager {
324
323
}
325
324
326
325
// alloc and set the buf_base_ and offset_ attributes of cache
327
- std::vector<uint64_t > alloc (std::vector<size_t > buffers_size ) {
328
- size_t total_size = 0 ;
329
- for (size_t i = 0 ; i < buffers_size .size (); i++) {
330
- total_size += divideAndCeil (buffers_size [i], 64 ) * 64 ;
326
+ std::vector<uint64_t > alloc (std::vector<size_t > buffersSize ) {
327
+ size_t totalSize = 0 ;
328
+ for (size_t i = 0 ; i < buffersSize .size (); i++) {
329
+ totalSize += divideAndCeil (buffersSize [i], 64 ) * 64 ;
331
330
}
332
- llvm::dbgs () << " Alloc total size: " << total_size << ' \n ' ;
333
- auto base = createConstCacheProxy (total_size );
334
- std::vector<uint64_t > global_ids (buffers_size .size ());
331
+ llvm::dbgs () << " Alloc total size: " << totalSize << ' \n ' ;
332
+ auto base = createConstCacheProxy (totalSize );
333
+ std::vector<uint64_t > globalIds (buffersSize .size ());
335
334
size_t offset = 0 ;
336
- for (size_t i = 0 ; i < buffers_size .size (); i++) {
335
+ for (size_t i = 0 ; i < buffersSize .size (); i++) {
337
336
llvm::dbgs () << " Alloc offset: " << offset << ' \n ' ;
338
337
regCachedTensor (cachedTensorGlobalId, base, offset);
339
- global_ids [i] = cachedTensorGlobalId;
338
+ globalIds [i] = cachedTensorGlobalId;
340
339
++cachedTensorGlobalId;
341
- offset += divideAndCeil (buffers_size [i], 64 ) * 64 ;
340
+ offset += divideAndCeil (buffersSize [i], 64 ) * 64 ;
342
341
}
343
- return global_ids ;
342
+ return globalIds ;
344
343
}
345
344
};
346
345
347
- static void addGlobal (ModuleOp module , Location loc, OpBuilder &builder,
348
- StringRef name, int64_t value) {
346
+ static void addGlobalI32 (ModuleOp module , Location loc, OpBuilder &builder,
347
+ StringRef name, int32_t value) {
349
348
OpBuilder::InsertionGuard insertGuard (builder);
350
349
builder.setInsertionPointToStart (module .getBody ());
351
350
352
- auto type = IntegerType::get (builder.getContext (), 8 );
351
+ auto type = IntegerType::get (builder.getContext (), 32 );
353
352
LLVM::GlobalOp global = builder.create <LLVM::GlobalOp>(
354
353
loc, type, /* isConstant=*/ true , LLVM::Linkage::Internal, name,
355
- builder.getIndexAttr (value),
354
+ builder.getI32IntegerAttr (value),
356
355
/* alignment=*/ 0 );
357
356
}
358
357
359
- static void addGlobalArray (ModuleOp module , Location loc, OpBuilder &builder,
360
- StringRef name, ArrayRef<int64_t > array) {
358
+ static void addGlobalI64Array (ModuleOp module , Location loc, OpBuilder &builder,
359
+ StringRef name, ArrayRef<int64_t > array) {
361
360
OpBuilder::InsertionGuard insertGuard (builder);
362
361
builder.setInsertionPointToStart (module .getBody ());
363
362
364
363
auto type = LLVM::LLVMArrayType::get (
365
- IntegerType::get (builder.getContext (), 8 ), array.size ());
364
+ IntegerType::get (builder.getContext (), 64 ), array.size ());
366
365
LLVM::GlobalOp global = builder.create <LLVM::GlobalOp>(
367
366
loc, type, /* isConstant=*/ true , LLVM::Linkage::Internal, name,
368
- builder.getIndexArrayAttr (array),
367
+ builder.getI64ArrayAttr (array),
369
368
/* alignment=*/ 0 );
370
369
}
371
370
372
- // static void addGlobalArray(ModuleOp module, Location loc, OpBuilder &builder,
373
- // StringRef name, ArrayRef<int64_t> array) {
374
- // OpBuilder::InsertionGuard insertGuard(builder);
375
- // builder.setInsertionPointToStart(module.getBody());
376
-
377
- // MemRefType type = MemRefType::Builder(array.size(), builder.getIndexType());
378
- // IntegerAttr memrefAlignment = IntegerAttr();
379
- // auto global = builder.create<memref::GlobalOp>(
380
- // loc, name,
381
- // /*sym_visibility=*/builder.getStringAttr("public"),
382
- // /*type=*/type,
383
- // /*initial_value=*/builder.getIndexTensorAttr(array),
384
- // /*constant=*/true,
385
- // /*alignment=*/memrefAlignment);
386
- // }
387
-
388
- // static void addGlobal(ModuleOp module, Location loc, OpBuilder &builder,
389
- // StringRef name, int64_t value) {
390
- // SmallVector<int64_t> array{value};
391
- // addGlobalArray(module, loc, builder, name, array);
392
- // }
371
+ static void addGlobalI32Array (ModuleOp module , Location loc, OpBuilder &builder,
372
+ StringRef name, ArrayRef<int32_t > array) {
373
+ OpBuilder::InsertionGuard insertGuard (builder);
374
+ builder.setInsertionPointToStart (module .getBody ());
375
+
376
+ auto type = LLVM::LLVMArrayType::get (
377
+ IntegerType::get (builder.getContext (), 32 ), array.size ());
378
+ LLVM::GlobalOp global = builder.create <LLVM::GlobalOp>(
379
+ loc, type, /* isConstant=*/ true , LLVM::Linkage::Internal, name,
380
+ builder.getI32ArrayAttr (array),
381
+ /* alignment=*/ 0 );
382
+ }
393
383
394
384
// Operate on tensors. Create fold() and compute() on module. The
395
385
// folded weights and first-run flag is maintained by upper-level runtime.
@@ -547,16 +537,16 @@ void CST::runOnOperation() {
547
537
globalIndexes.push_back (id);
548
538
}
549
539
globalIndexes.insert (globalIndexes.begin (), globalIndexes.size ());
550
- addGlobalArray (moduleOp, moduleOp.getLoc (), builder, " __fold_buffer_ids" ,
551
- globalIndexes);
540
+ addGlobalI64Array (moduleOp, moduleOp.getLoc (), builder, " __fold_buffer_ids" ,
541
+ globalIndexes);
552
542
553
543
foldFunc.setVisibility (SymbolTable::Visibility::Public);
554
544
moduleOp.push_back (foldFunc);
555
545
symbolTable.insert (foldFunc);
556
546
557
- SmallVector<int64_t > foldArgs;
558
- SmallVector<int64_t > foldIds;
559
- SmallVector<int64_t > computeArgs;
547
+ SmallVector<int32_t > foldArgs;
548
+ SmallVector<int32_t > foldIds;
549
+ SmallVector<int32_t > computeArgs;
560
550
561
551
// modify the BlockArguments of block
562
552
size_t oriNumArgs = block.getNumArguments ();
@@ -607,14 +597,15 @@ void CST::runOnOperation() {
607
597
foldArgs.insert (foldArgs.end (), id);
608
598
}
609
599
foldArgs.insert (foldArgs.begin (), foldArgs.size ());
610
- addGlobalArray (moduleOp, moduleOp.getLoc (), builder, " __fold_args" , foldArgs);
600
+ addGlobalI32Array (moduleOp, moduleOp.getLoc (), builder, " __fold_args" ,
601
+ foldArgs);
611
602
612
603
computeArgs.insert (computeArgs.begin (), computeArgs.size ());
613
- addGlobalArray (moduleOp, moduleOp.getLoc (), builder, " __compute_args" ,
614
- computeArgs);
604
+ addGlobalI32Array (moduleOp, moduleOp.getLoc (), builder, " __compute_args" ,
605
+ computeArgs);
615
606
616
- addGlobal (moduleOp, moduleOp.getLoc (), builder, " __num_orig_num_args" ,
617
- oriNumArgs);
607
+ addGlobalI32 (moduleOp, moduleOp.getLoc (), builder, " __num_orig_num_args" ,
608
+ oriNumArgs);
618
609
619
610
// modify the compute func signature
620
611
func::FuncOp computeFunc = cast<func::FuncOp>(topFunc);
0 commit comments