Skip to content

Commit f05633c

Browse files
committed
Use code generator for cloning responses
1 parent f56464f commit f05633c

File tree

10 files changed

+189
-49
lines changed

10 files changed

+189
-49
lines changed

gen/SourceGenerator/Extensions.cs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) .NET Foundation and Contributors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
namespace SourceGenerator;
16+
17+
static class Extensions {
18+
public static IEnumerable<ClassDeclarationSyntax> FindClasses(this Compilation compilation, Func<ClassDeclarationSyntax, bool> predicate)
19+
=> compilation.SyntaxTrees
20+
.Select(tree => compilation.GetSemanticModel(tree))
21+
.SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>())
22+
.Where(predicate);
23+
24+
public static IEnumerable<ClassDeclarationSyntax> FindAnnotatedClass(this Compilation compilation, string attributeName, bool strict) {
25+
return compilation.FindClasses(
26+
syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(CheckAttribute))
27+
);
28+
29+
bool CheckAttribute(AttributeSyntax attr) {
30+
var name = attr.Name.ToString();
31+
return strict ? name == attributeName : name.StartsWith(attributeName);
32+
}
33+
}
34+
35+
public static IEnumerable<ITypeSymbol> GetBaseTypesAndThis(this ITypeSymbol type) {
36+
var current = type;
37+
38+
while (current != null) {
39+
yield return current;
40+
41+
current = current.BaseType;
42+
}
43+
}
44+
}

gen/SourceGenerator/ImmutableGenerator.cs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,6 @@
1313
// limitations under the License.
1414
//
1515

16-
using System.Text;
17-
using Microsoft.CodeAnalysis;
18-
using Microsoft.CodeAnalysis.CSharp;
19-
using Microsoft.CodeAnalysis.CSharp.Syntax;
20-
using Microsoft.CodeAnalysis.Text;
21-
2216
namespace SourceGenerator;
2317

2418
[Generator]
@@ -28,10 +22,7 @@ public void Initialize(GeneratorInitializationContext context) { }
2822
public void Execute(GeneratorExecutionContext context) {
2923
var compilation = context.Compilation;
3024

31-
var mutableClasses = compilation.SyntaxTrees
32-
.Select(tree => compilation.GetSemanticModel(tree))
33-
.SelectMany(model => model.SyntaxTree.GetRoot().DescendantNodes().OfType<ClassDeclarationSyntax>())
34-
.Where(syntax => syntax.AttributeLists.Any(list => list.Attributes.Any(attr => attr.Name.ToString() == "GenerateImmutable")));
25+
var mutableClasses = compilation.FindAnnotatedClass("GenerateImmutable", strict: true);
3526

3627
foreach (var mutableClass in mutableClasses) {
3728
var immutableClass = GenerateImmutableClass(mutableClass, compilation);
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) .NET Foundation and Contributors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
namespace SourceGenerator;
16+
17+
[Generator]
18+
public class InheritedCloneGenerator : ISourceGenerator {
19+
const string AttributeName = "GenerateClone";
20+
21+
public void Initialize(GeneratorInitializationContext context) { }
22+
23+
public void Execute(GeneratorExecutionContext context) {
24+
var compilation = context.Compilation;
25+
26+
var candidates = compilation.FindAnnotatedClass(AttributeName, false);
27+
28+
foreach (var candidate in candidates) {
29+
var semanticModel = compilation.GetSemanticModel(candidate.SyntaxTree);
30+
var genericClassSymbol = semanticModel.GetDeclaredSymbol(candidate);
31+
if (genericClassSymbol == null) continue;
32+
33+
// Get the method name from the attribute Name argument
34+
var attributeData = genericClassSymbol.GetAttributes().FirstOrDefault(a => a.AttributeClass?.Name == $"{AttributeName}Attribute");
35+
var methodName = (string)attributeData.NamedArguments.FirstOrDefault(arg => arg.Key == "Name").Value.Value;
36+
37+
// Get the generic argument type where properties need to be copied from
38+
var attributeSyntax = candidate.AttributeLists
39+
.SelectMany(l => l.Attributes)
40+
.FirstOrDefault(a => a.Name.ToString().StartsWith(AttributeName));
41+
if (attributeSyntax == null) continue; // This should never happen
42+
43+
var typeArgumentSyntax = ((GenericNameSyntax)attributeSyntax.Name).TypeArgumentList.Arguments[0];
44+
var typeSymbol = (INamedTypeSymbol)semanticModel.GetSymbolInfo(typeArgumentSyntax).Symbol;
45+
46+
var code = GenerateMethod(candidate, genericClassSymbol, typeSymbol, methodName);
47+
context.AddSource($"{genericClassSymbol.Name}.Clone.g.cs", SourceText.From(code, Encoding.UTF8));
48+
}
49+
}
50+
51+
static string GenerateMethod(
52+
TypeDeclarationSyntax classToExtendSyntax,
53+
INamedTypeSymbol classToExtendSymbol,
54+
INamedTypeSymbol classToClone,
55+
string methodName
56+
) {
57+
var namespaceName = classToExtendSymbol.ContainingNamespace.ToDisplayString();
58+
var className = classToExtendSyntax.Identifier.Text;
59+
var genericTypeParameters = string.Join(", ", classToExtendSymbol.TypeParameters.Select(tp => tp.Name));
60+
var classDeclaration = classToExtendSymbol.TypeParameters.Length > 0 ? $"{className}<{genericTypeParameters}>" : className;
61+
62+
var all = classToClone.GetBaseTypesAndThis();
63+
var props = all.SelectMany(x => x.GetMembers().OfType<IPropertySymbol>()).ToArray();
64+
var usings = classToExtendSyntax.SyntaxTree.GetCompilationUnitRoot().Usings.Select(u => u.ToString());
65+
66+
var constructorParams = classToExtendSymbol.Constructors.First().Parameters.ToArray();
67+
var constructorArgs = string.Join(", ", constructorParams.Select(p => $"original.{GetPropertyName(p.Name, props)}"));
68+
var constructorParamNames = constructorParams.Select(p => p.Name).ToArray();
69+
70+
var properties = props
71+
// ReSharper disable once PossibleUnintendedLinearSearchInSet
72+
.Where(prop => !constructorParamNames.Contains(prop.Name, StringComparer.OrdinalIgnoreCase) && prop.SetMethod != null)
73+
.Select(prop => $" {prop.Name} = original.{prop.Name},")
74+
.ToArray();
75+
76+
const string template = """
77+
{Usings}
78+
79+
namespace {Namespace};
80+
81+
public partial class {ClassDeclaration} {
82+
public static {ClassDeclaration} {MethodName}({OriginalClassName} original)
83+
=> new {ClassDeclaration}({ConstructorArgs}) {
84+
{Properties}
85+
};
86+
}
87+
""";
88+
89+
var code = template
90+
.Replace("{Usings}", string.Join("\n", usings))
91+
.Replace("{Namespace}", namespaceName)
92+
.Replace("{ClassDeclaration}", classDeclaration)
93+
.Replace("{OriginalClassName}", classToClone.Name)
94+
.Replace("{MethodName}", methodName)
95+
.Replace("{ConstructorArgs}", constructorArgs)
96+
.Replace("{Properties}", string.Join("\n", properties).TrimEnd(','));
97+
98+
return code;
99+
100+
static string GetPropertyName(string parameterName, IPropertySymbol[] properties) {
101+
var property = properties.FirstOrDefault(p => string.Equals(p.Name, parameterName, StringComparison.OrdinalIgnoreCase));
102+
return property?.Name ?? parameterName;
103+
}
104+
}
105+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"$schema": "http://json.schemastore.org/launchsettings.json",
3+
"profiles": {
4+
"Generators": {
5+
"commandName": "DebugRoslynComponent",
6+
"targetProject": "../../src/RestSharp/RestSharp.csproj"
7+
}
8+
}
9+
}

gen/SourceGenerator/SourceGenerator.csproj

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@
99
<IsPackable>false</IsPackable>
1010
</PropertyGroup>
1111
<ItemGroup>
12-
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" PrivateAssets="All" />
13-
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" PrivateAssets="All" />
12+
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" PrivateAssets="All"/>
13+
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" PrivateAssets="All"/>
1414
</ItemGroup>
1515
<ItemGroup>
16-
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false" />
16+
<None Include="$(OutputPath)\$(AssemblyName).dll" Pack="true" PackagePath="analyzers/dotnet/cs" Visible="false"/>
17+
</ItemGroup>
18+
<ItemGroup>
19+
20+
<Using Include="System.Text"/>
21+
<Using Include="Microsoft.CodeAnalysis"/>
22+
<Using Include="Microsoft.CodeAnalysis.CSharp"/>
23+
<Using Include="Microsoft.CodeAnalysis.CSharp.Syntax"/>
24+
<Using Include="Microsoft.CodeAnalysis.Text"/>
1725
</ItemGroup>
1826
</Project>

src/RestSharp/Extensions/GenerateImmutableAttribute.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
namespace RestSharp.Extensions;
1717

1818
[AttributeUsage(AttributeTargets.Class)]
19-
class GenerateImmutableAttribute : Attribute { }
19+
class GenerateImmutableAttribute : Attribute;
20+
21+
[AttributeUsage(AttributeTargets.Class)]
22+
class GenerateCloneAttribute<T> : Attribute where T : class {
23+
public string? Name { get; set; }
24+
};
2025

2126
[AttributeUsage(AttributeTargets.Property)]
22-
class Exclude : Attribute { }
27+
class Exclude : Attribute;

src/RestSharp/Extensions/HttpResponseExtensions.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ static class HttpResponseExtensions {
2727
: new HttpRequestException($"Request failed with status code {httpResponse.StatusCode}");
2828
#endif
2929

30-
public static string GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
30+
public static async Task<string> GetResponseString(this HttpResponseMessage response, byte[] bytes, Encoding clientEncoding) {
3131
var encodingString = response.Content.Headers.ContentType?.CharSet;
3232
var encoding = encodingString != null ? TryGetEncoding(encodingString) : clientEncoding;
3333

3434
using var reader = new StreamReader(new MemoryStream(bytes), encoding);
35-
return reader.ReadToEnd();
36-
35+
return await reader.ReadToEndAsync();
3736
Encoding TryGetEncoding(string es) {
3837
try {
3938
return Encoding.GetEncoding(es);
@@ -69,4 +68,4 @@ Encoding TryGetEncoding(string es) {
6968
return original == null ? null : streamWriter(original);
7069
}
7170
}
72-
}
71+
}

src/RestSharp/Response/RestResponse.cs

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
using System.Diagnostics;
16-
using System.Net;
1716
using System.Text;
1817
using RestSharp.Extensions;
1918

@@ -25,34 +24,13 @@ namespace RestSharp;
2524
/// Container for data sent back from API including deserialized data
2625
/// </summary>
2726
/// <typeparam name="T">Type of data to deserialize to</typeparam>
28-
[DebuggerDisplay("{" + nameof(DebuggerDisplay) + "()}")]
29-
public class RestResponse<T>(RestRequest request) : RestResponse(request) {
27+
[GenerateClone<RestResponse>(Name = "FromResponse")]
28+
[DebuggerDisplay($"{{{nameof(DebuggerDisplay)}()}}")]
29+
public partial class RestResponse<T>(RestRequest request) : RestResponse(request) {
3030
/// <summary>
3131
/// Deserialized entity data
3232
/// </summary>
3333
public T? Data { get; set; }
34-
35-
public static RestResponse<T> FromResponse(RestResponse response)
36-
=> new(response.Request) {
37-
Content = response.Content,
38-
ContentEncoding = response.ContentEncoding,
39-
ContentHeaders = response.ContentHeaders,
40-
ContentLength = response.ContentLength,
41-
ContentType = response.ContentType,
42-
Cookies = response.Cookies,
43-
ErrorException = response.ErrorException,
44-
ErrorMessage = response.ErrorMessage,
45-
Headers = response.Headers,
46-
IsSuccessStatusCode = response.IsSuccessStatusCode,
47-
RawBytes = response.RawBytes,
48-
ResponseStatus = response.ResponseStatus,
49-
ResponseUri = response.ResponseUri,
50-
RootElement = response.RootElement,
51-
Server = response.Server,
52-
StatusCode = response.StatusCode,
53-
StatusDescription = response.StatusDescription,
54-
Version = response.Version
55-
};
5634
}
5735

5836
/// <summary>
@@ -78,7 +56,7 @@ async Task<RestResponse> GetDefaultResponse() {
7856
#endif
7957

8058
var bytes = stream == null ? null : await stream.ReadAsBytes(cancellationToken).ConfigureAwait(false);
81-
var content = bytes == null ? null : httpResponse.GetResponseString(bytes, encoding);
59+
var content = bytes == null ? null : await httpResponse.GetResponseString(bytes, encoding);
8260

8361
return new RestResponse(request) {
8462
Content = content,

src/RestSharp/Response/RestResponseBase.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
using System.Diagnostics;
16-
using System.Net;
16+
// ReSharper disable PropertyCanBeMadeInitOnly.Global
1717

1818
namespace RestSharp;
1919

@@ -65,12 +65,13 @@ protected RestResponseBase(RestRequest request) {
6565
public HttpStatusCode StatusCode { get; set; }
6666

6767
/// <summary>
68-
/// Whether or not the HTTP response status code indicates success
68+
/// Whether the HTTP response status code indicates success
6969
/// </summary>
7070
public bool IsSuccessStatusCode { get; set; }
7171

7272
/// <summary>
73-
/// Whether or not the HTTP response status code indicates success and no other error occurred (deserialization, timeout, ...)
73+
/// Whether the HTTP response status code indicates success and no other error occurred
74+
/// (deserialization, timeout, ...)
7475
/// </summary>
7576
public bool IsSuccessful => IsSuccessStatusCode && ResponseStatus == ResponseStatus.Completed;
7677

src/RestSharp/RestClient.Async.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public async Task<RestResponse> ExecuteAsync(RestRequest request, CancellationTo
4343
/// <inheritdoc />
4444
[PublicAPI]
4545
public async Task<Stream?> DownloadStreamAsync(RestRequest request, CancellationToken cancellationToken = default) {
46-
// Make sure we only read the headers so we can stream the content body efficiently
46+
// Make sure we only read the headers, so we can stream the content body efficiently
4747
request.CompletionOption = HttpCompletionOption.ResponseHeadersRead;
4848
var response = await ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);
4949

0 commit comments

Comments
 (0)