Skip to content

Commit 86c9eff

Browse files
[mlir][Parser] Add nan and inf keywords
1 parent 4548bff commit 86c9eff

File tree

6 files changed

+121
-32
lines changed

6 files changed

+121
-32
lines changed

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
#include "mlir/IR/DialectImplementation.h"
2222
#include "mlir/IR/DialectResourceBlobManager.h"
2323
#include "mlir/IR/IntegerSet.h"
24+
#include "llvm/ADT/APFloat.h"
2425
#include "llvm/ADT/StringExtras.h"
2526
#include "llvm/Support/Endian.h"
27+
#include <cmath>
2628
#include <optional>
2729

2830
using namespace mlir;
@@ -121,14 +123,17 @@ Attribute Parser::parseAttribute(Type type) {
121123

122124
// Parse floating point and integer attributes.
123125
case Token::floatliteral:
126+
case Token::kw_inf:
127+
case Token::kw_nan:
124128
return parseFloatAttr(type, /*isNegative=*/false);
125129
case Token::integer:
126130
return parseDecOrHexAttr(type, /*isNegative=*/false);
127131
case Token::minus: {
128132
consumeToken(Token::minus);
129133
if (getToken().is(Token::integer))
130134
return parseDecOrHexAttr(type, /*isNegative=*/true);
131-
if (getToken().is(Token::floatliteral))
135+
if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
136+
getToken().is(Token::kw_nan))
132137
return parseFloatAttr(type, /*isNegative=*/true);
133138

134139
return (emitWrongTokenError(
@@ -342,21 +347,24 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
342347

343348
/// Parse a float attribute.
344349
Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
345-
auto val = getToken().getFloatingPointValue();
346-
if (!val)
347-
return (emitError("floating point value too large for attribute"), nullptr);
348-
consumeToken(Token::floatliteral);
350+
const Token tok = getToken();
351+
consumeToken();
349352
if (!type) {
350353
// Default to F64 when no type is specified.
351354
if (!consumeIf(Token::colon))
352355
type = builder.getF64Type();
353356
else if (!(type = parseType()))
354357
return nullptr;
355358
}
356-
if (!isa<FloatType>(type))
359+
auto floatType = dyn_cast<FloatType>(type);
360+
if (!floatType)
357361
return (emitError("floating point value not valid for specified type"),
358362
nullptr);
359-
return FloatAttr::get(type, isNegative ? -*val : *val);
363+
std::optional<APFloat> apResult;
364+
if (failed(parseFloatFromLiteral(apResult, tok, isNegative,
365+
floatType.getFloatSemantics())))
366+
return Attribute();
367+
return FloatAttr::get(floatType, *apResult);
360368
}
361369

362370
/// Construct an APint from a parsed value, a known attribute type and
@@ -622,7 +630,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
622630
}
623631

624632
// Check to see if floating point values were parsed.
625-
if (token.is(Token::floatliteral)) {
633+
if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
626634
return p.emitError(tokenLoc)
627635
<< "expected integer elements, but parsed floating-point";
628636
}
@@ -729,6 +737,8 @@ ParseResult TensorLiteralParser::parseElement() {
729737
// Parse a boolean element.
730738
case Token::kw_true:
731739
case Token::kw_false:
740+
case Token::kw_inf:
741+
case Token::kw_nan:
732742
case Token::floatliteral:
733743
case Token::integer:
734744
storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -738,7 +748,8 @@ ParseResult TensorLiteralParser::parseElement() {
738748
// Parse a signed integer or a negative floating-point element.
739749
case Token::minus:
740750
p.consumeToken(Token::minus);
741-
if (!p.getToken().isAny(Token::floatliteral, Token::integer))
751+
if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
752+
Token::integer))
742753
return p.emitError("expected integer or floating point literal");
743754
storage.emplace_back(/*isNegative=*/true, p.getToken());
744755
p.consumeToken();

mlir/lib/AsmParser/Parser.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,33 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
350350
ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
351351
const Token &tok, bool isNegative,
352352
const llvm::fltSemantics &semantics) {
353+
// Check for inf keyword.
354+
if (tok.is(Token::kw_inf)) {
355+
if (!APFloat::semanticsHasInf(semantics))
356+
return emitError(tok.getLoc())
357+
<< "floating point type does not support infinity";
358+
result = APFloat::getInf(semantics, isNegative);
359+
return success();
360+
}
361+
362+
// Check for NaN keyword.
363+
if (tok.is(Token::kw_nan)) {
364+
if (!APFloat::semanticsHasNan(semantics))
365+
return emitError(tok.getLoc())
366+
<< "floating point type does not support NaN";
367+
result = APFloat::getNaN(semantics, isNegative);
368+
return success();
369+
}
370+
353371
// Check for a floating point value.
354372
if (tok.is(Token::floatliteral)) {
355373
auto val = tok.getFloatingPointValue();
356374
if (!val)
357375
return emitError(tok.getLoc()) << "floating point value too large";
376+
if (std::fpclassify(*val) == FP_ZERO &&
377+
!APFloat::semanticsHasZero(semantics))
378+
return emitError(tok.getLoc())
379+
<< "floating point type does not support zero";
358380

359381
result.emplace(isNegative ? -*val : *val);
360382
bool unused;

mlir/lib/AsmParser/TokenKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
111111
TOK_KEYWORD(for)
112112
TOK_KEYWORD(func)
113113
TOK_KEYWORD(index)
114+
TOK_KEYWORD(inf)
114115
TOK_KEYWORD(loc)
115116
TOK_KEYWORD(max)
116117
TOK_KEYWORD(memref)
117118
TOK_KEYWORD(min)
118119
TOK_KEYWORD(mod)
120+
TOK_KEYWORD(nan)
119121
TOK_KEYWORD(none)
120122
TOK_KEYWORD(offset)
121123
TOK_KEYWORD(size)

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
18801880
// CHECK-NEXT: %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
18811881
// CHECK-NEXT: return %[[X]], %arg0, %arg0
18821882
%c0 = arith.constant 0.0 : f32
1883-
%inf = arith.constant 0x7F800000 : f32
1883+
%inf = arith.constant inf : f32
18841884
%0 = arith.minimumf %c0, %arg0 : f32
18851885
%1 = arith.minimumf %arg0, %arg0 : f32
18861886
%2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
18951895
// CHECK-NEXT: %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
18961896
// CHECK-NEXT: return %[[X]], %arg0, %arg0
18971897
%c0 = arith.constant 0.0 : f32
1898-
%-inf = arith.constant 0xFF800000 : f32
1898+
%-inf = arith.constant -inf : f32
18991899
%0 = arith.maximumf %c0, %arg0 : f32
19001900
%1 = arith.maximumf %arg0, %arg0 : f32
19011901
%2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
19101910
// CHECK-NEXT: %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
19111911
// CHECK-NEXT: return %[[X]], %arg0, %arg0
19121912
%c0 = arith.constant 0.0 : f32
1913-
%inf = arith.constant 0x7F800000 : f32
1913+
%inf = arith.constant inf : f32
19141914
%0 = arith.minnumf %c0, %arg0 : f32
19151915
%1 = arith.minnumf %arg0, %arg0 : f32
19161916
%2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
19251925
// CHECK-NEXT: %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
19261926
// CHECK-NEXT: return %[[X]], %arg0, %arg0
19271927
%c0 = arith.constant 0.0 : f32
1928-
%-inf = arith.constant 0xFF800000 : f32
1928+
%-inf = arith.constant -inf : f32
19291929
%0 = arith.maxnumf %c0, %arg0 : f32
19301930
%1 = arith.maxnumf %arg0, %arg0 : f32
19311931
%2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
20242024
// CHECK-DAG: %[[T:.*]] = arith.constant true
20252025
// CHECK-DAG: %[[F:.*]] = arith.constant false
20262026
// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]]
2027-
%nan = arith.constant 0x7fffffff : f32
2027+
%nan = arith.constant nan : f32
20282028
%0 = arith.cmpf olt, %nan, %arg0 : f32
20292029
%1 = arith.cmpf olt, %arg0, %nan : f32
20302030
%2 = arith.cmpf ugt, %nan, %arg0 : f32

mlir/test/IR/attribute.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
108108
// CHECK: float_attr = 2.000000e+00 : f128
109109
float_attr = 2. : f128
110110
} : () -> ()
111+
"test.float_attrs"() {
112+
// Note: nan/inf are printed in binary format because there may be multiple
113+
// nan/inf representations.
114+
// CHECK: float_attr = 0x7FC00000 : f32
115+
float_attr = nan : f32
116+
} : () -> ()
117+
"test.float_attrs"() {
118+
// CHECK: float_attr = 0x7C : f8E4M3
119+
float_attr = nan : f8E4M3
120+
} : () -> ()
121+
"test.float_attrs"() {
122+
// CHECK: float_attr = 0xFFC00000 : f32
123+
float_attr = -nan : f32
124+
} : () -> ()
125+
"test.float_attrs"() {
126+
// CHECK: float_attr = 0xFC : f8E4M3
127+
float_attr = -nan : f8E4M3
128+
} : () -> ()
129+
"test.float_attrs"() {
130+
// CHECK: float_attr = 0x7F800000 : f32
131+
float_attr = inf : f32
132+
} : () -> ()
133+
"test.float_attrs"() {
134+
// CHECK: float_attr = 0x78 : f8E4M3
135+
float_attr = inf : f8E4M3
136+
} : () -> ()
137+
"test.float_attrs"() {
138+
// CHECK: float_attr = 0xFF800000 : f32
139+
float_attr = -inf : f32
140+
} : () -> ()
141+
"test.float_attrs"() {
142+
// CHECK: float_attr = 0xF8 : f8E4M3
143+
float_attr = -inf : f8E4M3
144+
} : () -> ()
111145
return
112146
}
113147

148+
// -----
149+
150+
func.func @float_nan_unsupported() {
151+
"test.float_attrs"() {
152+
// expected-error @below{{floating point type does not support NaN}}
153+
float_attr = nan : f4E2M1FN
154+
} : () -> ()
155+
}
156+
157+
// -----
158+
159+
func.func @float_inf_unsupported() {
160+
"test.float_attrs"() {
161+
// expected-error @below{{floating point type does not support infinity}}
162+
float_attr = inf : f4E2M1FN
163+
} : () -> ()
164+
}
165+
166+
// -----
167+
114168
//===----------------------------------------------------------------------===//
115169
// Test integer attributes
116170
//===----------------------------------------------------------------------===//

mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func.func @tanh() {
4141
call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
4242

4343
// CHECK: nan
44-
%nan = arith.constant 0x7fc00000 : f32
44+
%nan = arith.constant nan : f32
4545
call @tanh_f32(%nan) : (f32) -> ()
4646

4747
return
@@ -87,15 +87,15 @@ func.func @log() {
8787
call @log_f32(%zero) : (f32) -> ()
8888

8989
// CHECK: nan
90-
%nan = arith.constant 0x7fc00000 : f32
90+
%nan = arith.constant nan : f32
9191
call @log_f32(%nan) : (f32) -> ()
9292

9393
// CHECK: inf
94-
%inf = arith.constant 0x7f800000 : f32
94+
%inf = arith.constant inf : f32
9595
call @log_f32(%inf) : (f32) -> ()
9696

9797
// CHECK: -inf, nan, inf, 0.693147
98-
%special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
98+
%special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
9999
call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
100100

101101
return
@@ -141,11 +141,11 @@ func.func @log2() {
141141
call @log2_f32(%neg_one) : (f32) -> ()
142142

143143
// CHECK: inf
144-
%inf = arith.constant 0x7f800000 : f32
144+
%inf = arith.constant inf : f32
145145
call @log2_f32(%inf) : (f32) -> ()
146146

147147
// CHECK: -inf, nan, inf, 1.58496
148-
%special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
148+
%special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
149149
call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
150150

151151
return
@@ -192,11 +192,11 @@ func.func @log1p() {
192192
call @log1p_f32(%neg_two) : (f32) -> ()
193193

194194
// CHECK: inf
195-
%inf = arith.constant 0x7f800000 : f32
195+
%inf = arith.constant inf : f32
196196
call @log1p_f32(%inf) : (f32) -> ()
197197

198198
// CHECK: -inf, nan, inf, 9.99995e-06
199-
%special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
199+
%special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
200200
call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
201201

202202
return
@@ -247,7 +247,7 @@ func.func @erf() {
247247
call @erf_f32(%val7) : (f32) -> ()
248248

249249
// CHECK: -1
250-
%negativeInf = arith.constant 0xff800000 : f32
250+
%negativeInf = arith.constant -inf : f32
251251
call @erf_f32(%negativeInf) : (f32) -> ()
252252

253253
// CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
263263
call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
264264

265265
// CHECK: 1
266-
%inf = arith.constant 0x7f800000 : f32
266+
%inf = arith.constant inf : f32
267267
call @erf_f32(%inf) : (f32) -> ()
268268

269269
// CHECK: nan
270-
%nan = arith.constant 0x7fc00000 : f32
270+
%nan = arith.constant nan : f32
271271
call @erf_f32(%nan) : (f32) -> ()
272272

273273
return
@@ -306,15 +306,15 @@ func.func @exp() {
306306
call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
307307

308308
// CHECK: inf
309-
%inf = arith.constant 0x7f800000 : f32
309+
%inf = arith.constant inf : f32
310310
call @exp_f32(%inf) : (f32) -> ()
311311

312312
// CHECK: 0
313-
%negative_inf = arith.constant 0xff800000 : f32
313+
%negative_inf = arith.constant -inf : f32
314314
call @exp_f32(%negative_inf) : (f32) -> ()
315315

316316
// CHECK: nan
317-
%nan = arith.constant 0x7fc00000 : f32
317+
%nan = arith.constant nan : f32
318318
call @exp_f32(%nan) : (f32) -> ()
319319

320320
return
@@ -358,19 +358,19 @@ func.func @expm1() {
358358
call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
359359

360360
// CHECK: -1
361-
%neg_inf = arith.constant 0xff800000 : f32
361+
%neg_inf = arith.constant -inf : f32
362362
call @expm1_f32(%neg_inf) : (f32) -> ()
363363

364364
// CHECK: inf
365-
%inf = arith.constant 0x7f800000 : f32
365+
%inf = arith.constant inf : f32
366366
call @expm1_f32(%inf) : (f32) -> ()
367367

368368
// CHECK: -1, inf, 1e-10
369-
%special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
369+
%special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
370370
call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
371371

372372
// CHECK: nan
373-
%nan = arith.constant 0x7fc00000 : f32
373+
%nan = arith.constant nan : f32
374374
call @expm1_f32(%nan) : (f32) -> ()
375375

376376
return

0 commit comments

Comments
 (0)