Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.Windows.CsWin32/FastSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ internal static ForStatementSyntax ForStatement(VariableDeclarationSyntax? decla
return SyntaxFactory.ForStatement(Token(SyntaxKind.ForKeyword), Token(SyntaxKind.OpenParenToken), declaration!, default, semicolonToken, condition, semicolonToken, incrementors, Token(SyntaxKind.CloseParenToken), statement);
}

internal static ForEachStatementSyntax ForEachStatement(TypeSyntax type, SyntaxToken identifier, ExpressionSyntax expression, StatementSyntax statement) => SyntaxFactory.ForEachStatement(type, identifier, expression, statement);

internal static StatementSyntax EmptyStatement() => SyntaxFactory.EmptyStatement(Token(SyntaxKind.SemicolonToken));

internal static NamespaceDeclarationSyntax NamespaceDeclaration(NameSyntax name) => SyntaxFactory.NamespaceDeclaration(Token(TriviaList(), SyntaxKind.NamespaceKeyword, TriviaList(Space)), name.WithTrailingTrivia(LineFeed), OpenBrace, default, default, default, CloseBrace, default);
Expand Down
138 changes: 135 additions & 3 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ namespace Microsoft.Windows.CsWin32;

public partial class Generator
{
private static readonly TypeSyntax PCWSTRTypeSyntax = QualifiedName(QualifiedName(IdentifierName(GlobalWinmdRootNamespaceAlias), IdentifierName("Foundation")), IdentifierName("PCWSTR"));

private enum FriendlyOverloadOf
{
ExternMethod,
Expand Down Expand Up @@ -268,9 +270,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
{
// TODO: add support for in/out size parameters. (e.g. RSGetViewports)
// TODO: add support for lists of pointers via a generated pointer-wrapping struct (e.g. PSSetSamplers)

// It is possible that countParamIndex points to a parameter that is not on the extern method
// when the parameter is the last one and was moved to a return value.
if (!isPointerToPointer && TryHandleCountParam(elementType, nullableSource: true))
{
// This block intentionally left blank.
Expand Down Expand Up @@ -305,6 +304,136 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
VariableDeclarator(localName.Identifier).WithInitializer(EqualsValueClause(origName))));
arguments[param.SequenceNumber - 1] = Argument(localName);
}

// Translate ReadOnlySpan<PCWSTR> to ReadOnlySpan<string>
if (isIn && !isOut && isConst && externParam.Type is PointerTypeSyntax { ElementType: QualifiedNameSyntax { Right: { Identifier: { ValueText: "PCWSTR" } } } })
{
signatureChanged = true;

// Change the parameter type to ReadOnlySpan<string>
parameters[param.SequenceNumber - 1] = externParam
.WithType(MakeReadOnlySpanOfT(PredefinedType(Token(SyntaxKind.StringKeyword))));

IdentifierNameSyntax gcHandlesLocal = IdentifierName($"{origName}GCHandles");
IdentifierNameSyntax pcwstrLocal = IdentifierName($"{origName}Pointers");

// var paramNameGCHandles = ArrayPool<GCHandle>.Shared.Rent(paramName.Length);
var gcHandlesArrayDecl = LocalDeclarationStatement(VariableDeclaration(
ArrayType(IdentifierName("var"))).AddVariables(
VariableDeclarator(gcHandlesLocal.Identifier).WithInitializer(EqualsValueClause(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName("global::System.Buffers.ArrayPool<global::System.Runtime.InteropServices.GCHandle>"),
IdentifierName("Shared")),
IdentifierName("Rent")))
.WithArgumentList(ArgumentList().AddArguments(Argument(GetSpanLength(origName, false))))))));

// var paramNamePointers = ArrayPool<PCWSTR>.Shared.Rent(paramName.Length);
var strsArrayDecl = LocalDeclarationStatement(VariableDeclaration(
ArrayType(IdentifierName("var"))).AddVariables(
VariableDeclarator(pcwstrLocal.Identifier).WithInitializer(EqualsValueClause(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName($"global::System.Buffers.ArrayPool<{PCWSTRTypeSyntax.ToString()}>"),
IdentifierName("Shared")),
IdentifierName("Rent")))
.WithArgumentList(ArgumentList().AddArguments(Argument(GetSpanLength(origName, false))))))));

// for (int i = 0; i < paramName.Length; i++)
// {
// paramNameGCHandles[i] = GCHandle.Alloc(paramName[i], GCHandleType.Pinned);
// paramNamePointers[i] = (char*)paramNameGCHandles[i].AddrOfPinnedObject();
// }
IdentifierNameSyntax loopVariable = IdentifierName("i");
var forLoop = ForStatement(
VariableDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword))).AddVariables(
VariableDeclarator(loopVariable.Identifier).WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0))))),
BinaryExpression(SyntaxKind.LessThanExpression, loopVariable, GetSpanLength(origName, false)),
SingletonSeparatedList<ExpressionSyntax>(PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, loopVariable)),
Block().AddStatements(
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(gcHandlesLocal).AddArgumentListArguments(Argument(loopVariable)),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName("global::System.Runtime.InteropServices.GCHandle"),
IdentifierName("Alloc")))
.WithArgumentList(ArgumentList().AddArguments(
Argument(ElementAccessExpression(origName).AddArgumentListArguments(Argument(loopVariable))),
Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, ParseTypeName("global::System.Runtime.InteropServices.GCHandleType"), IdentifierName("Pinned"))))))),
ExpressionStatement(AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ElementAccessExpression(pcwstrLocal).AddArgumentListArguments(Argument(loopVariable)),
CastExpression(
PointerType(PredefinedType(Token(SyntaxKind.CharKeyword))),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ElementAccessExpression(gcHandlesLocal).AddArgumentListArguments(Argument(loopVariable)),
IdentifierName("AddrOfPinnedObject"))).WithArgumentList(ArgumentList()))))));

leadingOutsideTryStatements.AddRange([gcHandlesArrayDecl, strsArrayDecl, forLoop]);

// foreach (var gcHandle in paramNameGCHandles) gcHandle.Free();
var freeHandleStatement = ForEachStatement(
IdentifierName("var").WithTrailingTrivia(Space),
Identifier("gcHandle").WithTrailingTrivia(Space),
gcHandlesLocal.WithLeadingTrivia(Space),
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("gcHandle"),
IdentifierName("Free")))).WithLeadingTrivia(LineFeed));

// ArrayPool<GCHandle>.Shared.Return(gcHandlesArray);
var returnGCHandlesArray = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName("global::System.Buffers.ArrayPool<global::System.Runtime.InteropServices.GCHandle>"),
IdentifierName("Shared.Return")))
.WithArgumentList(ArgumentList().AddArguments(Argument(gcHandlesLocal))));

// ArrayPool<PCWSTR>.Shared.Return(paramNamePointers);
var returnStrsArray = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName($"global::System.Buffers.ArrayPool<{PCWSTRTypeSyntax.ToString()}> "),
IdentifierName("Shared.Return")))
.WithArgumentList(ArgumentList().AddArguments(Argument(pcwstrLocal))));

finallyStatements.AddRange([freeHandleStatement, returnGCHandlesArray, returnStrsArray]);

// Update fixed blocks already created to consume our array of pinned pointers
bool found = false;
for (int i = 0; i < fixedBlocks.Count; i++)
{
if (fixedBlocks[i] is VariableDeclarationSyntax { Variables: [VariableDeclaratorSyntax { Initializer: { Value: IdentifierNameSyntax { Identifier: SyntaxToken id } } initializer } variable] } declaration
&& id.ValueText == externParam.Identifier.ValueText)
{
// fixed (PCWSTR* paramNamePointersPtr = strsArray)
fixedBlocks[i] = declaration.WithVariables(SingletonSeparatedList(variable.WithInitializer(initializer.WithValue(pcwstrLocal))));
found = true;
break;
}
}

if (!found)
{
throw new GenerationFailedException("Unable to find existing fixed block to change.");
}

arguments[param.SequenceNumber - 1] = Argument(localName);
}
}
else if (isIn && isOptional && !isOut && !isPointerToPointer)
{
Expand Down Expand Up @@ -485,6 +614,9 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverloads(MethodDefi
bool TryHandleCountParam(TypeSyntax elementType, bool nullableSource)
{
IdentifierNameSyntax localName = IdentifierName(origName + "Local");

// It is possible that countParamIndex points to a parameter that is not on the extern method
// when the parameter is the last one and was moved to a return value.
if (countParamIndex.HasValue
&& this.canUseSpan
&& externMethodDeclaration.ParameterList.Parameters.Count > countParamIndex.Value
Expand Down
14 changes: 14 additions & 0 deletions src/Microsoft.Windows.CsWin32/Generator.WhitespaceRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,20 @@ internal WhitespaceRewriter()
}
}

public override SyntaxNode? VisitForEachStatement(ForEachStatementSyntax node)
{
node = this.WithIndentingTrivia(node);
if (node.Statement is BlockSyntax)
{
return base.VisitForEachStatement(node);
}
else
{
using var indent = new Indent(this);
return base.VisitForEachStatement(node);
}
}

public override SyntaxNode? VisitReturnStatement(ReturnStatementSyntax node)
{
return base.VisitReturnStatement(this.WithIndentingTrivia(node));
Expand Down
6 changes: 6 additions & 0 deletions test/GenerationSandbox.Tests/GeneratedForm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Windows.Win32.Networking.ActiveDirectory;
using Windows.Win32.System.Com;
using Windows.Win32.System.Diagnostics.Debug;
using Windows.Win32.System.RestartManager;
using Windows.Win32.System.Threading;

#pragma warning disable CA1812 // dead code
Expand Down Expand Up @@ -81,6 +82,11 @@ private static void WriteFile()
PInvoke.WriteFile((SafeHandle?)null, new byte[2], &written, (NativeOverlapped*)null);
}

private static void RmRegisterResources()
{
PInvoke.RmRegisterResources(0, ["a", "b"], [default(RM_UNIQUE_PROCESS)], ["a", "b"]);
}

private class MyStream : IStream
{
public HRESULT Read(void* pv, uint cb, uint* pcbRead)
Expand Down
1 change: 1 addition & 0 deletions test/GenerationSandbox.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ GetProcAddress
GetTickCount
GetWindowText
GetWindowTextLength
RmRegisterResources
HDC_UserSize
HRESULT_FROM_WIN32
IDirectorySearch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public void OutPWSTR_Parameters_AsSpan()

[Theory]
[InlineData("WSManGetSessionOptionAsString")] // Uses the reserved keyword 'string' as a parameter name
[InlineData("RmRegisterResources")] // Parameter with PCWSTR* (an array of native strings)
public void InterestingAPIs(string name)
{
this.Generate(name);
Expand Down