Skip to content

Commit ad94e2d

Browse files
committed
Refactor code related to missing handle discovery
1 parent b78306f commit ad94e2d

File tree

2 files changed

+104
-90
lines changed

2 files changed

+104
-90
lines changed

generator.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@
188188
"Namespace": "Silk.NET.Vulkan"
189189
},
190190
"ExtractNestedTyping": {
191-
"AssumeMissingTypesOpaque": true
191+
"GenerateMissingHandleTypes": true
192192
},
193193
"PrettifyNames": {
194194
"LongAcronymThreshold": 4,

sources/SilkTouch/SilkTouch/Mods/ExtractNestedTyping.cs

Lines changed: 103 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ public partial class ExtractNestedTyping(
5858
public class Config
5959
{
6060
/// <summary>
61-
/// Whether it should be assumed that missing types are likely opaque if they are only used as a pointer type
62-
/// and therefore should be subjected to handle transformations.
61+
/// Handle types are identified by looking for missing types that are only referenced through a pointer.
62+
/// If true, empty structs representing handle types will be generated.
6363
/// </summary>
64-
public bool AssumeMissingTypesOpaque { get; init; }
64+
public bool GenerateMissingHandleTypes { get; init; }
6565
}
6666

6767
/// <inheritdoc />
@@ -97,86 +97,8 @@ public override async Task ExecuteAsync(IModContext ctx, CancellationToken ct =
9797
walker.Visit(node);
9898
}
9999

100-
var handleDiscoverer = new HandleTypeDiscoverer();
101-
{
102-
// We need to find and generate all missing handle types
103-
// Handle types are types that are only referenced through a pointer
104-
// We do this by parsing through the list of type errors
105-
var typeErrors = compilation.GetDiagnostics(ct)
106-
.Where(d => d.Id == "CS0246") // Type errors
107-
.ToList();
108-
109-
// Find symbols that contain ITypeErrorSymbols
110-
// These symbols are not necessarily ITypeErrorSymbols
111-
var symbolsFound = new HashSet<ISymbol>(SymbolEqualityComparer.Default);
112-
foreach (var typeError in typeErrors)
113-
{
114-
var syntaxTree = typeError.Location.SourceTree;
115-
if (syntaxTree == null)
116-
{
117-
continue;
118-
}
119-
120-
var semanticModel = compilation.GetSemanticModel(syntaxTree);
121-
122-
// Get the syntax node the type error corresponds to
123-
var currentSyntax = syntaxTree.GetRoot().FindNode(typeError.Location.SourceSpan);
124-
125-
// Search upwards to find a syntax node that we can call GetDeclaredSymbol on
126-
// This is because calling GetDeclaredSymbol on the starting node will just return null
127-
var isSuccess = false;
128-
while (currentSyntax != null && currentSyntax is not TypeDeclarationSyntax)
129-
{
130-
switch (currentSyntax)
131-
{
132-
case VariableDeclarationSyntax variableDeclarationSyntax:
133-
{
134-
foreach (var declaratorSyntax in variableDeclarationSyntax.Variables)
135-
{
136-
var symbol = semanticModel.GetDeclaredSymbol(declaratorSyntax, ct);
137-
if (symbol != null)
138-
{
139-
symbolsFound.Add(symbol);
140-
isSuccess = true;
141-
142-
// All of the declarators will have the same type, so getting the first symbol is enough
143-
break;
144-
}
145-
}
146-
147-
break;
148-
}
149-
case MemberDeclarationSyntax memberDeclarationSyntax:
150-
{
151-
var symbol = semanticModel.GetDeclaredSymbol(memberDeclarationSyntax, ct);
152-
if (symbol != null)
153-
{
154-
symbolsFound.Add(symbol);
155-
isSuccess = true;
156-
}
157-
158-
break;
159-
}
160-
}
161-
162-
currentSyntax = currentSyntax.Parent;
163-
}
164-
165-
if (!isSuccess)
166-
{
167-
// This is to warn of unhandled cases
168-
logger.LogWarning("Failed to find corresponding symbol for type error. There may be an unhandled case in the code");
169-
}
170-
}
171-
172-
// These symbols contain at least one IErrorTypeSymbol, we need to search downwards for them
173-
foreach (var symbol in symbolsFound)
174-
{
175-
handleDiscoverer.Visit(symbol);
176-
}
177-
}
178-
179-
var missingHandleTypes = handleDiscoverer.GetMissingHandleTypes();
100+
var handleDiscoverer = new MissingHandleTypeDiscoverer(logger);
101+
var missingHandleTypes = handleDiscoverer.GetMissingHandleTypes(compilation, ct);
180102

181103
// Third pass to modify existing files as per our discovery.
182104
var rewriter = new Rewriter(logger);
@@ -350,18 +272,110 @@ string nativeTypeName
350272
}
351273
}
352274

353-
private class HandleTypeDiscoverer : SymbolVisitor
275+
private class MissingHandleTypeDiscoverer(ILogger logger) : SymbolVisitor
354276
{
355277
private readonly HashSet<IErrorTypeSymbol> _nonHandleTypes = new(SymbolEqualityComparer.Default);
356278
private readonly Dictionary<IErrorTypeSymbol, string> _missingTypes = new(SymbolEqualityComparer.Default);
357279

358280
private string? _currentNamespace = null;
359-
private int pointerTypeDepth = 0;
281+
private int _pointerTypeDepth = 0;
360282

361283
/// <summary>
362284
/// Gets all missing handle types that are found and the namespace that they should be created in.
363285
/// </summary>
364-
public Dictionary<IErrorTypeSymbol, string> GetMissingHandleTypes() => new(_missingTypes.Where(kvp => !_nonHandleTypes.Contains(kvp.Key)), SymbolEqualityComparer.Default);
286+
public Dictionary<IErrorTypeSymbol, string> GetMissingHandleTypes(Compilation compilation, CancellationToken ct)
287+
{
288+
Clear();
289+
290+
// We need to find and generate all missing handle types
291+
// Handle types are types that are only referenced through a pointer
292+
// We do this by parsing through the list of type errors
293+
var typeErrors = compilation.GetDiagnostics(ct)
294+
.Where(d => d.Id == "CS0246") // Type errors
295+
.ToList();
296+
297+
// Find symbols that contain ITypeErrorSymbols
298+
// These symbols are not necessarily ITypeErrorSymbols
299+
var symbolsFound = new HashSet<ISymbol>(SymbolEqualityComparer.Default);
300+
foreach (var typeError in typeErrors)
301+
{
302+
var syntaxTree = typeError.Location.SourceTree;
303+
if (syntaxTree == null)
304+
{
305+
continue;
306+
}
307+
308+
var semanticModel = compilation.GetSemanticModel(syntaxTree);
309+
310+
// Get the syntax node the type error corresponds to
311+
var currentSyntax = syntaxTree.GetRoot().FindNode(typeError.Location.SourceSpan);
312+
313+
// Search upwards to find a syntax node that we can call GetDeclaredSymbol on
314+
// This is because calling GetDeclaredSymbol on the starting node will just return null
315+
var isSuccess = false;
316+
while (currentSyntax != null && currentSyntax is not TypeDeclarationSyntax)
317+
{
318+
switch (currentSyntax)
319+
{
320+
case VariableDeclarationSyntax variableDeclarationSyntax:
321+
{
322+
foreach (var declaratorSyntax in variableDeclarationSyntax.Variables)
323+
{
324+
var symbol = semanticModel.GetDeclaredSymbol(declaratorSyntax, ct);
325+
if (symbol != null)
326+
{
327+
symbolsFound.Add(symbol);
328+
isSuccess = true;
329+
330+
// All of the declarators will have the same type, so getting the first symbol is enough
331+
break;
332+
}
333+
}
334+
335+
break;
336+
}
337+
case MemberDeclarationSyntax memberDeclarationSyntax:
338+
{
339+
var symbol = semanticModel.GetDeclaredSymbol(memberDeclarationSyntax, ct);
340+
if (symbol != null)
341+
{
342+
symbolsFound.Add(symbol);
343+
isSuccess = true;
344+
}
345+
346+
break;
347+
}
348+
}
349+
350+
currentSyntax = currentSyntax.Parent;
351+
}
352+
353+
if (!isSuccess)
354+
{
355+
// This is to warn of unhandled cases
356+
logger.LogWarning("Failed to find corresponding symbol for type error. There may be an unhandled case in the code");
357+
}
358+
}
359+
360+
// These symbols contain at least one IErrorTypeSymbol, we need to search downwards for them
361+
foreach (var symbol in symbolsFound)
362+
{
363+
Visit(symbol);
364+
}
365+
366+
return new Dictionary<IErrorTypeSymbol, string>(_missingTypes.Where(kvp => !_nonHandleTypes.Contains(kvp.Key)), SymbolEqualityComparer.Default);
367+
}
368+
369+
/// <summary>
370+
/// Resets internal state.
371+
/// </summary>
372+
public void Clear()
373+
{
374+
_nonHandleTypes.Clear();
375+
_missingTypes.Clear();
376+
_currentNamespace = null;
377+
_pointerTypeDepth = 0;
378+
}
365379

366380
public override void VisitMethod(IMethodSymbol symbol)
367381
{
@@ -415,9 +429,9 @@ public override void VisitPointerType(IPointerTypeSymbol symbol)
415429
{
416430
base.VisitPointerType(symbol);
417431

418-
pointerTypeDepth++;
432+
_pointerTypeDepth++;
419433
Visit(symbol.PointedAtType);
420-
pointerTypeDepth--;
434+
_pointerTypeDepth--;
421435
}
422436

423437
public override void VisitNamedType(INamedTypeSymbol symbol)
@@ -431,7 +445,7 @@ public override void VisitNamedType(INamedTypeSymbol symbol)
431445
throw new InvalidOperationException($"{nameof(_currentNamespace)} should not be null");
432446
}
433447

434-
if (pointerTypeDepth == 0)
448+
if (_pointerTypeDepth == 0)
435449
{
436450
_nonHandleTypes.Add(errorTypeSymbol);
437451
}

0 commit comments

Comments
 (0)