Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,4 @@
<PackageVersion Include="xunit.v3.assert" Version="3.2.1" />
<PackageVersion Include="xunit.v3.extensibility.core" Version="3.2.1" />
</ItemGroup>
</Project>
</Project>
74 changes: 63 additions & 11 deletions TUnit.Analyzers.CodeFixers/Base/AssertionRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,35 @@ protected AssertionRewriter(SemanticModel semanticModel)
var convertedAssertion = ConvertAssertionIfNeeded(node);
if (convertedAssertion != null)
{
// Preserve the original trivia (whitespace, comments, etc.)
var conversionTrivia = convertedAssertion.GetLeadingTrivia();
var originalTrivia = node.GetLeadingTrivia();

SyntaxTriviaList finalTrivia;
// Only do special handling when there's actually a TODO comment
var hasComment = conversionTrivia.Any(t => t.IsKind(SyntaxKind.SingleLineCommentTrivia));
if (hasComment)
{
// Conversion added trivia (TODO comments). Structure should be:
// [original whitespace] [TODO comment] [newline] [original whitespace] [await expression]
var whitespaceTrivia = originalTrivia.Where(t => t.IsKind(SyntaxKind.WhitespaceTrivia)).ToList();
var nonWhitespaceTrivia = originalTrivia.Where(t => !t.IsKind(SyntaxKind.WhitespaceTrivia)).ToList();

var builder = new List<SyntaxTrivia>();
builder.AddRange(nonWhitespaceTrivia); // Add any non-whitespace (e.g., leading newlines)
builder.AddRange(whitespaceTrivia); // Add indentation
builder.AddRange(conversionTrivia); // Add TODO comment + newline
builder.AddRange(whitespaceTrivia); // Add indentation again for the await

finalTrivia = new SyntaxTriviaList(builder);
}
else
{
// No TODO comment, just use original trivia
finalTrivia = originalTrivia;
}

return convertedAssertion
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithLeadingTrivia(finalTrivia)
.WithTrailingTrivia(node.GetTrailingTrivia());
}

Expand Down Expand Up @@ -91,7 +117,10 @@ protected ExpressionSyntax CreateTUnitAssertionWithMessage(
}

// Now wrap the entire thing in await: await Assert.That(actualValue).MethodName(args).Because(message)
return SyntaxFactory.AwaitExpression(fullInvocation);
// Need to add a trailing space after 'await' keyword
var awaitKeyword = SyntaxFactory.Token(SyntaxKind.AwaitKeyword)
.WithTrailingTrivia(SyntaxFactory.Space);
return SyntaxFactory.AwaitExpression(awaitKeyword, fullInvocation);
}

private static bool IsEmptyOrNullMessage(ExpressionSyntax message)
Expand Down Expand Up @@ -171,26 +200,47 @@ protected static ExpressionSyntax CreateMessageExpression(

/// <summary>
/// Checks if the argument at the given index appears to be a comparer (IComparer, IEqualityComparer).
/// Returns null if the type cannot be determined.
/// </summary>
protected bool IsLikelyComparerArgument(ArgumentSyntax argument)
protected bool? IsLikelyComparerArgument(ArgumentSyntax argument)
{
var typeInfo = SemanticModel.GetTypeInfo(argument.Expression);
if (typeInfo.Type == null) return false;
if (typeInfo.Type == null || typeInfo.Type.TypeKind == TypeKind.Error)
{
// Type couldn't be resolved - return null to indicate unknown
return null;
}

var typeName = typeInfo.Type.ToDisplayString();

// If it's a string type, it's definitely a message, not a comparer
if (typeInfo.Type.SpecialType == SpecialType.System_String ||
typeName == "string" || typeName == "System.String")
{
return false;
}

// Check for IComparer, IComparer<T>, IEqualityComparer, IEqualityComparer<T>
if (typeName.Contains("IComparer") || typeName.Contains("IEqualityComparer"))
{
return true;
}

// Check interfaces
// Check interfaces - also check for generic interface names like IComparer`1
if (typeInfo.Type is INamedTypeSymbol namedType)
{
return namedType.AllInterfaces.Any(i =>
i.Name == "IComparer" ||
i.Name == "IEqualityComparer");
if (namedType.AllInterfaces.Any(i =>
i.Name.StartsWith("IComparer") ||
i.Name.StartsWith("IEqualityComparer")))
{
return true;
}
}

// Also check if the type name itself contains Comparer (for StringComparer, etc.)
if (typeName.Contains("Comparer"))
{
return true;
}

return false;
Expand All @@ -206,12 +256,14 @@ protected static SyntaxTrivia CreateTodoComment(string message)

protected bool IsFrameworkAssertion(InvocationExpressionSyntax invocation)
{
var symbol = SemanticModel.GetSymbolInfo(invocation).Symbol;
var symbolInfo = SemanticModel.GetSymbolInfo(invocation);
var symbol = symbolInfo.Symbol;

if (symbol is not IMethodSymbol methodSymbol)
{
return false;
}

var namespaceName = methodSymbol.ContainingNamespace?.ToDisplayString() ?? "";
return IsFrameworkAssertionNamespace(namespaceName);
}
Expand Down
Loading
Loading