diff --git a/spec/open_api_spec.cr b/spec/open_api_spec.cr index c064f3d..5b6308c 100644 --- a/spec/open_api_spec.cr +++ b/spec/open_api_spec.cr @@ -9,7 +9,7 @@ describe ActionController::OpenAPI do it "generates openapi docs" do result = ActionController::OpenAPI.generate_open_api_docs("title", "version", description: "desc") result[:openapi].should eq "3.0.3" - result[:paths].size.should eq 23 + result[:paths].size.should eq 25 result[:info][:description].should eq "desc" end end diff --git a/spec/route_builder_spec.cr b/spec/route_builder_spec.cr index 2261add..46660c1 100644 --- a/spec/route_builder_spec.cr +++ b/spec/route_builder_spec.cr @@ -186,4 +186,33 @@ describe AC::Route::Builder do result = client.get("/skipping_annotation") result.status_code.should eq 200 end + + it "should parse headers as params" do + result = client.get("/filtering/testing/header/values?query_param=12", headers: HTTP::Headers{ + "X-Count" => "123", + }) + result.status_code.should eq 200 + result.body.should eq "123--12" + + result = client.get("/filtering/testing/header/values/default?query_param=13") + result.status_code.should eq 200 + result.body.should eq "12--13" + + # test error handling + expect_raises(ActionController::Route::Param::MissingError, "missing required header 'X-Count'") do + client.get("/filtering/testing/header/values?query_param=12") + end + + expect_raises(ActionController::Route::Param::ValueError, "invalid header value for 'X-Count'") do + client.get("/filtering/testing/header/values?query_param=12", headers: HTTP::Headers{ + "X-Count" => "abc", + }) + end + + expect_raises(ActionController::Route::Param::ValueError, "invalid header value for 'X-Count'") do + client.get("/filtering/testing/header/values/default?query_param=12", headers: HTTP::Headers{ + "X-Count" => "abc", + }) + end + end end diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index b7c10ca..41c1b51 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -106,6 +106,24 @@ class Filtering < FilterOrdering id end + @[AC::Route::GET("/testing/header/values", content_type: "text/plain")] + def testing_headers( + @[AC::Param::Info(header: "X-Count", description: "number of requests made", example: "34")] + value : Int32, + query_param : Int32 + ) : String + "#{value}--#{query_param}" + end + + @[AC::Route::GET("/testing/header/values/default", content_type: "text/plain")] + def testing_headers_default( + query_param : Int32, + @[AC::Param::Info(header: "X-Count", description: "number of requests made", example: "34")] + value : Int32 = 12, + ) : String + "#{value}--#{query_param}" + end + # Test default arguments and multiple routes for a single method @[AC::Route::GET("/other_route/:id/test")] @[AC::Route::GET("/other_route/test")] diff --git a/src/action-controller/open_api.cr b/src/action-controller/open_api.cr index 6ac36ad..d6e80f8 100644 --- a/src/action-controller/open_api.cr +++ b/src/action-controller/open_api.cr @@ -348,7 +348,7 @@ module ActionController::OpenAPI params: [ {% for param_name, param in params %} { - name: {{ param_name }}, + name: {{ param[:header] || param_name }}, in: {{ param[:in] }}, required: {{ param[:required] }}, schema: ::JSON::Schema.introspect({{ param[:schema] }}, openapi: true).to_json, diff --git a/src/action-controller/router/builder.cr b/src/action-controller/router/builder.cr index 2bcc150..fe1489d 100644 --- a/src/action-controller/router/builder.cr +++ b/src/action-controller/router/builder.cr @@ -387,6 +387,15 @@ module ActionController::Route::Builder {% open_api_param = {} of Nil => Nil %} {% else %} {% open_api_param = open_api_params[query_param_name] || {} of Nil => Nil %} + + # check for header params + {% if (ann_converter && ann_converter[:header]) %} + {% open_api_param[:in] = :header %} + {% open_api_param[:header] = ann_converter[:header].id.stringify %} + {% else %} + {% open_api_param[:in] = open_api_param[:in] || :query %} + {% end %} + {% open_api_param[:in] = open_api_param[:in] || :query %} {% open_api_param[:docs] = (ann_converter && ann_converter[:description]) %} {% open_api_param[:example] = (ann_converter && ann_converter[:example]) %} @@ -463,16 +472,26 @@ module ActionController::Route::Builder {% end %} end + # handle header params + {% elsif open_api_param.has_key?(:header) %} + if param_value = @__context__.request.headers.fetch({{open_api_param[:header]}}, nil) + {{restrictions.join(" || ").id}} + {% if arg.default_value.stringify != "" %} + else + {{arg.default_value}} + {% end %} + end + # Required route param, so we ensure it {% elsif required_params.includes? string_name %} if param_value = route_params[{{query_param_name}}]? {{restrictions.join(" || ").id}} else - raise ::AC::Route::Param::MissingError.new("missing required parameter", {{query_param_name}}, {{arg.restriction.resolve.stringify}}) + raise ::AC::Route::Param::MissingError.new("missing required parameter '#{ {{query_param_name}} }'", {{query_param_name}}, {{arg.restriction.resolve.stringify}}) end - # An optional route param, might be passed as an query param - {% else %} + # An optional route param, might be passed as an query param (not the case for headers) + {% elsif !open_api_param.has_key?(:header) %} if param_value = params[{{query_param_name}}]? {{restrictions.join(" || ").id}} {% if arg.default_value.stringify != "" %} @@ -485,11 +504,19 @@ module ActionController::Route::Builder # Use tap to ensure a good error message if the function param isn't nilable ){% if !nilable %}.tap { |result| if result.nil? - if params.has_key?({{query_param_name}}) - raise ::AC::Route::Param::ValueError.new("invalid parameter value", {{query_param_name}}, {{arg.restriction.resolve.stringify}}) - else - raise ::AC::Route::Param::MissingError.new("missing required parameter", {{query_param_name}}, {{arg.restriction.resolve.stringify}}) - end + {% if open_api_param.has_key?(:header) %} + if @__context__.request.headers.has_key?({{open_api_param[:header]}}) + raise ::AC::Route::Param::ValueError.new("invalid header value for '#{ {{open_api_param[:header]}} }'", {{open_api_param[:header]}}, {{arg.restriction.resolve.stringify}}) + else + raise ::AC::Route::Param::MissingError.new("missing required header '#{ {{open_api_param[:header]}} }'", {{open_api_param[:header]}}, {{arg.restriction.resolve.stringify}}) + end + {% else %} + if params.has_key?({{query_param_name}}) + raise ::AC::Route::Param::ValueError.new("invalid parameter value for '#{ {{query_param_name}} }'", {{query_param_name}}, {{arg.restriction.resolve.stringify}}) + else + raise ::AC::Route::Param::MissingError.new("missing required parameter '#{ {{query_param_name}} }'", {{query_param_name}}, {{arg.restriction.resolve.stringify}}) + end + {% end %} end }.not_nil!{% end %}, {% end %}