Skip to content

Commit a7a08ba

Browse files
authored
Merge pull request #1126 from microsoft/fix987
Fix handling of signed bit fields
2 parents 67273c0 + 3c457ec commit a7a08ba

File tree

5 files changed

+214
-34
lines changed

5 files changed

+214
-34
lines changed

src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ internal static SeparatedSyntaxList<TNode> SeparatedList<TNode>()
390390

391391
internal static IsPatternExpressionSyntax IsPatternExpression(ExpressionSyntax expression, PatternSyntax pattern) => SyntaxFactory.IsPatternExpression(expression, Token(TriviaList(Space), SyntaxKind.IsKeyword, TriviaList(Space)), pattern);
392392

393+
internal static BinaryPatternSyntax BinaryPattern(SyntaxKind kind, PatternSyntax left, PatternSyntax right) => SyntaxFactory.BinaryPattern(kind, left, TokenWithSpaces(GetBinaryPatternOperatorTokenKind(kind)), right);
394+
395+
internal static RelationalPatternSyntax RelationalPattern(SyntaxToken operatorToken, ExpressionSyntax expression) => SyntaxFactory.RelationalPattern(operatorToken, expression);
396+
393397
internal static ConditionalExpressionSyntax ConditionalExpression(ExpressionSyntax condition, ExpressionSyntax whenTrue, ExpressionSyntax whenFalse) => SyntaxFactory.ConditionalExpression(condition, Token(TriviaList(Space), SyntaxKind.QuestionToken, TriviaList(Space)), whenTrue, Token(TriviaList(Space), SyntaxKind.ColonToken, TriviaList(Space)), whenFalse);
394398

395399
internal static IfStatementSyntax IfStatement(ExpressionSyntax condition, StatementSyntax whenTrue) => IfStatement(condition, whenTrue, null);
@@ -595,6 +599,14 @@ private static SyntaxKind GetLiteralExpressionTokenKind(SyntaxKind kind)
595599
};
596600
}
597601

602+
private static SyntaxKind GetBinaryPatternOperatorTokenKind(SyntaxKind kind)
603+
=> kind switch
604+
{
605+
SyntaxKind.OrPattern => SyntaxKind.OrKeyword,
606+
SyntaxKind.AndPattern => SyntaxKind.AndKeyword,
607+
_ => throw new ArgumentOutOfRangeException(),
608+
};
609+
598610
private static SyntaxToken XmlReplaceBracketTokens(SyntaxToken originalToken, SyntaxToken rewrittenToken)
599611
{
600612
if (rewrittenToken.IsKind(SyntaxKind.LessThanToken) && string.Equals("<", rewrittenToken.Text, StringComparison.Ordinal))

src/Microsoft.Windows.CsWin32/Generator.Struct.cs

Lines changed: 80 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,20 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
161161
var fieldTypeInfo = (PrimitiveTypeHandleInfo)fieldDef.DecodeSignature(SignatureHandleProvider.Instance, null);
162162

163163
CustomAttributeValue<TypeSyntax> decodedAttribute = bitfieldAttribute.DecodeValue(CustomAttributeTypeProvider.Instance);
164+
(int? fieldBitLength, bool signed) = fieldTypeInfo.PrimitiveTypeCode switch
165+
{
166+
PrimitiveTypeCode.Byte => (8, false),
167+
PrimitiveTypeCode.SByte => (8, true),
168+
PrimitiveTypeCode.UInt16 => (16, false),
169+
PrimitiveTypeCode.Int16 => (16, true),
170+
PrimitiveTypeCode.UInt32 => (32, false),
171+
PrimitiveTypeCode.Int32 => (32, true),
172+
PrimitiveTypeCode.UInt64 => (64, false),
173+
PrimitiveTypeCode.Int64 => (64, true),
174+
PrimitiveTypeCode.UIntPtr => (null, false),
175+
PrimitiveTypeCode.IntPtr => ((int?)null, true),
176+
_ => throw new NotImplementedException(),
177+
};
164178
string propName = (string)decodedAttribute.FixedArguments[0].Value!;
165179
byte propOffset = (byte)(long)decodedAttribute.FixedArguments[1].Value!;
166180
byte propLength = (byte)(long)decodedAttribute.FixedArguments[2].Value!;
@@ -171,18 +185,25 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
171185
continue;
172186
}
173187

174-
TypeSyntax propertyType = propLength switch
188+
long minValue = signed ? -(1L << (propLength - 1)) : 0;
189+
long maxValue = (1L << (propLength - (signed ? 1 : 0))) - 1;
190+
int? leftPad = fieldBitLength.HasValue ? fieldBitLength - (propOffset + propLength) : null;
191+
int rightPad = propOffset;
192+
(TypeSyntax propertyType, int propertyBitLength) = propLength switch
175193
{
176-
1 => PredefinedType(Token(SyntaxKind.BoolKeyword)),
177-
<= 8 => PredefinedType(Token(SyntaxKind.ByteKeyword)),
178-
<= 16 => PredefinedType(Token(SyntaxKind.UShortKeyword)),
179-
<= 32 => PredefinedType(Token(SyntaxKind.UIntKeyword)),
180-
<= 64 => PredefinedType(Token(SyntaxKind.ULongKeyword)),
194+
1 => (PredefinedType(Token(SyntaxKind.BoolKeyword)), 1),
195+
<= 8 => (PredefinedType(Token(signed ? SyntaxKind.SByteKeyword : SyntaxKind.ByteKeyword)), 8),
196+
<= 16 => (PredefinedType(Token(signed ? SyntaxKind.ShortKeyword : SyntaxKind.UShortKeyword)), 16),
197+
<= 32 => (PredefinedType(Token(signed ? SyntaxKind.IntKeyword : SyntaxKind.UIntKeyword)), 32),
198+
<= 64 => (PredefinedType(Token(signed ? SyntaxKind.LongKeyword : SyntaxKind.ULongKeyword)), 64),
181199
_ => throw new NotSupportedException(),
182200
};
183201

184-
AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration);
185-
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration);
202+
AccessorDeclarationSyntax getter = AccessorDeclaration(SyntaxKind.GetAccessorDeclaration)
203+
.AddModifiers(TokenWithSpace(SyntaxKind.ReadOnlyKeyword))
204+
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));
205+
AccessorDeclarationSyntax setter = AccessorDeclaration(SyntaxKind.SetAccessorDeclaration)
206+
.AddAttributeLists(AttributeList().AddAttributes(MethodImpl(MethodImplOptions.AggressiveInlining)));
186207

187208
ulong maskNoOffset = (1UL << propLength) - 1;
188209
ulong mask = maskNoOffset << propOffset;
@@ -203,36 +224,61 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
203224
ExpressionSyntax notMaskNoOffset = UncheckedExpression(CastExpression(propertyType, PrefixUnaryExpression(SyntaxKind.BitwiseNotExpression, maskNoOffsetExpr)));
204225
LiteralExpressionSyntax propOffsetExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(propOffset));
205226

206-
// get => (byte)((field & unchecked((FIELDTYPE)getterMask)) >> propOffset);
227+
// signed:
228+
// get => (byte)((field << leftPad) >> (leftPad + rightPad)));
229+
// unsigned:
230+
// get => (byte)((field >> rightPad) & maskNoOffset);
207231
ExpressionSyntax getterExpression =
208-
CastExpression(propertyType, ParenthesizedExpression(BinaryExpression(
209-
SyntaxKind.RightShiftExpression,
210-
ParenthesizedExpression(BinaryExpression(
211-
SyntaxKind.BitwiseAndExpression,
212-
fieldAccess,
213-
UncheckedExpression(CastExpression(fieldType, maskExpr)))),
214-
propOffsetExpr)));
232+
CastExpression(propertyType, ParenthesizedExpression(
233+
signed ?
234+
BinaryExpression(
235+
SyntaxKind.RightShiftExpression,
236+
ParenthesizedExpression(BinaryExpression(
237+
SyntaxKind.LeftShiftExpression,
238+
fieldAccess,
239+
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(leftPad!.Value)))),
240+
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(leftPad.Value + rightPad)))
241+
: BinaryExpression(
242+
SyntaxKind.BitwiseAndExpression,
243+
ParenthesizedExpression(BinaryExpression(SyntaxKind.RightShiftExpression, fieldAccess, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(rightPad)))),
244+
maskNoOffsetExpr)));
215245
getter = getter
216246
.WithExpressionBody(ArrowExpressionClause(getterExpression))
217247
.WithSemicolonToken(SemicolonWithLineFeed);
218248

219-
// if ((value & ~maskNoOffset) != 0) throw new ArgumentOutOfRangeException(nameof(value));
220-
// field = (int)((field & unchecked((int)~mask)) | ((int)value << propOffset)));
221249
IdentifierNameSyntax valueName = IdentifierName("value");
222-
setter = setter.WithBody(Block().AddStatements(
223-
IfStatement(
224-
BinaryExpression(SyntaxKind.NotEqualsExpression, ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, notMaskNoOffset)), LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))),
225-
ThrowStatement(ObjectCreationExpression(IdentifierName(nameof(ArgumentOutOfRangeException))).AddArgumentListArguments(Argument(InvocationExpression(IdentifierName("nameof")).WithArgumentList(ArgumentList().AddArguments(Argument(valueName))))))),
226-
ExpressionStatement(AssignmentExpression(
227-
SyntaxKind.SimpleAssignmentExpression,
228-
fieldAccess,
229-
CastExpression(fieldType, ParenthesizedExpression(
230-
BinaryExpression(
231-
SyntaxKind.BitwiseOrExpression,
232-
//// (field & unchecked((int)~mask))
233-
fieldAndNotMask,
234-
//// ((int)value << propOffset)
235-
ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueName), propOffsetExpr)))))))));
250+
251+
List<StatementSyntax> setterStatements = new();
252+
if (propertyBitLength > propLength)
253+
{
254+
// The allowed range is smaller than the property type, so we need to check that the value fits.
255+
// signed:
256+
// global::System.Debug.Assert(value is >= minValue and <= maxValue);
257+
// unsigned:
258+
// global::System.Debug.Assert(value is <= maxValue);
259+
RelationalPatternSyntax max = RelationalPattern(TokenWithSpace(SyntaxKind.LessThanEqualsToken), CastExpression(propertyType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(maxValue))));
260+
RelationalPatternSyntax? min = signed ? RelationalPattern(TokenWithSpace(SyntaxKind.GreaterThanEqualsToken), CastExpression(propertyType, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(minValue)))) : null;
261+
setterStatements.Add(ExpressionStatement(InvocationExpression(
262+
ParseName("global::System.Diagnostics.Debug.Assert"),
263+
ArgumentList().AddArguments(Argument(
264+
IsPatternExpression(
265+
valueName,
266+
min is null ? max : BinaryPattern(SyntaxKind.AndPattern, min, max)))))));
267+
}
268+
269+
// field = (int)((field & unchecked((int)~mask)) | ((int)(value & mask) << propOffset)));
270+
ExpressionSyntax valueAndMaskNoOffset = ParenthesizedExpression(BinaryExpression(SyntaxKind.BitwiseAndExpression, valueName, maskNoOffsetExpr));
271+
setterStatements.Add(ExpressionStatement(AssignmentExpression(
272+
SyntaxKind.SimpleAssignmentExpression,
273+
fieldAccess,
274+
CastExpression(fieldType, ParenthesizedExpression(
275+
BinaryExpression(
276+
SyntaxKind.BitwiseOrExpression,
277+
//// (field & unchecked((int)~mask))
278+
fieldAndNotMask,
279+
//// ((int)(value & mask) << propOffset)
280+
ParenthesizedExpression(BinaryExpression(SyntaxKind.LeftShiftExpression, CastExpression(fieldType, valueAndMaskNoOffset), propOffsetExpr))))))));
281+
setter = setter.WithBody(Block().AddStatements(setterStatements.ToArray()));
236282
}
237283
else
238284
{
@@ -261,11 +307,12 @@ private StructDeclarationSyntax DeclareStruct(TypeDefinitionHandle typeDefHandle
261307
}
262308

263309
string bitDescription = propLength == 1 ? $"bit {propOffset}" : $"bits {propOffset}-{propOffset + propLength - 1}";
310+
string allowedRange = propLength == 1 ? string.Empty : $" Allowed values are [{minValue}..{maxValue}].";
264311

265312
PropertyDeclarationSyntax bitfieldProperty = PropertyDeclaration(propertyType.WithTrailingTrivia(Space), Identifier(propName).WithTrailingTrivia(LineFeed))
266313
.AddModifiers(TokenWithSpace(this.Visibility))
267314
.WithAccessorList(AccessorList().AddAccessors(getter, setter))
268-
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>Gets or sets {bitDescription} in the <see cref=\"{fieldName}\" /> field.</summary>\n"));
315+
.WithLeadingTrivia(ParseLeadingTrivia($"/// <summary>Gets or sets {bitDescription} in the <see cref=\"{fieldName}\" /> field.{allowedRange}</summary>\n"));
269316

270317
members.Add(bitfieldProperty);
271318
}

test/GenerationSandbox.Tests/BitFieldTests.cs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using Windows.Win32.Devices.Usb;
55
using Windows.Win32.UI.Shell;
6+
using Windows.Win32.UI.TabletPC;
67

78
public class BitFieldTests
89
{
@@ -21,13 +22,28 @@ public void Bool()
2122
Assert.Equal(unchecked((int)0xfffffffb), s._bitfield);
2223
}
2324

25+
#if DEBUG
2426
[Fact]
2527
public void ThrowWhenSetValueIsOutOfBounds()
2628
{
2729
BM_REQUEST_TYPE._BM s = default;
28-
Assert.Throws<ArgumentOutOfRangeException>(() => s.Type = 0b100);
30+
TestUtils.AssertDebugAssertFailed(() => s.Type = 0b100);
2931
}
3032

33+
[Fact]
34+
public void ThrowWhenSetValueIsOutOfBounds_Signed()
35+
{
36+
FLICK_DATA s = default;
37+
38+
// Assert after each invalid set that what ended up being set did not exceed the bounds of the bitfield.
39+
TestUtils.AssertDebugAssertFailed(() => s.iFlickDirection = -5);
40+
Assert.Equal(0, s._bitfield & ~0xe0);
41+
42+
TestUtils.AssertDebugAssertFailed(() => s.iFlickDirection = 4);
43+
Assert.Equal(0, s._bitfield & ~0xe0);
44+
}
45+
#endif
46+
3147
[Fact]
3248
public void SetValueMultiBit()
3349
{
@@ -41,4 +57,53 @@ public void SetValueMultiBit()
4157
s.Type = 0;
4258
Assert.Equal(0b10011111, s._bitfield);
4359
}
60+
61+
[Fact]
62+
public void SignedField()
63+
{
64+
FLICK_DATA s = default;
65+
66+
// iFlickDirection: 3 bits => range -4..3
67+
const int mask = 0b111_00000;
68+
s.iFlickDirection = -1;
69+
Assert.Equal(0b111_00000, s._bitfield);
70+
Assert.Equal(-1, s.iFlickDirection);
71+
72+
s.iFlickDirection = 1;
73+
Assert.Equal(0b001_00000, s._bitfield);
74+
Assert.Equal(1, s.iFlickDirection);
75+
76+
int oldFieldValue = s._bitfield;
77+
for (sbyte i = -4; i <= 3; i++)
78+
{
79+
// Assert that a valid value is retained via the property.
80+
s.iFlickDirection = i;
81+
Assert.Equal(i, s.iFlickDirection);
82+
83+
// Assert that no other bits were touched.
84+
Assert.Equal(oldFieldValue & ~mask, s._bitfield & ~mask);
85+
}
86+
87+
// Repeat the test, but with all 1s in other locations.
88+
s._bitfield = unchecked((int)0xffffffff);
89+
oldFieldValue = s._bitfield;
90+
for (sbyte i = -4; i <= 3; i++)
91+
{
92+
// Assert that a valid value is retained via the property.
93+
s.iFlickDirection = i;
94+
Assert.Equal(i, s.iFlickDirection);
95+
96+
// Assert that no other bits were touched.
97+
Assert.Equal(oldFieldValue & ~mask, s._bitfield & ~mask);
98+
}
99+
}
100+
101+
[Fact]
102+
public void SignedField_HasBoolFor1Bit()
103+
{
104+
FLICK_DATA s = default;
105+
Assert.False(s.fMenuModifier);
106+
s.fMenuModifier = true;
107+
Assert.Equal(0b10_0000_0000, s._bitfield);
108+
}
44109
}

test/GenerationSandbox.Tests/NativeMethods.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CSIDL_DESKTOP
99
DISPLAYCONFIG_VIDEO_SIGNAL_INFO
1010
EnumWindows
1111
FILE_ACCESS_RIGHTS
12+
FLICK_DATA
1213
GetProcAddress
1314
GetTickCount
1415
GetWindowText
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System.Diagnostics;
5+
6+
internal static class TestUtils
7+
{
8+
#if DEBUG // Only tests that are conditioned for Debug mode can assert this.
9+
internal static void AssertDebugAssertFailed(Action action)
10+
{
11+
// We're mutating a static collection.
12+
// Protect against concurrent tests mutating the collection while we're using it.
13+
lock (Trace.Listeners)
14+
{
15+
TraceListener[] listeners = Trace.Listeners.Cast<TraceListener>().ToArray();
16+
Trace.Listeners.Clear();
17+
Trace.Listeners.Add(new ThrowingTraceListener());
18+
19+
try
20+
{
21+
action();
22+
Assert.Fail("Expected Debug.Assert to fail.");
23+
}
24+
catch (DebugAssertFailedException)
25+
{
26+
// PASS
27+
}
28+
finally
29+
{
30+
Trace.Listeners.Clear();
31+
Trace.Listeners.AddRange(listeners);
32+
}
33+
}
34+
}
35+
#endif
36+
37+
private class DebugAssertFailedException : Exception
38+
{
39+
}
40+
41+
private class ThrowingTraceListener : TraceListener
42+
{
43+
public override void Fail(string? message) => throw new DebugAssertFailedException();
44+
45+
public override void Fail(string? message, string? detailMessage) => throw new DebugAssertFailedException();
46+
47+
public override void Write(string? message)
48+
{
49+
}
50+
51+
public override void WriteLine(string? message)
52+
{
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)