You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
8.6 KiB

-- COMMAND LINE ----------------------------------------------------------------
local function print_help()
print "Usage: target.lua [options] [regex]"
print "\t-h show this message"
print "\t-q quiet mode"
print "\t-s disable colours"
end
local quiet, grey, test_regex
if arg then
for i = 1, #arg do
if arg[i] == "-h" then
print_help()
os.exit(0)
elseif arg[i] == "-q" then
quiet = true
elseif arg[i] == "-s" then
grey = true
elseif not test_regex then
test_regex = arg[i]
else
print_help()
os.exit(1)
end
end
end
-- UTILS -----------------------------------------------------------------------
local function red(str) return grey and str or "\27[1;31m" .. str .. "\27[0m" end
local function blue(str) return grey and str or "\27[1;34m" .. str .. "\27[0m" end
local function green(str) return grey and str or "\27[1;32m" .. str .. "\27[0m" end
local function yellow(str) return grey and str or "\27[1;33m" .. str .. "\27[0m" end
local tab_tag = blue "[----------]"
local done_tag = blue "[==========]"
local run_tag = blue "[ RUN ]"
local ok_tag = green "[ OK ]"
local fail_tag = red "[ FAIL]"
local disabled_tag = yellow "[ DISABLED ]"
local passed_tag = green "[ PASSED ]"
local failed_tag = red "[ FAILED ]"
local ntests = 0
local failed = false
local failed_list = {}
local function trace(start_frame)
print "Trace:"
local frame = start_frame
while true do
local info = debug.getinfo(frame, "Sl")
if not info then break end
if info.what == "C" then
print(frame - start_frame, "??????")
else
print(frame - start_frame, info.short_src .. ":" .. info.currentline .. " ")
end
frame = frame + 1
end
end
local function log(msg)
if not quiet then
print(msg)
end
end
local function fail(msg, start_frame)
failed = true
print("Fail: " .. msg)
trace(start_frame or 4)
end
local function stringize_var_arg(...)
local args = { n = select("#", ...), ... }
local result = {}
for i = 1, args.n do
if type(args[i]) == 'string' then
result[i] = '"' .. tostring(args[i]) .. '"'
else
result[i] = tostring(args[i])
end
end
return table.concat(result, ", ")
end
local function test_pretty_name(suite_name, test_name)
if suite_name == "__root" then
return test_name
else
return suite_name .. "." .. test_name
end
end
-- PUBLIC API -----------------------------------------------------------------
local api = { test_suite_name = "__root", skip = false }
api.assert = function (cond)
if not cond then
fail("assertion " .. tostring(cond) .. " failed")
end
end
api.equal = function (l, r)
if l ~= r then
fail(tostring(l) .. " ~= " .. tostring(r))
end
end
api.not_equal = function (l, r)
if l == r then
fail(tostring(l) .. " == " .. tostring(r))
end
end
api.almost_equal = function (l, r, diff)
if require("math").abs(l - r) > diff then
fail("|" .. tostring(l) .. " - " .. tostring(r) .."| > " .. tostring(diff))
end
end
api.is_false = function (maybe_false)
if maybe_false or type(maybe_false) ~= "boolean" then
fail("got " .. tostring(maybe_false) .. " instead of false")
end
end
api.is_true = function (maybe_true)
if not maybe_true or type(maybe_true) ~= "boolean" then
fail("got " .. tostring(maybe_true) .. " instead of true")
end
end
api.is_not_nil = function (maybe_not_nil)
if type(maybe_not_nil) == "nil" then
fail("got nil")
end
end
api.error_raised = function (f, error_message, ...)
local status, err = pcall(f, ...)
if status == true then
fail("error not raised")
else
if error_message ~= nil then
-- we set "plain" to true to avoid pattern matching in string.find
if err:find(error_message, 1, true) == nil then
fail("'" .. error_message .. "' not found in error '" .. tostring(err) .. "'")
end
end
end
end
api.register_assert = function(assert_name, assert_func)
rawset(api, assert_name, function(...)
local result, msg = assert_func(...)
if not result then
msg = msg or "Assertion "..assert_name.." failed"
fail(msg)
end
end)
end
local function make_type_checker(typename)
api["is_" .. typename] = function (maybe_type)
if type(maybe_type) ~= typename then
fail("got " .. tostring(maybe_type) .. " instead of " .. typename, 4)
end
end
end
local supported_types = {"nil", "boolean", "string", "number", "userdata", "table", "function", "thread"}
for i, supported_type in ipairs(supported_types) do
make_type_checker(supported_type)
end
local last_test_suite
local function run_test(test_suite, test_name, test_function, ...)
local suite_name = test_suite.test_suite_name
local full_test_name = test_pretty_name(suite_name, test_name)
if test_regex and not string.match(full_test_name, test_regex) then
return
end
if test_suite.skip then
log(disabled_tag .. " " .. full_test_name)
return
end
if suite_name ~= last_test_suite then
log(tab_tag)
last_test_suite = suite_name
end
ntests = ntests + 1
failed = false
log(run_tag .. " " .. full_test_name)
local start = os.time()
local status, err
for _, f in ipairs({test_suite.start_up, test_function, test_suite.tear_down}) do
status, err = pcall(f, ...)
if not status then
failed = true
print(tostring(err))
end
end
local stop = os.time()
local is_test_failed = not status or failed
log(string.format("%s %s %d sec",
is_test_failed and fail_tag or ok_tag,
full_test_name,
os.difftime(stop, start)))
if is_test_failed then
table.insert(failed_list, full_test_name)
end
end
api.summary = function ()
log(done_tag)
local nfailed = #failed_list
if nfailed == 0 then
log(passed_tag .. " " .. ntests .. " test(s)")
os.exit(0)
else
log(failed_tag .. " " .. nfailed .. " out of " .. ntests .. ":")
for _, test_name in ipairs(failed_list) do
log(failed_tag .. "\t" .. test_name)
end
os.exit(1)
end
end
api.result = function ( ... )
return ntests, #failed_list
end
local default_start_up = function () end
local default_tear_down = function () collectgarbage() end
api.start_up = default_start_up
api.tear_down = default_tear_down
local all_test_cases = { __root = {} }
local function handle_new_test(suite, test_name, test_function)
local suite_name = suite.test_suite_name
if not all_test_cases[suite_name] then
all_test_cases[suite_name] = {}
end
all_test_cases[suite_name][test_name] = test_function
local info = debug.getinfo(test_function)
if info.nparams == nil and
string.sub(test_name, #test_name - 1, #test_name) ~= "_p"
or info.nparams == 0 then
run_test(suite, test_name, test_function)
end
end
local function lookup_test_with_params(suite, test_name)
local suite_name = suite.test_suite_name
if all_test_cases[suite_name] and all_test_cases[suite_name][test_name] then
return function (...)
run_test(suite
, test_name .. "(" .. stringize_var_arg(...) .. ")"
, all_test_cases[suite_name][test_name], ...)
end
else
local full_test_name = test_pretty_name(suite_name, test_name)
table.insert(failed_list, full_test_name)
ntests = ntests + 1
log(fail_tag .. " No " .. full_test_name .. " parametrized test case!")
end
end
local function new_test_suite(_, name)
local test_suite = {
test_suite_name = name,
start_up = default_start_up,
tear_down = default_tear_down,
skip = false }
setmetatable(test_suite, {
__newindex = handle_new_test,
__index = lookup_test_with_params })
return test_suite
end
local test_suites = {}
setmetatable(api, {
__index = function (tbl, name)
if all_test_cases.__root[name] then
return lookup_test_with_params(tbl, name)
end
if not test_suites[name] then
test_suites[name] = new_test_suite(tbl, name)
end
return test_suites[name]
end,
__newindex = handle_new_test
})
return api