@@ -121,7 +121,7 @@ public bool TryGenerateExternMethod(string possiblyQualifiedName, out IReadOnlyL
121121 this . RequestExternMethod ( methodDefHandle ) ;
122122 } ) ;
123123
124- string methodNamespace = this . Reader . GetString ( this . Reader . GetTypeDefinition ( methodDef . GetDeclaringType ( ) ) . Namespace ) ;
124+ string methodNamespace = this . GetMethodNamespace ( methodDef ) ;
125125 preciseApi = ImmutableList . Create ( $ "{ methodNamespace } .{ methodName } ") ;
126126 return true ;
127127 }
@@ -168,6 +168,8 @@ private static bool IsLibraryAllowedAppLocal(string libraryName)
168168 return false ;
169169 }
170170
171+ private string GetMethodNamespace ( MethodDefinition methodDef ) => this . Reader . GetString ( this . Reader . GetTypeDefinition ( methodDef . GetDeclaringType ( ) ) . Namespace ) ;
172+
171173 private void DeclareExternMethod ( MethodDefinitionHandle methodDefinitionHandle )
172174 {
173175 MethodDefinition methodDefinition = this . Reader . GetMethodDefinition ( methodDefinitionHandle ) ;
@@ -204,43 +206,123 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle)
204206 CustomAttributeHandleCollection ? returnTypeAttributes = this . GetReturnTypeCustomAttributes ( methodDefinition ) ;
205207 TypeSyntaxAndMarshaling returnType = signature . ReturnType . ToTypeSyntax ( typeSettings , returnTypeAttributes , ParameterAttributes . Out ) ;
206208
207- MethodDeclarationSyntax methodDeclaration = MethodDeclaration (
208- List < AttributeListSyntax > ( )
209- . Add ( AttributeList ( )
210- . WithCloseBracketToken ( TokenWithLineFeed ( SyntaxKind . CloseBracketToken ) )
211- . AddAttributes ( DllImport ( import , moduleName , entrypoint , requiresUnicodeCharSet ? CharSet . Unicode : CharSet . Ansi ) ) ) ,
212- modifiers : TokenList ( TokenWithSpace ( this . Visibility ) , TokenWithSpace ( SyntaxKind . StaticKeyword ) , TokenWithSpace ( SyntaxKind . ExternKeyword ) ) ,
209+ // Search for any enum substitutions.
210+ TypeSyntax ? returnTypeEnumName = this . FindAssociatedEnum ( returnTypeAttributes ) ;
211+ TypeSyntax ? [ ] ? parameterEnumType = null ;
212+ foreach ( ParameterHandle parameterHandle in methodDefinition . GetParameters ( ) )
213+ {
214+ Parameter parameter = this . Reader . GetParameter ( parameterHandle ) ;
215+ if ( parameter . SequenceNumber == 0 )
216+ {
217+ continue ;
218+ }
219+
220+ if ( this . FindAssociatedEnum ( parameter . GetCustomAttributes ( ) ) is IdentifierNameSyntax parameterEnumName )
221+ {
222+ parameterEnumType ??= new TypeSyntax ? [ signature . ParameterTypes . Length ] ;
223+ parameterEnumType [ parameter . SequenceNumber - 1 ] = parameterEnumName ;
224+ }
225+ }
226+
227+ AttributeListSyntax CreateDllImportAttributeList ( ) => AttributeList ( )
228+ . WithCloseBracketToken ( TokenWithLineFeed ( SyntaxKind . CloseBracketToken ) )
229+ . AddAttributes ( DllImport ( import , moduleName , entrypoint , requiresUnicodeCharSet ? CharSet . Unicode : CharSet . Ansi ) ) ;
230+
231+ MethodDeclarationSyntax externDeclaration = MethodDeclaration (
232+ List < AttributeListSyntax > ( ) . Add ( CreateDllImportAttributeList ( ) ) ,
233+ modifiers : TokenList ( TokenWithSpace ( SyntaxKind . StaticKeyword ) , TokenWithSpace ( SyntaxKind . ExternKeyword ) ) ,
213234 returnType . Type . WithTrailingTrivia ( TriviaList ( Space ) ) ,
214235 explicitInterfaceSpecifier : null ! ,
215236 SafeIdentifier ( methodName ) ,
216237 null ! ,
217- FixTrivia ( this . CreateParameterList ( methodDefinition , signature , typeSettings ) ) ,
238+ this . CreateParameterList ( methodDefinition , signature , typeSettings ) ,
218239 List < TypeParameterConstraintClauseSyntax > ( ) ,
219240 body : null ! ,
220241 TokenWithLineFeed ( SyntaxKind . SemicolonToken ) ) ;
221- methodDeclaration = returnType . AddReturnMarshalAs ( methodDeclaration ) ;
242+ externDeclaration = returnType . AddReturnMarshalAs ( externDeclaration ) ;
222243
223244 if ( this . generateDefaultDllImportSearchPathsAttribute )
224245 {
225- methodDeclaration = methodDeclaration . AddAttributeLists (
246+ externDeclaration = externDeclaration . AddAttributeLists (
226247 IsLibraryAllowedAppLocal ( moduleName ) ? DefaultDllImportSearchPathsAllowAppDirAttributeList : DefaultDllImportSearchPathsAttributeList ) ;
227248 }
228249
229- if ( this . GetSupportedOSPlatformAttribute ( methodDefinition . GetCustomAttributes ( ) ) is AttributeSyntax supportedOSPlatformAttribute )
250+ bool requiresUnsafe = RequiresUnsafe ( externDeclaration . ReturnType ) || externDeclaration . ParameterList . Parameters . Any ( p => RequiresUnsafe ( p . Type ) ) ;
251+ if ( requiresUnsafe )
230252 {
231- methodDeclaration = methodDeclaration . AddAttributeLists ( AttributeList ( ) . AddAttributes ( supportedOSPlatformAttribute ) ) ;
253+ externDeclaration = externDeclaration . AddModifiers ( TokenWithSpace ( SyntaxKind . UnsafeKeyword ) ) ;
232254 }
233255
234- // Add documentation if we can find it.
235- methodDeclaration = this . AddApiDocumentation ( entrypoint ?? methodName , methodDeclaration ) ;
256+ MethodDeclarationSyntax exposedMethod ;
257+ if ( returnTypeEnumName is null && parameterEnumType is null )
258+ {
259+ // No need for wrapping the extern method, so just expose it directly.
260+ exposedMethod = externDeclaration . WithModifiers ( externDeclaration . Modifiers . Insert ( 0 , TokenWithSpace ( this . Visibility ) ) ) ;
261+ }
262+ else
263+ {
264+ string ns = this . GetMethodNamespace ( methodDefinition ) ;
265+ NameSyntax nsSyntax = ParseName ( ReplaceCommonNamespaceWithAlias ( this , ns ) ) ;
266+ ParameterListSyntax exposedParameterList = this . CreateParameterList ( methodDefinition , signature , typeSettings ) ;
267+ static SyntaxToken RefInOutKeyword ( ParameterSyntax p ) =>
268+ p . Modifiers . Any ( SyntaxKind . OutKeyword ) ? TokenWithSpace ( SyntaxKind . OutKeyword ) :
269+ p . Modifiers . Any ( SyntaxKind . RefKeyword ) ? TokenWithSpace ( SyntaxKind . RefKeyword ) :
270+ default ;
271+ ArgumentListSyntax argumentList = exposedParameterList . Parameters . Aggregate ( ArgumentList ( ) , ( list , p ) => list . AddArguments ( Argument ( IdentifierName ( p . Identifier . ValueText ) ) . WithRefKindKeyword ( RefInOutKeyword ( p ) ) ) ) ;
272+ if ( parameterEnumType is not null )
273+ {
274+ for ( int i = 0 ; i < parameterEnumType . Length ; i ++ )
275+ {
276+ if ( parameterEnumType [ i ] is TypeSyntax parameterType )
277+ {
278+ NameSyntax qualifiedParameterType = QualifiedName ( nsSyntax , ( SimpleNameSyntax ) parameterType ) ;
279+ exposedParameterList = exposedParameterList . ReplaceNode ( exposedParameterList . Parameters [ i ] , exposedParameterList . Parameters [ i ] . WithType ( qualifiedParameterType . WithTrailingTrivia ( Space ) ) ) ;
280+ this . RequestInteropType ( ns , parameterEnumType [ i ] ! . ToString ( ) , this . DefaultContext ) ;
281+ argumentList = argumentList . ReplaceNode ( argumentList . Arguments [ i ] , argumentList . Arguments [ i ] . WithExpression ( CastExpression ( externDeclaration . ParameterList . Parameters [ i ] . Type ! . WithTrailingTrivia ( default ( SyntaxTriviaList ) ) , argumentList . Arguments [ i ] . Expression ) ) ) ;
282+ }
283+ }
284+ }
285+
286+ // We need to specify Entrypoint because our local function will have a different name.
287+ // It must have a unique name because some functions will have the same signature as our exposed method except for the return type.
288+ entrypoint ??= methodName ;
289+ IdentifierNameSyntax localExternFunctionName = IdentifierName ( "LocalExternFunction" ) ;
290+ ExpressionSyntax invocation = InvocationExpression ( localExternFunctionName , argumentList ) ;
291+
292+ if ( returnTypeEnumName is not null )
293+ {
294+ this . RequestInteropType ( ns , returnTypeEnumName . ToString ( ) , this . DefaultContext ) ;
295+ returnTypeEnumName = QualifiedName ( nsSyntax , ( SimpleNameSyntax ) returnTypeEnumName ) ;
296+ invocation = CastExpression ( returnTypeEnumName , invocation ) ;
297+ }
298+
299+ StatementSyntax forwardingStatement = returnType . Type is PredefinedTypeSyntax { Keyword . RawKind : ( int ) SyntaxKind . VoidKeyword } ? ExpressionStatement ( invocation ) : ReturnStatement ( invocation ) ;
300+ LocalFunctionStatementSyntax externFunction = LocalFunctionStatement ( externDeclaration . ReturnType , localExternFunctionName . Identifier )
301+ . AddAttributeLists ( CreateDllImportAttributeList ( ) . WithOpenBracketToken ( Token ( SyntaxKind . OpenBracketToken ) . WithLeadingTrivia ( LineFeed ) ) )
302+ . WithModifiers ( externDeclaration . Modifiers )
303+ . WithParameterList ( externDeclaration . ParameterList )
304+ . WithSemicolonToken ( SemicolonWithLineFeed ) ;
305+
306+ exposedMethod = MethodDeclaration ( returnTypeEnumName ?? returnType . Type , externDeclaration . Identifier )
307+ . AddModifiers ( TokenWithSpace ( this . Visibility ) , TokenWithSpace ( SyntaxKind . StaticKeyword ) )
308+ . WithParameterList ( exposedParameterList )
309+ . AddBodyStatements ( forwardingStatement , externFunction ) ;
310+ if ( requiresUnsafe )
311+ {
312+ exposedMethod = exposedMethod . AddModifiers ( TokenWithSpace ( SyntaxKind . UnsafeKeyword ) ) ;
313+ }
314+ }
236315
237- if ( RequiresUnsafe ( methodDeclaration . ReturnType ) || methodDeclaration . ParameterList . Parameters . Any ( p => RequiresUnsafe ( p . Type ) ) )
316+ if ( this . GetSupportedOSPlatformAttribute ( methodDefinition . GetCustomAttributes ( ) ) is AttributeSyntax supportedOSPlatformAttribute )
238317 {
239- methodDeclaration = methodDeclaration . AddModifiers ( TokenWithSpace ( SyntaxKind . UnsafeKeyword ) ) ;
318+ exposedMethod = exposedMethod . AddAttributeLists ( AttributeList ( ) . AddAttributes ( supportedOSPlatformAttribute ) ) ;
240319 }
241320
242- this . volatileCode . AddMemberToModule ( moduleName , this . DeclareFriendlyOverloads ( methodDefinition , methodDeclaration , this . methodsAndConstantsClassName , FriendlyOverloadOf . ExternMethod , this . injectedPInvokeHelperMethods ) ) ;
243- this . volatileCode . AddMemberToModule ( moduleName , methodDeclaration ) ;
321+ // Add documentation if we can find it.
322+ exposedMethod = this . AddApiDocumentation ( entrypoint ?? methodName , exposedMethod ) ;
323+
324+ this . volatileCode . AddMemberToModule ( moduleName , this . DeclareFriendlyOverloads ( methodDefinition , exposedMethod , this . methodsAndConstantsClassName , FriendlyOverloadOf . ExternMethod , this . injectedPInvokeHelperMethods ) ) ;
325+ this . volatileCode . AddMemberToModule ( moduleName , exposedMethod ) ;
244326 }
245327 catch ( Exception ex )
246328 {
0 commit comments