@@ -58,10 +58,10 @@ public partial class ExtractNestedTyping(
5858 public class Config
5959 {
6060 /// <summary>
61- /// Whether it should be assumed that missing types are likely opaque if they are only used as a pointer type
62- /// and therefore should be subjected to handle transformations .
61+ /// Handle types are identified by looking for missing types that are only referenced through a pointer.
62+ /// If true, empty structs representing handle types will be generated .
6363 /// </summary>
64- public bool AssumeMissingTypesOpaque { get ; init ; }
64+ public bool GenerateMissingHandleTypes { get ; init ; }
6565 }
6666
6767 /// <inheritdoc />
@@ -97,86 +97,8 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
9797 walker . Visit ( node ) ;
9898 }
9999
100- var handleDiscoverer = new HandleTypeDiscoverer ( ) ;
101- {
102- // We need to find and generate all missing handle types
103- // Handle types are types that are only referenced through a pointer
104- // We do this by parsing through the list of type errors
105- var typeErrors = compilation . GetDiagnostics ( ct )
106- . Where ( d => d . Id == "CS0246" ) // Type errors
107- . ToList ( ) ;
108-
109- // Find symbols that contain ITypeErrorSymbols
110- // These symbols are not necessarily ITypeErrorSymbols
111- var symbolsFound = new HashSet < ISymbol > ( SymbolEqualityComparer . Default ) ;
112- foreach ( var typeError in typeErrors )
113- {
114- var syntaxTree = typeError . Location . SourceTree ;
115- if ( syntaxTree == null )
116- {
117- continue ;
118- }
119-
120- var semanticModel = compilation . GetSemanticModel ( syntaxTree ) ;
121-
122- // Get the syntax node the type error corresponds to
123- var currentSyntax = syntaxTree . GetRoot ( ) . FindNode ( typeError . Location . SourceSpan ) ;
124-
125- // Search upwards to find a syntax node that we can call GetDeclaredSymbol on
126- // This is because calling GetDeclaredSymbol on the starting node will just return null
127- var isSuccess = false ;
128- while ( currentSyntax != null && currentSyntax is not TypeDeclarationSyntax )
129- {
130- switch ( currentSyntax )
131- {
132- case VariableDeclarationSyntax variableDeclarationSyntax :
133- {
134- foreach ( var declaratorSyntax in variableDeclarationSyntax . Variables )
135- {
136- var symbol = semanticModel . GetDeclaredSymbol ( declaratorSyntax , ct ) ;
137- if ( symbol != null )
138- {
139- symbolsFound . Add ( symbol ) ;
140- isSuccess = true ;
141-
142- // All of the declarators will have the same type, so getting the first symbol is enough
143- break ;
144- }
145- }
146-
147- break ;
148- }
149- case MemberDeclarationSyntax memberDeclarationSyntax :
150- {
151- var symbol = semanticModel . GetDeclaredSymbol ( memberDeclarationSyntax , ct ) ;
152- if ( symbol != null )
153- {
154- symbolsFound . Add ( symbol ) ;
155- isSuccess = true ;
156- }
157-
158- break ;
159- }
160- }
161-
162- currentSyntax = currentSyntax . Parent ;
163- }
164-
165- if ( ! isSuccess )
166- {
167- // This is to warn of unhandled cases
168- logger . LogWarning ( "Failed to find corresponding symbol for type error. There may be an unhandled case in the code" ) ;
169- }
170- }
171-
172- // These symbols contain at least one IErrorTypeSymbol, we need to search downwards for them
173- foreach ( var symbol in symbolsFound )
174- {
175- handleDiscoverer . Visit ( symbol ) ;
176- }
177- }
178-
179- var missingHandleTypes = handleDiscoverer . GetMissingHandleTypes ( ) ;
100+ var handleDiscoverer = new MissingHandleTypeDiscoverer ( logger ) ;
101+ var missingHandleTypes = handleDiscoverer . GetMissingHandleTypes ( compilation , ct ) ;
180102
181103 // Third pass to modify existing files as per our discovery.
182104 var rewriter = new Rewriter ( logger ) ;
@@ -350,18 +272,110 @@ string nativeTypeName
350272 }
351273 }
352274
353- private class HandleTypeDiscoverer : SymbolVisitor
275+ private class MissingHandleTypeDiscoverer ( ILogger logger ) : SymbolVisitor
354276 {
355277 private readonly HashSet < IErrorTypeSymbol > _nonHandleTypes = new ( SymbolEqualityComparer . Default ) ;
356278 private readonly Dictionary < IErrorTypeSymbol , string > _missingTypes = new ( SymbolEqualityComparer . Default ) ;
357279
358280 private string ? _currentNamespace = null ;
359- private int pointerTypeDepth = 0 ;
281+ private int _pointerTypeDepth = 0 ;
360282
361283 /// <summary>
362284 /// Gets all missing handle types that are found and the namespace that they should be created in.
363285 /// </summary>
364- public Dictionary < IErrorTypeSymbol , string > GetMissingHandleTypes ( ) => new ( _missingTypes . Where ( kvp => ! _nonHandleTypes . Contains ( kvp . Key ) ) , SymbolEqualityComparer . Default ) ;
286+ public Dictionary < IErrorTypeSymbol , string > GetMissingHandleTypes ( Compilation compilation , CancellationToken ct )
287+ {
288+ Clear ( ) ;
289+
290+ // We need to find and generate all missing handle types
291+ // Handle types are types that are only referenced through a pointer
292+ // We do this by parsing through the list of type errors
293+ var typeErrors = compilation . GetDiagnostics ( ct )
294+ . Where ( d => d . Id == "CS0246" ) // Type errors
295+ . ToList ( ) ;
296+
297+ // Find symbols that contain ITypeErrorSymbols
298+ // These symbols are not necessarily ITypeErrorSymbols
299+ var symbolsFound = new HashSet < ISymbol > ( SymbolEqualityComparer . Default ) ;
300+ foreach ( var typeError in typeErrors )
301+ {
302+ var syntaxTree = typeError . Location . SourceTree ;
303+ if ( syntaxTree == null )
304+ {
305+ continue ;
306+ }
307+
308+ var semanticModel = compilation . GetSemanticModel ( syntaxTree ) ;
309+
310+ // Get the syntax node the type error corresponds to
311+ var currentSyntax = syntaxTree . GetRoot ( ) . FindNode ( typeError . Location . SourceSpan ) ;
312+
313+ // Search upwards to find a syntax node that we can call GetDeclaredSymbol on
314+ // This is because calling GetDeclaredSymbol on the starting node will just return null
315+ var isSuccess = false ;
316+ while ( currentSyntax != null && currentSyntax is not TypeDeclarationSyntax )
317+ {
318+ switch ( currentSyntax )
319+ {
320+ case VariableDeclarationSyntax variableDeclarationSyntax :
321+ {
322+ foreach ( var declaratorSyntax in variableDeclarationSyntax . Variables )
323+ {
324+ var symbol = semanticModel . GetDeclaredSymbol ( declaratorSyntax , ct ) ;
325+ if ( symbol != null )
326+ {
327+ symbolsFound . Add ( symbol ) ;
328+ isSuccess = true ;
329+
330+ // All of the declarators will have the same type, so getting the first symbol is enough
331+ break ;
332+ }
333+ }
334+
335+ break ;
336+ }
337+ case MemberDeclarationSyntax memberDeclarationSyntax :
338+ {
339+ var symbol = semanticModel . GetDeclaredSymbol ( memberDeclarationSyntax , ct ) ;
340+ if ( symbol != null )
341+ {
342+ symbolsFound . Add ( symbol ) ;
343+ isSuccess = true ;
344+ }
345+
346+ break ;
347+ }
348+ }
349+
350+ currentSyntax = currentSyntax . Parent ;
351+ }
352+
353+ if ( ! isSuccess )
354+ {
355+ // This is to warn of unhandled cases
356+ logger . LogWarning ( "Failed to find corresponding symbol for type error. There may be an unhandled case in the code" ) ;
357+ }
358+ }
359+
360+ // These symbols contain at least one IErrorTypeSymbol, we need to search downwards for them
361+ foreach ( var symbol in symbolsFound )
362+ {
363+ Visit ( symbol ) ;
364+ }
365+
366+ return new Dictionary < IErrorTypeSymbol , string > ( _missingTypes . Where ( kvp => ! _nonHandleTypes . Contains ( kvp . Key ) ) , SymbolEqualityComparer . Default ) ;
367+ }
368+
369+ /// <summary>
370+ /// Resets internal state.
371+ /// </summary>
372+ public void Clear ( )
373+ {
374+ _nonHandleTypes . Clear ( ) ;
375+ _missingTypes . Clear ( ) ;
376+ _currentNamespace = null ;
377+ _pointerTypeDepth = 0 ;
378+ }
365379
366380 public override void VisitMethod ( IMethodSymbol symbol )
367381 {
@@ -415,9 +429,9 @@ public override void VisitPointerType(IPointerTypeSymbol symbol)
415429 {
416430 base . VisitPointerType ( symbol ) ;
417431
418- pointerTypeDepth ++ ;
432+ _pointerTypeDepth ++ ;
419433 Visit ( symbol . PointedAtType ) ;
420- pointerTypeDepth -- ;
434+ _pointerTypeDepth -- ;
421435 }
422436
423437 public override void VisitNamedType ( INamedTypeSymbol symbol )
@@ -431,7 +445,7 @@ public override void VisitNamedType(INamedTypeSymbol symbol)
431445 throw new InvalidOperationException ( $ "{ nameof ( _currentNamespace ) } should not be null") ;
432446 }
433447
434- if ( pointerTypeDepth == 0 )
448+ if ( _pointerTypeDepth == 0 )
435449 {
436450 _nonHandleTypes . Add ( errorTypeSymbol ) ;
437451 }
0 commit comments