Skip to content

Commit 73e1a47

Browse files
Refactor and fixes in source input model (#3445)
1 parent e9bd019 commit 73e1a47

File tree

4 files changed

+147
-99
lines changed

4 files changed

+147
-99
lines changed

src/AutoRest.CSharp/Common/Generation/Types/CSharpType.cs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,6 @@ internal static CSharpType FromSystemType(Type type, string defaultNamespace, So
163163
internal static CSharpType FromSystemType(BuildContext context, Type type)
164164
=> FromSystemType(type, context.DefaultNamespace, context.SourceInputModel);
165165

166-
public bool IsCollectionType()
167-
{
168-
if (!IsFrameworkType)
169-
return false;
170-
171-
return FrameworkType.Equals(typeof(IList<>)) ||
172-
FrameworkType.Equals(typeof(IEnumerable<>)) ||
173-
FrameworkType == typeof(IReadOnlyList<>) ||
174-
FrameworkType.Equals(typeof(IDictionary<,>)) ||
175-
FrameworkType == typeof(IReadOnlyDictionary<,>);
176-
}
177-
178166
public CSharpType GetNonNullable()
179167
{
180168
if (!IsNullable)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Immutable;
6+
using System.Diagnostics.CodeAnalysis;
7+
using System.Linq;
8+
using Azure.Core;
9+
using Microsoft.CodeAnalysis;
10+
11+
namespace AutoRest.CSharp.Input.Source
12+
{
13+
public class CodeGenAttributes
14+
{
15+
public CodeGenAttributes(Compilation compilation)
16+
{
17+
CodeGenSuppressAttribute = GetSymbol(compilation, typeof(CodeGenSuppressAttribute));
18+
CodeGenMemberAttribute = GetSymbol(compilation, typeof(CodeGenMemberAttribute));
19+
CodeGenTypeAttribute = GetSymbol(compilation, typeof(CodeGenTypeAttribute));
20+
CodeGenModelAttribute = GetSymbol(compilation, typeof(CodeGenModelAttribute));
21+
CodeGenClientAttribute = GetSymbol(compilation, typeof(CodeGenClientAttribute));
22+
CodeGenMemberSerializationAttribute = GetSymbol(compilation, typeof(CodeGenMemberSerializationAttribute));
23+
CodeGenMemberSerializationHooksAttribute = GetSymbol(compilation, typeof(CodeGenMemberSerializationHooksAttribute));
24+
}
25+
26+
public INamedTypeSymbol CodeGenSuppressAttribute { get; }
27+
28+
public INamedTypeSymbol CodeGenMemberAttribute { get; }
29+
30+
public INamedTypeSymbol CodeGenTypeAttribute { get; }
31+
32+
public INamedTypeSymbol CodeGenModelAttribute { get; }
33+
34+
public INamedTypeSymbol CodeGenClientAttribute { get; }
35+
36+
public INamedTypeSymbol CodeGenMemberSerializationAttribute { get; }
37+
38+
public INamedTypeSymbol CodeGenMemberSerializationHooksAttribute { get; }
39+
40+
private static INamedTypeSymbol GetSymbol(Compilation compilation, Type type) => compilation.GetTypeByMetadataName(type.FullName!) ?? throw new InvalidOperationException($"cannot load symbol of attribute {type}");
41+
42+
private static bool CheckAttribute(AttributeData attributeData, INamedTypeSymbol codeGenAttribute)
43+
=> SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, codeGenAttribute);
44+
45+
public bool TryGetCodeGenMemberAttributeValue(AttributeData attributeData, [MaybeNullWhen(false)] out string name)
46+
{
47+
name = null;
48+
if (!CheckAttribute(attributeData, CodeGenMemberAttribute))
49+
return false;
50+
51+
name = attributeData.ConstructorArguments.FirstOrDefault().Value as string;
52+
return name != null;
53+
}
54+
55+
public bool TryGetCodeGenMemberSerializationAttributeValue(AttributeData attributeData, [MaybeNullWhen(false)] out string[] propertyNames)
56+
{
57+
propertyNames = null;
58+
if (!CheckAttribute(attributeData, CodeGenMemberSerializationAttribute))
59+
return false;
60+
61+
if (attributeData.ConstructorArguments.Length > 0)
62+
{
63+
propertyNames = ToStringArray(attributeData.ConstructorArguments[0].Values);
64+
}
65+
66+
return propertyNames != null;
67+
}
68+
69+
public bool TryGetCodeGenMemberSerializationHooksAttributeValue(AttributeData attributeData, out (string? SerializationHook, string? DeserializationHook) hooks)
70+
{
71+
hooks = default;
72+
if (!CheckAttribute(attributeData, CodeGenMemberSerializationHooksAttribute))
73+
return false;
74+
75+
string? serializationHook = null;
76+
string? deserializationHook = null;
77+
78+
var arguments = attributeData.ConstructorArguments;
79+
serializationHook = arguments[0].Value as string;
80+
deserializationHook = arguments[1].Value as string;
81+
82+
hooks = (serializationHook, deserializationHook);
83+
return serializationHook != null || deserializationHook != null;
84+
}
85+
86+
public bool TryGetCodeGenModelAttributeValue(AttributeData attributeData, out string[]? usage, out string[]? formats)
87+
{
88+
usage = null;
89+
formats = null;
90+
if (!CheckAttribute(attributeData, CodeGenModelAttribute))
91+
return false;
92+
foreach (var namedArgument in attributeData.NamedArguments)
93+
{
94+
switch (namedArgument.Key)
95+
{
96+
case nameof(Azure.Core.CodeGenModelAttribute.Usage):
97+
usage = ToStringArray(namedArgument.Value.Values);
98+
break;
99+
case nameof(Azure.Core.CodeGenModelAttribute.Formats):
100+
formats = ToStringArray(namedArgument.Value.Values);
101+
break;
102+
}
103+
}
104+
105+
return usage != null || formats != null;
106+
}
107+
108+
private static string[]? ToStringArray(ImmutableArray<TypedConstant> values)
109+
{
110+
if (values.IsDefaultOrEmpty)
111+
{
112+
return null;
113+
}
114+
115+
return values
116+
.Select(v => (string?)v.Value)
117+
.OfType<string>()
118+
.ToArray();
119+
}
120+
}
121+
}

src/AutoRest.CSharp/Common/Input/Source/ModelTypeMapping.cs

Lines changed: 17 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -20,103 +20,49 @@ public class ModelTypeMapping
2020
public string[]? Usage { get; }
2121
public string[]? Formats { get; }
2222

23-
public ModelTypeMapping(INamedTypeSymbol modelAttribute, INamedTypeSymbol memberAttribute, INamedTypeSymbol serializationAttribute, INamedTypeSymbol serializationHooksAttribute, INamedTypeSymbol? existingType)
23+
public ModelTypeMapping(CodeGenAttributes codeGenAttributes, INamedTypeSymbol existingType)
2424
{
2525
_existingType = existingType;
2626
_propertyMappings = new();
2727
_serializationMappings = new(SymbolEqualityComparer.Default);
2828

2929
foreach (ISymbol member in GetMembers(existingType))
3030
{
31+
string[]? serializationPath = null;
32+
(string? SerializationHook, string? DeserializationHook)? serializationHooks = null;
3133
foreach (var attributeData in member.GetAttributes())
3234
{
33-
var attributeTypeSymbol = attributeData.AttributeClass;
3435
// handle CodeGenMember attribute
35-
if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, memberAttribute) && TryGetCodeGenMemberAttributeValue(member, attributeData, out var schemaMemberName))
36+
if (codeGenAttributes.TryGetCodeGenMemberAttributeValue(attributeData, out var schemaMemberName))
3637
{
3738
_propertyMappings.Add(schemaMemberName, member);
3839
}
39-
string[]? serializationPath = null;
40-
(string? SerializationHook, string? DeserializationHook)? serializationHooks = null;
41-
if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, serializationAttribute) && TryGetSerializationAttributeValue(member, attributeData, out var pathResult))
40+
// handle CodeGenMemberSerialization attribute
41+
if (codeGenAttributes.TryGetCodeGenMemberSerializationAttributeValue(attributeData, out var pathResult))
4242
{
4343
serializationPath = pathResult;
4444
}
45-
if (SymbolEqualityComparer.Default.Equals(attributeTypeSymbol, serializationHooksAttribute) && TryGetSerializationHooks(member, attributeData, out var hooks))
45+
// handle CodeGenMemberSerializationHooks attribute
46+
if (codeGenAttributes.TryGetCodeGenMemberSerializationHooksAttributeValue(attributeData, out var hooks))
4647
{
4748
serializationHooks = hooks;
4849
}
49-
if (serializationPath != null || serializationHooks != null)
50-
{
51-
_serializationMappings.Add(member, new SourcePropertySerializationMapping(member, serializationPath, serializationHooks?.SerializationHook, serializationHooks?.DeserializationHook));
52-
}
5350
}
54-
}
55-
56-
if (existingType != null)
57-
{
58-
foreach (var attributeData in existingType.GetAttributes())
51+
if (serializationPath != null || serializationHooks != null)
5952
{
60-
if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, modelAttribute))
61-
{
62-
foreach (var namedArgument in attributeData.NamedArguments)
63-
{
64-
switch (namedArgument.Key)
65-
{
66-
case nameof(CodeGenModelAttribute.Usage):
67-
Usage = ToStringArray(namedArgument.Value.Values);
68-
break;
69-
case nameof(CodeGenModelAttribute.Formats):
70-
Formats = ToStringArray(namedArgument.Value.Values);
71-
break;
72-
}
73-
}
74-
}
53+
_serializationMappings.Add(member, new SourcePropertySerializationMapping(member, serializationPath, serializationHooks?.SerializationHook, serializationHooks?.DeserializationHook));
7554
}
7655
}
77-
}
78-
79-
private static bool TryGetSerializationHooks(ISymbol symbol, AttributeData attributeData, out (string? SerializationHook, string? DeserializationHook) hooks)
80-
{
81-
string? serializationHook = null;
82-
string? deserializationHook = null;
83-
84-
var arguments = attributeData.ConstructorArguments;
85-
serializationHook = arguments[0].Value as string;
86-
deserializationHook = arguments[1].Value as string;
8756

88-
hooks = (serializationHook, deserializationHook);
89-
return serializationHook != null || deserializationHook != null;
90-
}
91-
92-
private static bool TryGetCodeGenMemberAttributeValue(ISymbol symbol, AttributeData attributeData, [MaybeNullWhen(false)] out string name)
93-
{
94-
name = attributeData.ConstructorArguments.FirstOrDefault().Value as string;
95-
return name != null;
96-
}
97-
98-
private static bool TryGetSerializationAttributeValue(ISymbol symbol, AttributeData attributeData, [MaybeNullWhen(false)] out string[] propertyNames)
99-
{
100-
propertyNames = null;
101-
if (attributeData.ConstructorArguments.Length > 0)
57+
foreach (var attributeData in existingType.GetAttributes())
10258
{
103-
propertyNames = ToStringArray(attributeData.ConstructorArguments[0].Values);
104-
}
105-
106-
return propertyNames != null;
107-
}
108-
109-
private static string[]? ToStringArray(ImmutableArray<TypedConstant> values)
110-
{
111-
if (values.IsDefaultOrEmpty)
112-
{
113-
return null;
59+
// handle CodeGenModel attribute
60+
if (codeGenAttributes.TryGetCodeGenModelAttributeValue(attributeData, out var usage, out var formats))
61+
{
62+
Usage = usage;
63+
Formats = formats;
64+
}
11465
}
115-
116-
return values
117-
.Select(v => (string?)v.Value)
118-
.OfType<string>()
119-
.ToArray();
12066
}
12167

12268
public SourceMemberMapping? GetForMember(string name)

src/AutoRest.CSharp/Common/Input/Source/SourceInputModel.cs

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,15 @@ public class SourceInputModel
1515
{
1616
private readonly Compilation _compilation;
1717
private readonly CompilationInput? _existingCompilation;
18-
private readonly INamedTypeSymbol _typeAttribute;
19-
private readonly INamedTypeSymbol _modelAttribute;
20-
private readonly INamedTypeSymbol _clientAttribute;
21-
private readonly INamedTypeSymbol _schemaMemberNameAttribute;
22-
private readonly INamedTypeSymbol _serializationAttribute;
23-
private readonly INamedTypeSymbol _serializationHooksAttribute;
18+
private readonly CodeGenAttributes _codeGenAttributes;
2419
private readonly Dictionary<string, INamedTypeSymbol> _nameMap = new Dictionary<string, INamedTypeSymbol>(StringComparer.OrdinalIgnoreCase);
2520

2621
public SourceInputModel(Compilation compilation, CompilationInput? existingCompilation = null)
2722
{
2823
_compilation = compilation;
2924
_existingCompilation = existingCompilation;
3025

31-
_schemaMemberNameAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberAttribute).FullName!)!;
32-
_serializationAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberSerializationAttribute).FullName!)!;
33-
_serializationHooksAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenMemberSerializationHooksAttribute).FullName!)!;
34-
_typeAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenTypeAttribute).FullName!)!;
35-
_modelAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenModelAttribute).FullName!)!;
36-
_clientAttribute = compilation.GetTypeByMetadataName(typeof(CodeGenClientAttribute).FullName!)!;
26+
_codeGenAttributes = new CodeGenAttributes(compilation);
3727

3828
IAssemblySymbol assembly = _compilation.Assembly;
3929

@@ -58,9 +48,12 @@ public SourceInputModel(Compilation compilation, CompilationInput? existingCompi
5848
return osvAttribute?.ConstructorArguments[0].Values.Select(v => v.Value).OfType<string>().ToList();
5949
}
6050

61-
public ModelTypeMapping CreateForModel(INamedTypeSymbol? symbol)
51+
public ModelTypeMapping? CreateForModel(INamedTypeSymbol? symbol)
6252
{
63-
return new ModelTypeMapping(_modelAttribute, _schemaMemberNameAttribute, _serializationAttribute, _serializationHooksAttribute, symbol);
53+
if (symbol == null)
54+
return null;
55+
56+
return new ModelTypeMapping(_codeGenAttributes, symbol);
6457
}
6558

6659
internal IMethodSymbol? FindMethod(string namespaceName, string typeName, string methodName, IEnumerable<CSharpType> parameters)
@@ -87,7 +80,7 @@ internal bool TryGetClientSourceInput(INamedTypeSymbol type, [NotNullWhen(true)]
8780
var attributeType = attribute.AttributeClass;
8881
while (attributeType != null)
8982
{
90-
if (SymbolEqualityComparer.Default.Equals(attributeType, _clientAttribute))
83+
if (SymbolEqualityComparer.Default.Equals(attributeType, _codeGenAttributes.CodeGenClientAttribute))
9184
{
9285
INamedTypeSymbol? parentClientType = null;
9386
foreach ((var argumentName, TypedConstant constant) in attribute.NamedArguments)
@@ -119,7 +112,7 @@ private bool TryGetName(ISymbol symbol, [NotNullWhen(true)] out string? name)
119112
var type = attribute.AttributeClass;
120113
while (type != null)
121114
{
122-
if (SymbolEqualityComparer.Default.Equals(type, _typeAttribute))
115+
if (SymbolEqualityComparer.Default.Equals(type, _codeGenAttributes.CodeGenTypeAttribute))
123116
{
124117
if (attribute?.ConstructorArguments.Length > 0)
125118
{

0 commit comments

Comments
 (0)