Skip to content

Commit

Permalink
在错误信息中将参数名称也显示出来
Browse files Browse the repository at this point in the history
  • Loading branch information
xebecnan committed Nov 1, 2021
1 parent ee2f9ae commit 3ec4f59
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 33 deletions.
20 changes: 15 additions & 5 deletions src/binder.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ local TYPE_NAME2ID = Types.TYPE_NAME2ID
local sf = string.format
local ast_error = Util.ast_error
local dump_table = Util.dump_table
local errorf = function(...)
error(sf(...))
end

local function convert_type(ast)
if ast.tag == 'TypeFunction' then
Expand Down Expand Up @@ -83,17 +86,20 @@ end
local function inference_type_for_function(ast)
local n_funcname = ast[1]
local n_parlist = ast[2]
local n_args = { tag='TypeArgList', info=n_parlist.info }
local n_partypes = { tag='TypeArgList', info=n_parlist.info }
for i = 1, #n_parlist do
local n_par = n_parlist[i]
if n_par.tag == 'VarArg' then
n_args[#n_args+1] = { tag='VarArg', info=n_par.info }
n_partypes[#n_partypes+1] = { tag='VarArg', info=n_par.info }
elseif n_par.tag == 'Id' then
local any = { tag='Id', 'Any' }
n_partypes[#n_partypes+1] = { tag='FuncParameter', info=n_par.info, any, n_par }
else
n_args[#n_args+1] = { tag='Id', 'Any' }
errorf('bad parameter tag: %s', n_par.tag)
end
end
local n_ret = { tag='Id', 'Any' }
return { tag='TypeFunction', info=ast.info, n_args, n_ret }
return { tag='TypeFunction', info=ast.info, n_partypes, n_ret }
end

local function function_def_common(ast, env, walk_node)
Expand All @@ -120,6 +126,10 @@ local function function_def_common(ast, env, walk_node)
if n_par.tag ~= 'VarArg' then
Symbols.set_var(n_par, n_type)
end
if n_par.tag == 'Id' and n_type.tag ~= 'FuncParameter' then
n_type = { tag='FuncParameter', info=n_type.info, n_type, n_par }
par_types[i] = n_type
end
if n_type.tag ~= 'VarArg' then
i = i + 1
end
Expand All @@ -128,7 +138,7 @@ local function function_def_common(ast, env, walk_node)
if not error_flag and i <= #par_types then
local n_type = par_types[i]
if n_type.tag ~= 'VarArg' then
ast_error(ast, 'missing arg #%d (%s)', i, Types.get_full_type_name(n_type))
ast_error(ast, 'missing arg #%d (%s)', i, Types.get_full_type_name(n_type, true))
end
end
end
Expand Down
34 changes: 23 additions & 11 deletions src/typechecker.lua
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ match_type = function(expect, given)
return true
end

if given.tag == 'FuncParameter' then
return match_type(expect, given[1])
end

if expect.tag == 'Require' and given.tag == 'Require' then
if match_type(expect[1], given[1]) then
return true
Expand Down Expand Up @@ -252,8 +256,12 @@ match_type = function(expect, given)
return true
end
end
elseif expect.tag == 'FuncParameter' then
if match_type(expect[1], given) then
return true
end
end
return false, sf('expect "%s", but given "%s"', Types.get_full_type_name(expect), Types.get_full_type_name(given))
return false, sf('expect "%s", but given "%s"', Types.get_full_type_name(expect, false), Types.get_full_type_name(given, false))
end

local function match_node_type(node, tp)
Expand Down Expand Up @@ -317,13 +325,13 @@ function F:Call(ast, env, walk_node)
return
end

local n_parlist
local n_partypes
if si.tag == 'TypeFunction' then
n_parlist = si[1]
n_partypes = si[1]
elseif si.tag == 'OptArg' and si[1].tag == 'TypeFunction' then
n_parlist = si[1][1]
n_partypes = si[1][1]
elseif si.tag == 'Require' then
n_parlist = si[1][1]
n_partypes = si[1][1]
elseif si.tag == 'Id' and si[1] == 'Any' then
-- 调用的函数为 any 类型,跳过检查
walk_node(self, ast)
Expand All @@ -336,15 +344,19 @@ function F:Call(ast, env, walk_node)
local error_flag = false
for i_arg = 1, #n_arglist do
local n_given = n_arglist[i_arg]
local n_expet = n_parlist[i_par]
local n_expet = n_partypes[i_par]
if not n_expet then
ast_error(ast, "too many arguments to function '%s'", dump_funcname(n_funcname))
error_flag = true
break
end
local ok, err = match_node_type(n_given, n_expet)
if not ok then
ast_error(ast, sf('arg #%d, %s', i_par, err))
if n_expet == 'FuncParameter' then
ast_error(ast, sf('arg #%d "%s", %s', i_par, n_expet[2][1], err))
else
ast_error(ast, sf('arg #%d, %s', i_par, err))
end
error_flag = true
break
end
Expand All @@ -354,11 +366,11 @@ function F:Call(ast, env, walk_node)
end
end

if not error_flag and i_par <= #n_parlist then
local n_expet = n_parlist[i_par]
if not error_flag and i_par <= #n_partypes then
local n_expet = n_partypes[i_par]
if n_expet.tag ~= 'VarArg' and n_expet.tag ~= 'OptArg' then
ast_error(ast, "missing arg #%d (%s) to function '%s'",
i_par, Types.get_full_type_name(n_expet), dump_funcname(n_funcname))
ast_error(ast, 'missing arg #%d "%s" to function "%s"',
i_par, Types.get_full_type_name(n_expet, true), dump_funcname(n_funcname))
end
end

Expand Down
20 changes: 13 additions & 7 deletions src/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ local function dump_type_obj(ast)
b[#b+1] = ' '
b[#b+1] = k
b[#b+1] = ':'
b[#b+1] = M.get_full_type_name(n_fieldtype)
b[#b+1] = M.get_full_type_name(n_fieldtype, false)
if i < #ast then
b[#b+1] = ';'
end
Expand All @@ -271,7 +271,7 @@ local function dump_type_table(ast)
b[#b+1] = '[?]'
end
b[#b+1] = ':'
b[#b+1] = M.get_full_type_name(M.get_node_type(nv))
b[#b+1] = M.get_full_type_name(M.get_node_type(nv), false)
if i < #ast - 1 then
b[#b+1] = ';'
end
Expand All @@ -280,19 +280,19 @@ local function dump_type_table(ast)
return table.concat(b, '')
end

function M.get_full_type_name(ast)
function M.get_full_type_name(ast, with_par_name)
if ast.tag == 'Id' then
return M.get_type_name(ast[1])
elseif ast.tag == 'TypeFunction' then
local b = {}
for i = 1, #ast[1] do
b[#b+1] = M.get_full_type_name(ast[1][i])
b[#b+1] = M.get_full_type_name(ast[1][i], false)
if i ~= #ast[1] then
b[#b+1] = ', '
end
end
b[#b+1] = ' >> '
b[#b+1] = M.get_full_type_name(ast[2])
b[#b+1] = M.get_full_type_name(ast[2], false)
return table.concat(b, '')
elseif ast.tag == 'TypeAlias' then
return ast[1]
Expand All @@ -303,9 +303,15 @@ function M.get_full_type_name(ast)
elseif ast.tag == 'VarArg' then
return '...'
elseif ast.tag == 'OptArg' then
return M.get_full_type_name(ast[1]) .. '?'
return M.get_full_type_name(ast[1], with_par_name) .. '?'
elseif ast.tag == 'Require' then
return '(require) ' .. M.get_full_type_name(ast[1])
return '(require) ' .. M.get_full_type_name(ast[1], with_par_name)
elseif ast.tag == 'FuncParameter' then
if with_par_name then
return sf("%s (%s)", ast[2][1], M.get_full_type_name(ast[1], false))
else
return sf("%s", M.get_full_type_name(ast[1], false))
end
-- elseif ast.tag == 'CloseTypeObj' then
-- error
else
Expand Down
26 changes: 16 additions & 10 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,41 @@ local function assert_func(v)
end


local function foo()
local function foo1()
-->> bar :: number >> number
local function bar(v)
return v + 1
end

-- OK
bar(1)
bar(1.1)
assert_func(10 > 1)

-- ERROR
bar('a')
bar({})
bar(false)
bar(true)

assert_func(10 > 1)
assert_func(10)
end

local function bar()
local function foo2()
-->> bar :: string >> number
local function bar(v)
return #v
end

-- OK
bar('a')
assert_func(true)

-- ERROR
bar(1)
bar(1.1)
bar('a')
bar({})
bar(false)
bar(true)

assert_func(true)
assert_func(print())
end

Expand All @@ -45,16 +49,18 @@ local function add(a, b)
return a + b
end

add(1, 2)

-->> do_add :: (number, number >> number), number, number >> number
local function do_add(f_add, a, b)
return f_add(a, b)
end

do_add(assert_func, 3, 4)
-- OK
add(1, 2)
do_add(add, 3, 4)

-- ERROR
do_add(assert_func, 3, 4)

--[[>>
AstNode = {
tag : string;
Expand Down

0 comments on commit 3ec4f59

Please sign in to comment.