Skip to content

Commit

Permalink
specs(resolver) better unit testing for resolver
Browse files Browse the repository at this point in the history
- directly test `resolver.execute()`
- adds a stub of `ngx.shared.DICT` for unit testing where
  `kong.tools.cache` is being used (it requires shm dicts)
- tiny refactors for testability
- more use cases covered in specs
  • Loading branch information
thibaultcha committed Dec 1, 2015
1 parent 24cc4e8 commit 571b631
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 79 deletions.
8 changes: 6 additions & 2 deletions kong/core/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ local certificate = require "kong.core.certificate"

local table_insert = table.insert
local math_floor = math.floor
local unpack = unpack
local ipairs = ipairs

local MULT = 10^3
local function round(num)
Expand All @@ -43,7 +45,7 @@ return {
access = {
before = function()
ngx.ctx.KONG_ACCESS_START = ngx.now()
ngx.ctx.api, ngx.ctx.upstream_url = resolver.execute()
ngx.ctx.api, ngx.ctx.upstream_url, ngx.ctx.upstream_host = resolver.execute(ngx.var.request_uri, ngx.req.get_headers())
end,
-- Only executed if the `resolver` module found an API and allows nginx to proxy it.
after = function()
Expand All @@ -58,8 +60,10 @@ return {
upstream_url = upstream_url.."?"..ngx.encode_args(ngx.req.get_uri_args())
end

-- Set the `$upstream_url` variable for the `proxy_pass` nginx's directive.
-- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx
-- directive in kong.yml.
ngx.var.upstream_url = upstream_url
ngx.var.upstream_host = ngx.ctx.upstream_host
end
},
header_filter = {
Expand Down
53 changes: 28 additions & 25 deletions kong/core/resolver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ local function get_upstream_url(api)
return result
end

local function get_host_from_url(val)
local function get_host_from_upstream_url(val)
local parsed_url = url.parse(val)

local port
Expand Down Expand Up @@ -99,7 +99,7 @@ function _M.load_apis_in_memory()
end

function _M.find_api_by_request_host(req_headers, apis_dics)
local all_hosts = {}
local hosts_list = {}
for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do
local hosts = req_headers[header_name]
if hosts then
Expand All @@ -109,9 +109,9 @@ function _M.find_api_by_request_host(req_headers, apis_dics)
-- for all values of this header, try to find an API using the apis_by_dns dictionnary
for _, host in ipairs(hosts) do
host = unpack(stringy.split(host, ":"))
table_insert(all_hosts, host)
table_insert(hosts_list, host)
if apis_dics.by_dns[host] then
return apis_dics.by_dns[host]
return apis_dics.by_dns[host], host
else
-- If the API was not found in the dictionary, maybe it is a wildcard request_host.
-- In that case, we need to loop over all of them.
Expand All @@ -125,7 +125,7 @@ function _M.find_api_by_request_host(req_headers, apis_dics)
end
end

return nil, all_hosts
return nil, nil, hosts_list
end

-- To do so, we have to compare entire URI segments (delimited by "/").
Expand Down Expand Up @@ -180,13 +180,14 @@ end
-- We keep APIs in the database cache for a longer time than usual.
-- @see https://github.com/Mashape/kong/issues/15 for an improvement on this.
--
-- @param `uri` The URI for this request.
-- @return `err` Any error encountered during the retrieval.
-- @return `api` The retrieved API, if any.
-- @return `hosts` The list of headers values found in Host and X-Host-Override.
-- @param `uri` The URI for this request.
-- @return `err` Any error encountered during the retrieval.
-- @return `api` The retrieved API, if any.
-- @return `matched_host` The host that was matched for this API, if matched.
-- @return `hosts` The list of headers values found in Host and X-Host-Override.
-- @return `strip_request_path_pattern` If the API was retrieved by request_path, contain the pattern to strip it from the URI.
local function find_api(uri)
local api, all_hosts, strip_request_path_pattern
local function find_api(uri, headers)
local api, matched_host, hosts_list, strip_request_path_pattern

-- Retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", _M.load_apis_in_memory, 60) -- 60 seconds cache, longer than usual
Expand All @@ -195,37 +196,37 @@ local function find_api(uri)
end

-- Find by Host header
api, all_hosts = _M.find_api_by_request_host(ngx.req.get_headers(), apis_dics)

api, matched_host, hosts_list = _M.find_api_by_request_host(headers, apis_dics)
-- If it was found by Host, return
if api then
return nil, api, all_hosts
return nil, api, matched_host, hosts_list
end

-- Otherwise, we look for it by request_path. We have to loop over all APIs and compare the requested URI.
api, strip_request_path_pattern = _M.find_api_by_request_path(uri, apis_dics.request_path_arr)

return nil, api, all_hosts, strip_request_path_pattern
return nil, api, nil, hosts_list, strip_request_path_pattern
end

local function url_has_path(url)
local _, count_slashes = string.gsub(url, "/", "")
local _, count_slashes = string_gsub(url, "/", "")
return count_slashes > 2
end

function _M.execute()
local uri = stringy.split(ngx.var.request_uri, "?")[1]
local err, api, hosts, strip_request_path_pattern = find_api(uri)
function _M.execute(request_uri, request_headers)
local uri = stringy.split(request_uri, "?")[1]
local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers)
if err then
return responses.send_HTTP_INTERNAL_SERVER_ERROR(err)
elseif not api then
return responses.send_HTTP_NOT_FOUND {
message = "API not found with these values",
request_host = hosts,
request_host = hosts_list,
request_path = uri
}
end

local upstream_host
local upstream_url = get_upstream_url(api)

-- If API was retrieved by request_path and the request_path needs to be stripped
Expand All @@ -235,13 +236,15 @@ function _M.execute()

upstream_url = upstream_url..uri

-- Set the
if api.preserve_host then
ngx.var.upstream_host = ngx.req.get_headers()["host"]
else
ngx.var.upstream_host = get_host_from_url(upstream_url)
upstream_host = matched_host
end
return api, upstream_url

if upstream_host == nil then
upstream_host = get_host_from_upstream_url(upstream_url)
end

return api, upstream_url, upstream_host
end

return _M
96 changes: 96 additions & 0 deletions kong/tools/ngx_stub.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,101 @@

local reg = require "rex_pcre"

-- DICT Proxy
-- https://github.com/bsm/fakengx/blob/master/fakengx.lua

local SharedDict = {}

local function set(data, key, value)
data[key] = {
value = value,
info = {expired = false}
}
end

function SharedDict:new()
return setmetatable({data = {}}, {__index = self})
end

function SharedDict:get(key)
return self.data[key] and self.data[key].value, nil
end

function SharedDict:set(key, value)
set(self.data, key, value)
return true, nil, false
end

SharedDict.safe_set = SharedDict.set

function SharedDict:add(key, value)
if self.data[key] ~= nil then
return false, "exists", false
end

set(self.data, key, value)
return true, nil, false
end

function SharedDict:replace(key, value)
if self.data[key] == nil then
return false, "not found", false
end

set(self.data, key, value)
return true, nil, false
end

function SharedDict:delete(key)
self.data[key] = nil
end

function SharedDict:incr(key, value)
if not self.data[key] then
return nil, "not found"
elseif type(self.data[key].value) ~= "number" then
return nil, "not a number"
end

self.data[key].value = self.data[key].value + value
return self.data[key].value, nil
end

function SharedDict:flush_all()
for _, item in pairs(self.data) do
item.info.expired = true
end
end

function SharedDict:flush_expired(n)
local data = self.data
local flushed = 0

for key, item in pairs(self.data) do
if item.info.expired then
data[key] = nil
flushed = flushed + 1
if n and flushed == n then
break
end
end
end

self.data = data

return flushed
end

local shared = {}
local shared_mt = {
__index = function(self, key)
if shared[key] == nil then
shared[key] = SharedDict:new()
end
return shared[key]
end
}

_G.ngx = {
req = {},
ctx = {},
Expand All @@ -19,6 +114,7 @@ _G.ngx = {
timer = {
at = function() end
},
shared = setmetatable({}, shared_mt),
re = {
match = reg.match,
gsub = function(str, pattern, sub)
Expand Down
Loading

0 comments on commit 571b631

Please sign in to comment.