From 6d467c46bdc9ca34eb52668d70180d342d85d20b Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Mon, 30 Nov 2015 19:24:11 -0800 Subject: [PATCH] fix(resolver) percent-encode query args Percent-encode query args when re-attaching them to the `upstream_uri`. Since `ngx.encode_args` does not perform percent-encoding on various reserved characters, this implements a custom `utils.encode_args` function which uses LuaSocket's `url.encode` function. It tries to mimic the `ngx.encode_uri` behaviour 100%. Ideally, `ngx.encode_args` would proceed to the percent-encoding itself (see https://github.com/openresty/lua-nginx-module/pull/542). This also makes some perf and style changes. Fix #749 --- kong/core/handler.lua | 14 ++- kong/core/resolver.lua | 2 +- kong/tools/utils.lua | 66 ++++++++++- ...pi_resolver_spec.lua => resolver_spec.lua} | 111 +++++++++--------- spec/unit/tools/utils_spec.lua | 110 ++++++++++++----- 5 files changed, 207 insertions(+), 96 deletions(-) rename spec/integration/proxy/{api_resolver_spec.lua => resolver_spec.lua} (76%) diff --git a/kong/core/handler.lua b/kong/core/handler.lua index 986fede09c05..87f404cc1956 100644 --- a/kong/core/handler.lua +++ b/kong/core/handler.lua @@ -18,6 +18,7 @@ -- -- @see https://github.com/openresty/lua-nginx-module#ngxctx +local url = require "socket.url" local utils = require "kong.tools.utils" local reports = require "kong.core.reports" local stringy = require "stringy" @@ -25,10 +26,10 @@ local resolver = require "kong.core.resolver" local constants = require "kong.constants" local certificate = require "kong.core.certificate" -local table_insert = table.insert -local math_floor = math.floor -local unpack = unpack +local type = type local ipairs = ipairs +local math_floor = math.floor +local table_insert = table.insert local MULT = 10^3 local function round(num) @@ -55,9 +56,10 @@ return { ngx.ctx.KONG_PROXIED = true -- Append any querystring parameters modified during plugins execution - local upstream_url = unpack(stringy.split(ngx.ctx.upstream_url, "?")) - if utils.table_size(ngx.req.get_uri_args()) > 0 then - upstream_url = upstream_url.."?"..ngx.encode_args(ngx.req.get_uri_args()) + local upstream_url = ngx.ctx.upstream_url + local uri_args = ngx.req.get_uri_args() + if utils.table_size(uri_args) > 0 then + upstream_url = upstream_url.."?"..utils.encode_args(uri_args) end -- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx diff --git a/kong/core/resolver.lua b/kong/core/resolver.lua index 5805de7c7218..6b20eb442c80 100644 --- a/kong/core/resolver.lua +++ b/kong/core/resolver.lua @@ -214,7 +214,7 @@ local function url_has_path(url) end function _M.execute(request_uri, request_headers) - local uri = stringy.split(request_uri, "?")[1] + local uri = unpack(stringy.split(request_uri, "?")) local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers) if err then return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) diff --git a/kong/tools/utils.lua b/kong/tools/utils.lua index 8423d3cd5e8b..e145197f9ad4 100644 --- a/kong/tools/utils.lua +++ b/kong/tools/utils.lua @@ -1,8 +1,23 @@ --- --- Module containing some general utility functions +-- Module containing some general utility functions used in many places in Kong. +-- +-- NOTE: Before implementing a function here, consider if it will be used in many places +-- across Kong. If not, a local function in the appropriate module is prefered. +-- +local url = require "socket.url" local uuid = require "lua_uuid" +local type = type +local pairs = pairs +local ipairs = ipairs +local tostring = tostring +local table_sort = table.sort +local table_concat = table.concat +local table_insert = table.insert +local string_find = string.find +local string_format = string.format + local _M = {} --- Generates a random unique string @@ -11,6 +26,51 @@ function _M.random_string() return uuid():gsub("-", "") end +local function encode_args_value(key, value) + key = url.escape(key) + if value ~= nil then + return string_format("%s=%s", key, url.escape(value)) + else + return key + end +end + +--- Encode a Lua table to a querystring +-- Tries to mimic ngx_lua's `ngx.encode_args`, but also percent-encode querystring values. +-- Supports multi-value query args, boolean values. +-- @TODO drop and use `ngx.encode_args` once it implements percent-encoding. +-- @see https://github.com/Mashape/kong/issues/749 +-- @param[type=table] args A key/value table containing the query args to encode +-- @treturn string A valid querystring (without the prefixing '?') +function _M.encode_args(args) + local query = {} + local keys = {} + + for k in pairs(args) do + keys[#keys+1] = k + end + + table_sort(keys) + + for _, key in ipairs(keys) do + local value = args[key] + if type(value) == "table" then + for _, sub_value in ipairs(value) do + query[#query+1] = encode_args_value(key, sub_value) + end + elseif value == true then + query[#query+1] = encode_args_value(key) + elseif value ~= false and value ~= nil then + value = tostring(value) + if value ~= "" then + query[#query+1] = encode_args_value(key, value) + end + end + end + + return table_concat(query, "&") +end + --- Calculates a table size. -- All entries both in array and hash part. -- @param t The table to use @@ -99,7 +159,7 @@ function _M.add_error(errors, k, v) errors[k] = setmetatable({errors[k]}, err_list_mt) end - table.insert(errors[k], v) + table_insert(errors[k], v) else errors[k] = v end @@ -118,7 +178,7 @@ function _M.load_module_if_exists(module_name) if status then return true, res -- Here we match any character because if a module has a dash '-' in its name, we would need to escape it. - elseif type(res) == "string" and string.find(res, "module '"..module_name.."' not found", nil, true) then + elseif type(res) == "string" and string_find(res, "module '"..module_name.."' not found", nil, true) then return false else error(res) diff --git a/spec/integration/proxy/api_resolver_spec.lua b/spec/integration/proxy/resolver_spec.lua similarity index 76% rename from spec/integration/proxy/api_resolver_spec.lua rename to spec/integration/proxy/resolver_spec.lua index 90557dc9401b..3639948e0ed3 100644 --- a/spec/integration/proxy/api_resolver_spec.lua +++ b/spec/integration/proxy/resolver_spec.lua @@ -21,7 +21,6 @@ local function parse_cert(cert) end describe("Resolver", function() - setup(function() spec_helper.prepare_db() spec_helper.insert_fixtures { @@ -55,22 +54,6 @@ describe("Resolver", function() spec_helper.stop_kong() end) - describe("Test URI", function() - - it("should URL decode the URI with querystring", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", { hello = "world"}, {host = "mockbin-uri.com"}) - assert.equal(200, status) - assert.equal("http://mockbin.org/request/hello%2f?hello=world", cjson.decode(response).url) - end) - - it("should URL decode the URI without querystring", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", nil, {host = "mockbin-uri.com"}) - assert.equal(200, status) - assert.equal("http://mockbin.org/request/hello%2f", cjson.decode(response).url) - end) - - end) - describe("Inexistent API", function() it("should return Not Found when the API is not in Kong", function() local response, status, headers = http_client.get(spec_helper.STUB_GET_URL, nil, {host = "foo.com"}) @@ -171,18 +154,6 @@ describe("Resolver", function() assert.equal("/somerequest_path/status/200", body.request_path) assert.equal(404, status) end) - it("should proxy and strip the request_path if `strip_request_path` is true", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") - assert.equal(200, status) - local body = cjson.decode(response) - assert.equal("http://mockbin.com/request", body.url) - end) - it("should proxy and strip the request_path if `strip_request_path` is true if request_path has pattern characters", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request") - assert.equal(200, status) - local body = cjson.decode(response) - assert.equal("http://mockbin.com/request", body.url) - end) it("should proxy when the request_path has a deep level", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/deep/request_path/status/200") assert.equal(200, status) @@ -191,33 +162,11 @@ describe("Resolver", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/mockbin?foo=bar") assert.equal(200, status) end) - it("should not strip if the `request_path` pattern is repeated in the request_uri", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request") - assert.equal(200, status) - local body = cjson.decode(response) - local upstream_url = body.log.entries[1].request.url - assert.equal("http://mockbin.com/har/of/request", upstream_url) - end) - it("should not add a trailing slash when strip_path is enabled", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", { hello = "world"}) - assert.equal(200, status) - assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) - end) it("should not add a trailing slash when strip_path is disabled", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", { hello = "world"}) + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", {hello = "world"}) assert.equal(200, status) assert.equal("http://www.mockbin.org/request/test-trailing-slash2?hello=world", cjson.decode(response).url) end) - it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", { hello = "world"}) - assert.equal(200, status) - assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) - end) - it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", { hello = "world"}) - assert.equal(200, status) - assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) - end) end) it("should return the correct Server and Via headers when the request was proxied", function() @@ -240,7 +189,7 @@ describe("Resolver", function() end) end) - describe("Preseve Host", function() + describe("preserve_host", function() it("should not preserve the host (default behavior)", function() local response, status = http_client.get(PROXY_URL.."/get", nil, {host = "httpbin-nopreserve.com"}) assert.equal(200, status) @@ -255,5 +204,59 @@ describe("Resolver", function() assert.equal("httpbin-preserve.com", parsed_response.headers["Host"]) end) end) - + + describe("strip_path", function() + it("should strip the request_path if `strip_request_path` is true", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") + assert.equal(200, status) + local body = cjson.decode(response) + assert.equal("http://mockbin.com/request", body.url) + end) + it("should strip the request_path if `strip_request_path` is true if `request_path` has pattern characters", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request") + assert.equal(200, status) + local body = cjson.decode(response) + assert.equal("http://mockbin.com/request", body.url) + end) + it("should not strip if the `request_path` pattern is repeated in the request_uri", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request") + assert.equal(200, status) + local body = cjson.decode(response) + local upstream_url = body.log.entries[1].request.url + assert.equal("http://mockbin.com/har/of/request", upstream_url) + end) + it("should not add a trailing slash when strip_path is enabled", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + end) + + describe("Percent-encoding", function() + it("should leave percent-encoded values in URI untouched", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2Fworld", {}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request/hello%2fworld", cjson.decode(response).url) + end) + it("should leave untouched percent-encoded values in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL.."/", {foo = "abc%7Cdef%2c%20world"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request/?foo=abc%7cdef%2c%20world", cjson.decode(response).url) + end) + it("should leave untouched percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL.."/", {["hello%20world"] = "foo"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request/?hello%20world=foo", cjson.decode(response).url) + end) + end) end) diff --git a/spec/unit/tools/utils_spec.lua b/spec/unit/tools/utils_spec.lua index 92e997ca1eb8..198a64216504 100644 --- a/spec/unit/tools/utils_spec.lua +++ b/spec/unit/tools/utils_spec.lua @@ -2,17 +2,80 @@ local utils = require "kong.tools.utils" describe("Utils", function() - describe("strings", function() - local first = utils.random_string() - assert.truthy(first) - assert.falsy(first:find("-")) - local second = utils.random_string() - assert.falsy(first == second) - end) + describe("string", function() + describe("random_string()", function() + it("should return a random string", function() + local first = utils.random_string() + assert.truthy(first) + assert.falsy(first:find("-")) + + local second = utils.random_string() + assert.not_equal(first, second) + end) + end) - describe("tables", function() - describe("#table_size()", function() + describe("encode_args()", function() + it("should encode a Lua table to a querystring", function() + local str = utils.encode_args { + foo = "bar", + hello = "world" + } + assert.equal("foo=bar&hello=world", str) + end) + it("should encode multi-value query args", function() + local str = utils.encode_args { + foo = {"bar", "zoo"}, + hello = "world" + } + assert.equal("foo=bar&foo=zoo&hello=world", str) + end) + it("should percent-encode given values", function() + local str = utils.encode_args { + encode = {"abc|def", ",$@|`"} + } + assert.equal("encode=abc%7cdef&encode=%2c%24%40%7c%60", str) + end) + it("should percent-encode given query args keys", function() + local str = utils.encode_args { + ["hello world"] = "foo" + } + assert.equal("hello%20world=foo", str) + end) + it("should support Lua numbers", function() + local str = utils.encode_args { + a = 1, + b = 2 + } + assert.equal("a=1&b=2", str) + end) + it("should support a boolean argument", function() + local str = utils.encode_args { + a = true, + b = 1 + } + assert.equal("a&b=1", str) + end) + it("should ignore nil and false values", function() + local str = utils.encode_args { + a = nil, + b = false + } + assert.equal("", str) + end) + it("should encode complex query args", function() + local str = utils.encode_args { + multiple = {"hello, world"}, + hello = "world", + ignore = false, + ["multiple values"] = true + } + assert.equal("hello=world&multiple=hello%2c%20world&multiple%20values", str) + end) + end) + end) + describe("table", function() + describe("table_size()", function() it("should return the size of a table", function() assert.are.same(0, utils.table_size(nil)) assert.are.same(0, utils.table_size({})) @@ -20,44 +83,36 @@ describe("Utils", function() assert.are.same(2, utils.table_size({ foo = "bar", bar = "baz" })) assert.are.same(2, utils.table_size({ "foo", "bar" })) end) - end) - describe("#table_contains()", function() - + describe("table_contains()", function() it("should return false if a value is not contained in a nil table", function() assert.False(utils.table_contains(nil, "foo")) end) - it("should return true if a value is contained in a table", function() local t = { foo = "hello", bar = "world" } assert.True(utils.table_contains(t, "hello")) end) - it("should return false if a value is not contained in a table", function() local t = { foo = "hello", bar = "world" } assert.False(utils.table_contains(t, "foo")) end) - end) - describe("#is_array()", function() - + describe("is_array()", function() it("should know when an array ", function() assert.True(utils.is_array({ "a", "b", "c", "d" })) assert.True(utils.is_array({ ["1"] = "a", ["2"] = "b", ["3"] = "c", ["4"] = "d" })) assert.False(utils.is_array({ "a", "b", "c", foo = "d" })) end) - end) - describe("#add_error()", function() + describe("add_error()", function() local add_error = utils.add_error it("should create a table if given `errors` is nil", function() assert.same({hello = "world"}, add_error(nil, "hello", "world")) end) - it("should add a key/value when the key does not exists", function() local errors = {hello = "world"} assert.same({ @@ -65,10 +120,8 @@ describe("Utils", function() foo = "bar" }, add_error(errors, "foo", "bar")) end) - it("should transform previous values to a list if the same key is given again", function() - local e = nil - + local e e = add_error(e, "key1", "value1") e = add_error(e, "key2", "value2") assert.same({key1 = "value1", key2 = "value2"}, e) @@ -82,10 +135,8 @@ describe("Utils", function() e = add_error(e, "key2", "value7") assert.same({key1 = {"value1", "value3", "value4", "value5", "value6"}, key2 = {"value2", "value7"}}, e) end) - it("should also list tables pushed as errors", function() - local e = nil - + local e e = add_error(e, "key1", "value1") e = add_error(e, "key2", "value2") e = add_error(e, "key1", "value3") @@ -100,11 +151,9 @@ describe("Utils", function() keyO = {{message = "some error"}, {message = "another"}} }, e) end) - end) - describe("#load_module_if_exists()", function() - + describe("load_module_if_exists()", function() it("should return false if the module does not exist", function() local loaded, mod assert.has_no.errors(function() @@ -113,7 +162,6 @@ describe("Utils", function() assert.False(loaded) assert.falsy(mod) end) - it("should throw an error if the module is invalid", function() local loaded, mod assert.has.errors(function() @@ -122,7 +170,6 @@ describe("Utils", function() assert.falsy(loaded) assert.falsy(mod) end) - it("should load a module if it was found and valid", function() local loaded, mod assert.has_no.errors(function() @@ -132,7 +179,6 @@ describe("Utils", function() assert.truthy(mod) assert.are.same("All your base are belong to us.", mod.exposed) end) - end) end) end)