403 lines
9.9 KiB
Lua
403 lines
9.9 KiB
Lua
local util = require("test-samurai.util")
|
|
|
|
local runner = {
|
|
name = "go",
|
|
}
|
|
|
|
local function find_block_end(lines, start_idx)
|
|
local depth = 0
|
|
local started = false
|
|
for i = start_idx, #lines do
|
|
local line = lines[i]
|
|
for j = 1, #line do
|
|
local ch = line:sub(j, j)
|
|
if ch == "{" then
|
|
depth = depth + 1
|
|
started = true
|
|
elseif ch == "}" then
|
|
if started then
|
|
depth = depth - 1
|
|
if depth == 0 then
|
|
return i - 1
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
return #lines - 1
|
|
end
|
|
|
|
local function find_test_functions(lines)
|
|
local funcs = {}
|
|
for i, line in ipairs(lines) do
|
|
local name = line:match("^%s*func%s+([%w_]+)%s*%(")
|
|
if not name then
|
|
name = line:match("^%s*func%s+%([^)]-%)%s+([%w_]+)%s*%(")
|
|
end
|
|
if name and line:find("%*testing%.T") then
|
|
local start_0 = i - 1
|
|
local end_0 = find_block_end(lines, i)
|
|
table.insert(funcs, {
|
|
name = name,
|
|
start = start_0,
|
|
["end"] = end_0,
|
|
})
|
|
end
|
|
end
|
|
return funcs
|
|
end
|
|
|
|
local function find_t_runs(lines, func)
|
|
local subtests = {}
|
|
for i = func.start + 1, func["end"] do
|
|
local line = lines[i + 1]
|
|
if line then
|
|
local name = line:match("t%.Run%(%s*['\"]([^'\"]+)['\"]")
|
|
if name then
|
|
local start_idx = i + 1
|
|
local end_0 = find_block_end(lines, start_idx)
|
|
table.insert(subtests, {
|
|
name = name,
|
|
start = start_idx - 1,
|
|
["end"] = end_0,
|
|
})
|
|
end
|
|
end
|
|
end
|
|
return subtests
|
|
end
|
|
|
|
local function escape_go_regex(s)
|
|
s = s or ""
|
|
return (s:gsub("([\\.^$|()%%[%]{}*+?%-])", "\\\\%1"))
|
|
end
|
|
|
|
local function build_run_pattern(spec)
|
|
local name = spec.test_path or ""
|
|
local escaped = escape_go_regex(name)
|
|
if spec.scope == "function" then
|
|
return "^" .. escaped .. "($|/)"
|
|
else
|
|
return "^" .. escaped .. "$"
|
|
end
|
|
end
|
|
|
|
local function build_pkg_arg(spec)
|
|
local file = spec.file
|
|
local cwd = spec.cwd
|
|
if not file or not cwd or file == "" or cwd == "" then
|
|
return "./..."
|
|
end
|
|
|
|
local dir = vim.fs.dirname(file)
|
|
if dir == cwd then
|
|
return "./"
|
|
end
|
|
|
|
if file:sub(1, #cwd) ~= cwd then
|
|
return "./..."
|
|
end
|
|
|
|
local rel = dir:sub(#cwd + 2)
|
|
if not rel or rel == "" then
|
|
return "./"
|
|
end
|
|
|
|
return "./" .. rel
|
|
end
|
|
|
|
local function collect_unique(list)
|
|
local out = {}
|
|
local seen = {}
|
|
for _, item in ipairs(list) do
|
|
if item and item ~= "" and not seen[item] then
|
|
seen[item] = true
|
|
table.insert(out, item)
|
|
end
|
|
end
|
|
return out
|
|
end
|
|
|
|
function runner.is_test_file(bufnr)
|
|
local path = util.get_buf_path(bufnr)
|
|
if not path or path == "" then
|
|
return false
|
|
end
|
|
return path:sub(-8) == "_test.go"
|
|
end
|
|
|
|
function runner.find_nearest(bufnr, row, _col)
|
|
if not runner.is_test_file(bufnr) then
|
|
return nil, "not a Go test file"
|
|
end
|
|
|
|
local lines = util.get_buf_lines(bufnr)
|
|
local funcs = find_test_functions(lines)
|
|
|
|
local current
|
|
for _, f in ipairs(funcs) do
|
|
if row >= f.start and row <= f["end"] then
|
|
current = f
|
|
break
|
|
end
|
|
end
|
|
|
|
if not current then
|
|
return nil, "cursor not inside a test function"
|
|
end
|
|
|
|
local subtests = find_t_runs(lines, current)
|
|
local inside_sub
|
|
for _, sub in ipairs(subtests) do
|
|
if row >= sub.start and row <= sub["end"] then
|
|
inside_sub = sub
|
|
break
|
|
end
|
|
end
|
|
|
|
local path = util.get_buf_path(bufnr)
|
|
local root = util.find_root(path, { "go.mod", ".git" })
|
|
|
|
if inside_sub then
|
|
local full = current.name .. "/" .. inside_sub.name
|
|
return {
|
|
file = path,
|
|
cwd = root,
|
|
test_path = full,
|
|
scope = "subtest",
|
|
func = current.name,
|
|
subtest = inside_sub.name,
|
|
}
|
|
else
|
|
return {
|
|
file = path,
|
|
cwd = root,
|
|
test_path = current.name,
|
|
scope = "function",
|
|
func = current.name,
|
|
}
|
|
end
|
|
end
|
|
|
|
function runner.build_command(spec)
|
|
local pattern = build_run_pattern(spec)
|
|
local pkg = build_pkg_arg(spec)
|
|
local cmd = { "go", "test", "-json", pkg, "-run", pattern }
|
|
return {
|
|
cmd = cmd,
|
|
cwd = spec.cwd,
|
|
}
|
|
end
|
|
|
|
function runner.build_file_command(bufnr)
|
|
local path = util.get_buf_path(bufnr)
|
|
if not path or path == "" then
|
|
return nil
|
|
end
|
|
local root = util.find_root(path, { "go.mod", ".git" })
|
|
if not root or root == "" then
|
|
root = vim.loop.cwd()
|
|
end
|
|
local spec = { file = path, cwd = root }
|
|
local pkg = build_pkg_arg(spec)
|
|
local cmd = { "go", "test", "-json", pkg }
|
|
local lines = util.get_buf_lines(bufnr)
|
|
local funcs = find_test_functions(lines)
|
|
local names = {}
|
|
for _, fn in ipairs(funcs) do
|
|
table.insert(names, fn.name)
|
|
end
|
|
names = collect_unique(names)
|
|
if #names > 0 then
|
|
local pattern_parts = {}
|
|
for _, name in ipairs(names) do
|
|
table.insert(pattern_parts, escape_go_regex(name))
|
|
end
|
|
local pattern = "^(" .. table.concat(pattern_parts, "|") .. ")$"
|
|
table.insert(cmd, "-run")
|
|
table.insert(cmd, pattern)
|
|
end
|
|
return {
|
|
cmd = cmd,
|
|
cwd = root,
|
|
}
|
|
end
|
|
|
|
function runner.build_all_command(bufnr)
|
|
local path = util.get_buf_path(bufnr)
|
|
local root
|
|
if path and path ~= "" then
|
|
root = util.find_root(path, { "go.mod", ".git" })
|
|
end
|
|
if not root or root == "" then
|
|
root = vim.loop.cwd()
|
|
end
|
|
local cmd = { "go", "test", "-json", "./..." }
|
|
return {
|
|
cmd = cmd,
|
|
cwd = root,
|
|
}
|
|
end
|
|
|
|
function runner.parse_results(output)
|
|
if not output or output == "" then
|
|
return { passes = {}, failures = {}, skips = {}, display = { passes = {}, failures = {}, skips = {} } }
|
|
end
|
|
local passes = {}
|
|
local failures = {}
|
|
local skips = {}
|
|
local display = { passes = {}, failures = {}, skips = {} }
|
|
for line in output:gmatch("[^\n]+") do
|
|
local ok, data = pcall(vim.json.decode, line)
|
|
if ok and type(data) == "table" then
|
|
if data.Test and data.Test ~= "" then
|
|
if data.Action == "pass" then
|
|
table.insert(passes, data.Test)
|
|
local short = data.Test:match("([^/]+)$") or data.Test
|
|
table.insert(display.passes, short)
|
|
elseif data.Action == "fail" then
|
|
table.insert(failures, data.Test)
|
|
local short = data.Test:match("([^/]+)$") or data.Test
|
|
table.insert(display.failures, short)
|
|
elseif data.Action == "skip" then
|
|
table.insert(skips, data.Test)
|
|
local short = data.Test:match("([^/]+)$") or data.Test
|
|
table.insert(display.skips, short)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
return {
|
|
passes = collect_unique(passes),
|
|
failures = collect_unique(failures),
|
|
skips = collect_unique(skips),
|
|
display = display,
|
|
}
|
|
end
|
|
|
|
local function split_output_lines(text)
|
|
if not text or text == "" then
|
|
return {}
|
|
end
|
|
local lines = vim.split(text, "\n", { plain = true })
|
|
if #lines > 0 and lines[#lines] == "" then
|
|
table.remove(lines, #lines)
|
|
end
|
|
return lines
|
|
end
|
|
|
|
function runner.parse_test_output(output)
|
|
local out = {}
|
|
if not output or output == "" then
|
|
return out
|
|
end
|
|
for line in output:gmatch("[^\n]+") do
|
|
local ok, data = pcall(vim.json.decode, line)
|
|
if ok and type(data) == "table" and data.Action == "output" and data.Test and data.Output then
|
|
if not out[data.Test] then
|
|
out[data.Test] = {}
|
|
end
|
|
for _, item in ipairs(split_output_lines(data.Output)) do
|
|
table.insert(out[data.Test], item)
|
|
end
|
|
end
|
|
end
|
|
return out
|
|
end
|
|
|
|
function runner.output_parser()
|
|
local seen_pass = {}
|
|
local seen_fail = {}
|
|
local failures = {}
|
|
local passes = {}
|
|
local skips = {}
|
|
local display = { passes = {}, failures = {}, skips = {} }
|
|
|
|
return {
|
|
on_line = function(line, _state)
|
|
local ok, data = pcall(vim.json.decode, line)
|
|
if not ok or type(data) ~= "table" then
|
|
return nil
|
|
end
|
|
local name = data.Test
|
|
if not name or name == "" then
|
|
return nil
|
|
end
|
|
local short = name:match("([^/]+)$") or name
|
|
if data.Action == "pass" and not seen_pass[name] then
|
|
seen_pass[name] = true
|
|
table.insert(passes, name)
|
|
table.insert(display.passes, short)
|
|
return {
|
|
passes = { name },
|
|
failures = {},
|
|
skips = {},
|
|
display = { passes = { short }, failures = {}, skips = {} },
|
|
failures_all = vim.deepcopy(failures),
|
|
}
|
|
elseif data.Action == "fail" and not seen_fail[name] then
|
|
seen_fail[name] = true
|
|
table.insert(failures, name)
|
|
table.insert(display.failures, short)
|
|
return {
|
|
passes = {},
|
|
failures = { name },
|
|
skips = {},
|
|
display = { passes = {}, failures = { short }, skips = {} },
|
|
failures_all = vim.deepcopy(failures),
|
|
}
|
|
elseif data.Action == "skip" and not seen_pass[name] then
|
|
seen_pass[name] = true
|
|
table.insert(skips, name)
|
|
table.insert(display.skips, short)
|
|
return {
|
|
passes = {},
|
|
failures = {},
|
|
skips = { name },
|
|
display = { passes = {}, failures = {}, skips = { short } },
|
|
failures_all = vim.deepcopy(failures),
|
|
}
|
|
end
|
|
return nil
|
|
end,
|
|
on_complete = function(_output, _state)
|
|
return nil
|
|
end,
|
|
}
|
|
end
|
|
|
|
function runner.build_failed_command(last_command, failures, _scope_kind)
|
|
if not last_command or type(last_command.cmd) ~= "table" then
|
|
return nil
|
|
end
|
|
local pattern_parts = {}
|
|
for _, name in ipairs(failures or {}) do
|
|
table.insert(pattern_parts, escape_go_regex(name))
|
|
end
|
|
if #pattern_parts == 0 then
|
|
return nil
|
|
end
|
|
local pattern = "^(" .. table.concat(pattern_parts, "|") .. ")$"
|
|
|
|
local cmd = {}
|
|
local skip_next = false
|
|
for _, arg in ipairs(last_command.cmd) do
|
|
if skip_next then
|
|
skip_next = false
|
|
elseif arg == "-run" then
|
|
skip_next = true
|
|
else
|
|
table.insert(cmd, arg)
|
|
end
|
|
end
|
|
table.insert(cmd, "-run")
|
|
table.insert(cmd, pattern)
|
|
|
|
return {
|
|
cmd = cmd,
|
|
cwd = last_command.cwd,
|
|
}
|
|
end
|
|
|
|
return runner
|