Skip to content

Commit

Permalink
[FIRRTL][CAPI] Allow constructing integers larger than 64 bits
Browse files Browse the repository at this point in the history
  • Loading branch information
SpriteOvO committed Apr 6, 2024
1 parent 923a5ee commit 8868394
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 4 deletions.
5 changes: 5 additions & 0 deletions include/circt-c/Dialect/FIRRTL.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ MLIR_CAPI_EXPORTED MlirAttribute firrtlAttrGetMemDir(MlirContext ctx,
MLIR_CAPI_EXPORTED MlirAttribute
firrtlAttrGetEventControl(MlirContext ctx, FIRRTLEventControl eventControl);

// Workaround:
// https://github.com/llvm/llvm-project/issues/84190#issuecomment-2035552035
MLIR_CAPI_EXPORTED MlirAttribute firrtlAttrGetIntegerFromString(
MlirType type, unsigned numBits, MlirStringRef str, uint8_t radix);

//===----------------------------------------------------------------------===//
// Utility API.
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions lib/CAPI/Dialect/FIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ MlirAttribute firrtlAttrGetEventControl(MlirContext ctx,
return wrap(EventControlAttr::get(unwrap(ctx), value));
}

MlirAttribute firrtlAttrGetIntegerFromString(MlirType type, unsigned numBits,
MlirStringRef str, uint8_t radix) {
auto value = APInt{numBits, unwrap(str), radix};
return wrap(IntegerAttr::get(unwrap(type), value));
}

FIRRTLValueFlow firrtlValueFoldFlow(MlirValue value, FIRRTLValueFlow flow) {
Flow flowValue;

Expand Down
82 changes: 78 additions & 4 deletions test/CAPI/firrtl.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

void exportCallback(MlirStringRef message, void *userData) {
printf("%.*s", (int)message.length, message.data);
void dumpCallback(MlirStringRef message, void *userData) {
fprintf(stderr, "%.*s", (int)message.length, message.data);
}

void appendBufferCallback(MlirStringRef message, void *userData) {
char *buffer = (char *)userData;
sprintf(buffer + strlen(buffer), "%.*s", (int)message.length, message.data);
}

void testExport(MlirContext ctx) {
Expand All @@ -39,7 +45,7 @@ void testExport(MlirContext ctx) {
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(testFIR));

MlirLogicalResult result = mlirExportFIRRTL(module, exportCallback, NULL);
MlirLogicalResult result = mlirExportFIRRTL(module, dumpCallback, NULL);
assert(mlirLogicalResultIsSuccess(result));

// CHECK: FIRRTL version 4.0.0
Expand Down Expand Up @@ -104,18 +110,86 @@ void testImportAnnotations(MlirContext ctx) {
firCircuit, mlirStringRefCreateFromCString("rawAnnotations"),
rawAnnotationsAttr);

mlirOperationPrint(mlirModuleGetOperation(module), exportCallback, NULL);
mlirOperationPrint(mlirModuleGetOperation(module), dumpCallback, NULL);

// clang-format off
// CHECK: firrtl.circuit "AnnoTest" attributes {rawAnnotations = [{class = "firrtl.transforms.DontTouchAnnotation", target = "~AnnoTest|AnnoTest>in"}]} {
// clang-format on
}

void assertAttrEqual(MlirAttribute lhs, MlirAttribute rhs) {
char lhsBuffer[256] = {0}, rhsBuffer[256] = {0};
mlirAttributePrint(lhs, appendBufferCallback, lhsBuffer);
mlirAttributePrint(rhs, appendBufferCallback, rhsBuffer);
assert(strcmp(lhsBuffer, rhsBuffer) == 0);
}

void testAttrGetIntegerFromString(MlirContext ctx) {
// large negative hex
assertAttrEqual(
mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("0xFF0000000000000000 : i72")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 72), 72,
mlirStringRefCreateFromCString("FF0000000000000000"), 16));

// large positive hex
assertAttrEqual(
mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("0xFF0000000000000000 : i73")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 73), 73,
mlirStringRefCreateFromCString("FF0000000000000000"), 16));

// large negative dec
assertAttrEqual(
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(
"-12345678912345678912345 : i75")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 75), 75,
mlirStringRefCreateFromCString("-12345678912345678912345"), 10));

// large positive dec
assertAttrEqual(
mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("12345678912345678912345 : i75")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 75), 75,
mlirStringRefCreateFromCString("12345678912345678912345"), 10));

// small negative hex
assertAttrEqual(
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0xFF : i8")),
firrtlAttrGetIntegerFromString(mlirIntegerTypeGet(ctx, 8), 8,
mlirStringRefCreateFromCString("FF"), 16));

// small positive hex
assertAttrEqual(
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0xFF : i9")),
firrtlAttrGetIntegerFromString(mlirIntegerTypeGet(ctx, 9), 9,
mlirStringRefCreateFromCString("FF"), 16));

// small negative dec
assertAttrEqual(mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("-114514 : i18")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 18), 18,
mlirStringRefCreateFromCString("-114514"), 10));

// small positive dec
assertAttrEqual(mlirAttributeParseGet(
ctx, mlirStringRefCreateFromCString("114514 : i18")),
firrtlAttrGetIntegerFromString(
mlirIntegerTypeGet(ctx, 18), 18,
mlirStringRefCreateFromCString("114514"), 10));
}

int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleLoadDialect(mlirGetDialectHandle__firrtl__(), ctx);
testExport(ctx);
testValueFoldFlow(ctx);
testImportAnnotations(ctx);
testAttrGetIntegerFromString(ctx);
return 0;
}

0 comments on commit 8868394

Please sign in to comment.