Skip to content

Commit

Permalink
fix(resolver) percent-encode query args
Browse files Browse the repository at this point in the history
Percent-encode query args when re-attaching them to the `upstream_uri`.
Since `ngx.encode_args` does not perform percent-encoding on various
reserved characters, this implements a custom `utils.encode_args`
function which uses LuaSocket's `url.encode` function. It tries to mimic
the `ngx.encode_uri` behaviour 100%.

Ideally, `ngx.encode_args` would proceed to the percent-encoding itself (see
openresty/lua-nginx-module#542).

This also makes some perf and style changes.

Fix #749
  • Loading branch information
thibaultcha committed Dec 1, 2015
1 parent 571b631 commit 6d467c4
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 96 deletions.
14 changes: 8 additions & 6 deletions kong/core/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
--
-- @see https://github.com/openresty/lua-nginx-module#ngxctx

local url = require "socket.url"
local utils = require "kong.tools.utils"
local reports = require "kong.core.reports"
local stringy = require "stringy"
local resolver = require "kong.core.resolver"
local constants = require "kong.constants"
local certificate = require "kong.core.certificate"

local table_insert = table.insert
local math_floor = math.floor
local unpack = unpack
local type = type
local ipairs = ipairs
local math_floor = math.floor
local table_insert = table.insert

local MULT = 10^3
local function round(num)
Expand All @@ -55,9 +56,10 @@ return {
ngx.ctx.KONG_PROXIED = true

-- Append any querystring parameters modified during plugins execution
local upstream_url = unpack(stringy.split(ngx.ctx.upstream_url, "?"))
if utils.table_size(ngx.req.get_uri_args()) > 0 then
upstream_url = upstream_url.."?"..ngx.encode_args(ngx.req.get_uri_args())
local upstream_url = ngx.ctx.upstream_url
local uri_args = ngx.req.get_uri_args()
if utils.table_size(uri_args) > 0 then
upstream_url = upstream_url.."?"..utils.encode_args(uri_args)
end

-- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx
Expand Down
2 changes: 1 addition & 1 deletion kong/core/resolver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ local function url_has_path(url)
end

function _M.execute(request_uri, request_headers)
local uri = stringy.split(request_uri, "?")[1]
local uri = unpack(stringy.split(request_uri, "?"))
local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers)
if err then
return responses.send_HTTP_INTERNAL_SERVER_ERROR(err)
Expand Down
66 changes: 63 additions & 3 deletions kong/tools/utils.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
---
-- Module containing some general utility functions
-- Module containing some general utility functions used in many places in Kong.
--
-- NOTE: Before implementing a function here, consider if it will be used in many places
-- across Kong. If not, a local function in the appropriate module is prefered.
--

local url = require "socket.url"
local uuid = require "lua_uuid"

local type = type
local pairs = pairs
local ipairs = ipairs
local tostring = tostring
local table_sort = table.sort
local table_concat = table.concat
local table_insert = table.insert
local string_find = string.find
local string_format = string.format

local _M = {}

--- Generates a random unique string
Expand All @@ -11,6 +26,51 @@ function _M.random_string()
return uuid():gsub("-", "")
end

local function encode_args_value(key, value)
key = url.escape(key)
if value ~= nil then
return string_format("%s=%s", key, url.escape(value))
else
return key
end
end

--- Encode a Lua table to a querystring
-- Tries to mimic ngx_lua's `ngx.encode_args`, but also percent-encode querystring values.
-- Supports multi-value query args, boolean values.
-- @TODO drop and use `ngx.encode_args` once it implements percent-encoding.
-- @see https://github.com/Mashape/kong/issues/749
-- @param[type=table] args A key/value table containing the query args to encode
-- @treturn string A valid querystring (without the prefixing '?')
function _M.encode_args(args)
local query = {}
local keys = {}

for k in pairs(args) do
keys[#keys+1] = k
end

table_sort(keys)

for _, key in ipairs(keys) do
local value = args[key]
if type(value) == "table" then
for _, sub_value in ipairs(value) do
query[#query+1] = encode_args_value(key, sub_value)
end
elseif value == true then
query[#query+1] = encode_args_value(key)
elseif value ~= false and value ~= nil then
value = tostring(value)
if value ~= "" then
query[#query+1] = encode_args_value(key, value)
end
end
end

return table_concat(query, "&")
end

--- Calculates a table size.
-- All entries both in array and hash part.
-- @param t The table to use
Expand Down Expand Up @@ -99,7 +159,7 @@ function _M.add_error(errors, k, v)
errors[k] = setmetatable({errors[k]}, err_list_mt)
end

table.insert(errors[k], v)
table_insert(errors[k], v)
else
errors[k] = v
end
Expand All @@ -118,7 +178,7 @@ function _M.load_module_if_exists(module_name)
if status then
return true, res
-- Here we match any character because if a module has a dash '-' in its name, we would need to escape it.
elseif type(res) == "string" and string.find(res, "module '"..module_name.."' not found", nil, true) then
elseif type(res) == "string" and string_find(res, "module '"..module_name.."' not found", nil, true) then
return false
else
error(res)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ local function parse_cert(cert)
end

describe("Resolver", function()

setup(function()
spec_helper.prepare_db()
spec_helper.insert_fixtures {
Expand Down Expand Up @@ -55,22 +54,6 @@ describe("Resolver", function()
spec_helper.stop_kong()
end)

describe("Test URI", function()

it("should URL decode the URI with querystring", function()
local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", { hello = "world"}, {host = "mockbin-uri.com"})
assert.equal(200, status)
assert.equal("http://mockbin.org/request/hello%2f?hello=world", cjson.decode(response).url)
end)

it("should URL decode the URI without querystring", function()
local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", nil, {host = "mockbin-uri.com"})
assert.equal(200, status)
assert.equal("http://mockbin.org/request/hello%2f", cjson.decode(response).url)
end)

end)

describe("Inexistent API", function()
it("should return Not Found when the API is not in Kong", function()
local response, status, headers = http_client.get(spec_helper.STUB_GET_URL, nil, {host = "foo.com"})
Expand Down Expand Up @@ -171,18 +154,6 @@ describe("Resolver", function()
assert.equal("/somerequest_path/status/200", body.request_path)
assert.equal(404, status)
end)
it("should proxy and strip the request_path if `strip_request_path` is true", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request")
assert.equal(200, status)
local body = cjson.decode(response)
assert.equal("http://mockbin.com/request", body.url)
end)
it("should proxy and strip the request_path if `strip_request_path` is true if request_path has pattern characters", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request")
assert.equal(200, status)
local body = cjson.decode(response)
assert.equal("http://mockbin.com/request", body.url)
end)
it("should proxy when the request_path has a deep level", function()
local _, status = http_client.get(spec_helper.PROXY_URL.."/deep/request_path/status/200")
assert.equal(200, status)
Expand All @@ -191,33 +162,11 @@ describe("Resolver", function()
local _, status = http_client.get(spec_helper.PROXY_URL.."/mockbin?foo=bar")
assert.equal(200, status)
end)
it("should not strip if the `request_path` pattern is repeated in the request_uri", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request")
assert.equal(200, status)
local body = cjson.decode(response)
local upstream_url = body.log.entries[1].request.url
assert.equal("http://mockbin.com/har/of/request", upstream_url)
end)
it("should not add a trailing slash when strip_path is enabled", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", { hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url)
end)
it("should not add a trailing slash when strip_path is disabled", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", { hello = "world"})
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", {hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request/test-trailing-slash2?hello=world", cjson.decode(response).url)
end)
it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", { hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url)
end)
it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", { hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url)
end)
end)

it("should return the correct Server and Via headers when the request was proxied", function()
Expand All @@ -240,7 +189,7 @@ describe("Resolver", function()
end)
end)

describe("Preseve Host", function()
describe("preserve_host", function()
it("should not preserve the host (default behavior)", function()
local response, status = http_client.get(PROXY_URL.."/get", nil, {host = "httpbin-nopreserve.com"})
assert.equal(200, status)
Expand All @@ -255,5 +204,59 @@ describe("Resolver", function()
assert.equal("httpbin-preserve.com", parsed_response.headers["Host"])
end)
end)


describe("strip_path", function()
it("should strip the request_path if `strip_request_path` is true", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request")
assert.equal(200, status)
local body = cjson.decode(response)
assert.equal("http://mockbin.com/request", body.url)
end)
it("should strip the request_path if `strip_request_path` is true if `request_path` has pattern characters", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request")
assert.equal(200, status)
local body = cjson.decode(response)
assert.equal("http://mockbin.com/request", body.url)
end)
it("should not strip if the `request_path` pattern is repeated in the request_uri", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request")
assert.equal(200, status)
local body = cjson.decode(response)
local upstream_url = body.log.entries[1].request.url
assert.equal("http://mockbin.com/har/of/request", upstream_url)
end)
it("should not add a trailing slash when strip_path is enabled", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", {hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url)
end)
it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", {hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url)
end)
it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function()
local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", {hello = "world"})
assert.equal(200, status)
assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url)
end)
end)

describe("Percent-encoding", function()
it("should leave percent-encoded values in URI untouched", function()
local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2Fworld", {}, {host = "mockbin-uri.com"})
assert.equal(200, status)
assert.equal("http://mockbin.org/request/hello%2fworld", cjson.decode(response).url)
end)
it("should leave untouched percent-encoded values in querystring", function()
local response, status = http_client.get(spec_helper.STUB_GET_URL.."/", {foo = "abc%7Cdef%2c%20world"}, {host = "mockbin-uri.com"})
assert.equal(200, status)
assert.equal("http://mockbin.org/request/?foo=abc%7cdef%2c%20world", cjson.decode(response).url)
end)
it("should leave untouched percent-encoded keys in querystring", function()
local response, status = http_client.get(spec_helper.STUB_GET_URL.."/", {["hello%20world"] = "foo"}, {host = "mockbin-uri.com"})
assert.equal(200, status)
assert.equal("http://mockbin.org/request/?hello%20world=foo", cjson.decode(response).url)
end)
end)
end)
Loading

0 comments on commit 6d467c4

Please sign in to comment.