|
18 | 18 | using Microsoft.FeatureManagement; |
19 | 19 | using Microsoft.Graph; |
20 | 20 | using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; |
| 21 | +using Semver; |
21 | 22 | namespace Microsoft.OneFuzz.Service; |
22 | 23 |
|
23 | 24 | public class Program { |
@@ -58,6 +59,67 @@ public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegat |
58 | 59 | } |
59 | 60 | } |
60 | 61 |
|
| 62 | + /// <summary> |
| 63 | + /// Represents a middleware that can optionally perform strict version checking based on data sent in request headers. |
| 64 | + /// </summary> |
| 65 | + public class VersionCheckingMiddleware : IFunctionsWorkerMiddleware { |
| 66 | + private const string CliVersionHeader = "Cli-Version"; |
| 67 | + private const string StrictVersionHeader = "Strict-Version"; |
| 68 | + private readonly SemVersion _oneFuzzServiceVersion; |
| 69 | + private readonly IRequestHandling _requestHandling; |
| 70 | + |
| 71 | + /// <summary> |
| 72 | + /// Initializes an instance of <see cref="VersionCheckingMiddleware"/> with the provided config and request handling objects. |
| 73 | + /// </summary> |
| 74 | + /// <param name="config">The service config containing the service version.</param> |
| 75 | + /// <param name="requestHandling">The request handling object to create HTTP responses with.</param> |
| 76 | + public VersionCheckingMiddleware(IServiceConfig config, IRequestHandling requestHandling) { |
| 77 | + _oneFuzzServiceVersion = SemVersion.Parse(config.OneFuzzVersion, SemVersionStyles.Strict); |
| 78 | + _requestHandling = requestHandling; |
| 79 | + } |
| 80 | + |
| 81 | + public OneFuzzResultVoid CheckCliVersion(Azure.Functions.Worker.Http.HttpHeadersCollection headers) { |
| 82 | + var doStrictVersionCheck = |
| 83 | + headers.TryGetValues(StrictVersionHeader, out var strictVersion) |
| 84 | + && strictVersion?.FirstOrDefault()?.Equals("true", StringComparison.InvariantCultureIgnoreCase) == true; // "== true" necessary here to avoid implicit null -> bool casting |
| 85 | + |
| 86 | + if (doStrictVersionCheck) { |
| 87 | + if (!headers.TryGetValues(CliVersionHeader, out var cliVersion)) { |
| 88 | + return Error.Create(ErrorCode.INVALID_REQUEST, $"'{StrictVersionHeader}' is set to true without a corresponding '{CliVersionHeader}' header"); |
| 89 | + } |
| 90 | + if (!SemVersion.TryParse(cliVersion?.FirstOrDefault() ?? "", SemVersionStyles.Strict, out var version)) { |
| 91 | + return Error.Create(ErrorCode.INVALID_CLI_VERSION, $"'{CliVersionHeader}' header value is not a valid sematic version"); |
| 92 | + } |
| 93 | + if (version.ComparePrecedenceTo(_oneFuzzServiceVersion) < 0) { |
| 94 | + return Error.Create(ErrorCode.INVALID_CLI_VERSION, "cli is out of date"); |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + return OneFuzzResultVoid.Ok; |
| 99 | + } |
| 100 | + |
| 101 | + /// <summary> |
| 102 | + /// Checks the request for two headers, cli version and one indicating whether to do strict version checking. |
| 103 | + /// When both are present and the cli is out of date, a descriptive response is sent back. |
| 104 | + /// </summary> |
| 105 | + /// <param name="context">The function context.</param> |
| 106 | + /// <param name="next">The function execution delegate.</param> |
| 107 | + /// <returns>A <seealso cref="Task"/> </returns> |
| 108 | + public async Async.Task Invoke(FunctionContext context, FunctionExecutionDelegate next) { |
| 109 | + var requestData = await context.GetHttpRequestDataAsync(); |
| 110 | + if (requestData is not null) { |
| 111 | + var error = CheckCliVersion(requestData.Headers); |
| 112 | + if (!error.IsOk) { |
| 113 | + var response = await _requestHandling.NotOk(requestData, error.ErrorV, "version middleware"); |
| 114 | + context.GetInvocationResult().Value = response; |
| 115 | + return; |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + await next(context); |
| 120 | + } |
| 121 | + } |
| 122 | + |
61 | 123 |
|
62 | 124 | //Move out expensive resources into separate class, and add those as Singleton |
63 | 125 | // ArmClient, Table Client(s), Queue Client(s), HttpClient, etc. |
@@ -161,6 +223,7 @@ public static async Async.Task Main() { |
161 | 223 | builder.UseMiddleware<LoggingMiddleware>(); |
162 | 224 | builder.UseMiddleware<Auth.AuthenticationMiddleware>(); |
163 | 225 | builder.UseMiddleware<Auth.AuthorizationMiddleware>(); |
| 226 | + builder.UseMiddleware<VersionCheckingMiddleware>(); |
164 | 227 |
|
165 | 228 | //this is a must, to tell the host that worker logging is done by us |
166 | 229 | builder.Services.Configure<WorkerOptions>(workerOptions => workerOptions.Capabilities["WorkerApplicationInsightsLoggingEnabled"] = bool.TrueString); |
|
0 commit comments