local util = require("test-samurai.util") local runner = { name = "go", } local function find_block_end(lines, start_idx) local depth = 0 local started = false for i = start_idx, #lines do local line = lines[i] for j = 1, #line do local ch = line:sub(j, j) if ch == "{" then depth = depth + 1 started = true elseif ch == "}" then if started then depth = depth - 1 if depth == 0 then return i - 1 end end end end end return #lines - 1 end local function find_test_functions(lines) local funcs = {} for i, line in ipairs(lines) do local name = line:match("^%s*func%s+([%w_]+)%s*%(") if not name then name = line:match("^%s*func%s+%([^)]-%)%s+([%w_]+)%s*%(") end if name and line:find("%*testing%.T") then local start_0 = i - 1 local end_0 = find_block_end(lines, i) table.insert(funcs, { name = name, start = start_0, ["end"] = end_0, }) end end return funcs end local function find_t_runs(lines, func) local subtests = {} for i = func.start + 1, func["end"] do local line = lines[i + 1] if line then local name = line:match("t%.Run%(%s*['\"]([^'\"]+)['\"]") if name then local start_idx = i + 1 local end_0 = find_block_end(lines, start_idx) table.insert(subtests, { name = name, start = start_idx - 1, ["end"] = end_0, }) end end end return subtests end local function escape_go_regex(s) s = s or "" return (s:gsub("([\\.^$|()%%[%]{}*+?%-])", "\\\\%1")) end local function build_run_pattern(spec) local name = spec.test_path or "" local escaped = escape_go_regex(name) if spec.scope == "function" then return "^" .. escaped .. "($|/)" else return "^" .. escaped .. "$" end end local function build_pkg_arg(spec) local file = spec.file local cwd = spec.cwd if not file or not cwd or file == "" or cwd == "" then return "./..." end local dir = vim.fs.dirname(file) if dir == cwd then return "./" end if file:sub(1, #cwd) ~= cwd then return "./..." end local rel = dir:sub(#cwd + 2) if not rel or rel == "" then return "./" end return "./" .. rel end local function collect_unique(list) local out = {} local seen = {} for _, item in ipairs(list) do if item and item ~= "" and not seen[item] then seen[item] = true table.insert(out, item) end end return out end function runner.is_test_file(bufnr) local path = util.get_buf_path(bufnr) if not path or path == "" then return false end return path:sub(-8) == "_test.go" end function runner.find_nearest(bufnr, row, _col) if not runner.is_test_file(bufnr) then return nil, "not a Go test file" end local lines = util.get_buf_lines(bufnr) local funcs = find_test_functions(lines) local current for _, f in ipairs(funcs) do if row >= f.start and row <= f["end"] then current = f break end end if not current then return nil, "cursor not inside a test function" end local subtests = find_t_runs(lines, current) local inside_sub for _, sub in ipairs(subtests) do if row >= sub.start and row <= sub["end"] then inside_sub = sub break end end local path = util.get_buf_path(bufnr) local root = util.find_root(path, { "go.mod", ".git" }) if inside_sub then local full = current.name .. "/" .. inside_sub.name return { file = path, cwd = root, test_path = full, scope = "subtest", func = current.name, subtest = inside_sub.name, } else return { file = path, cwd = root, test_path = current.name, scope = "function", func = current.name, } end end function runner.build_command(spec) local pattern = build_run_pattern(spec) local pkg = build_pkg_arg(spec) local cmd = { "go", "test", "-json", pkg, "-run", pattern } return { cmd = cmd, cwd = spec.cwd, } end function runner.build_file_command(bufnr) local path = util.get_buf_path(bufnr) if not path or path == "" then return nil end local root = util.find_root(path, { "go.mod", ".git" }) if not root or root == "" then root = vim.loop.cwd() end local spec = { file = path, cwd = root } local pkg = build_pkg_arg(spec) local cmd = { "go", "test", "-json", pkg } local lines = util.get_buf_lines(bufnr) local funcs = find_test_functions(lines) local names = {} for _, fn in ipairs(funcs) do table.insert(names, fn.name) end names = collect_unique(names) if #names > 0 then local pattern_parts = {} for _, name in ipairs(names) do table.insert(pattern_parts, escape_go_regex(name)) end local pattern = "^(" .. table.concat(pattern_parts, "|") .. ")$" table.insert(cmd, "-run") table.insert(cmd, pattern) end return { cmd = cmd, cwd = root, } end function runner.build_all_command(bufnr) local path = util.get_buf_path(bufnr) local root if path and path ~= "" then root = util.find_root(path, { "go.mod", ".git" }) end if not root or root == "" then root = vim.loop.cwd() end local cmd = { "go", "test", "-json", "./..." } return { cmd = cmd, cwd = root, } end function runner.parse_results(output) if not output or output == "" then return { passes = {}, failures = {}, skips = {}, display = { passes = {}, failures = {}, skips = {} } } end local passes = {} local failures = {} local skips = {} local display = { passes = {}, failures = {}, skips = {} } for line in output:gmatch("[^\n]+") do local ok, data = pcall(vim.json.decode, line) if ok and type(data) == "table" then if data.Test and data.Test ~= "" then if data.Action == "pass" then table.insert(passes, data.Test) local short = data.Test:match("([^/]+)$") or data.Test table.insert(display.passes, short) elseif data.Action == "fail" then table.insert(failures, data.Test) local short = data.Test:match("([^/]+)$") or data.Test table.insert(display.failures, short) elseif data.Action == "skip" then table.insert(skips, data.Test) local short = data.Test:match("([^/]+)$") or data.Test table.insert(display.skips, short) end end end end return { passes = collect_unique(passes), failures = collect_unique(failures), skips = collect_unique(skips), display = display, } end local function split_output_lines(text) if not text or text == "" then return {} end local lines = vim.split(text, "\n", { plain = true }) if #lines > 0 and lines[#lines] == "" then table.remove(lines, #lines) end return lines end local function normalize_go_name(name) if not name or name == "" then return nil end return (name:gsub("%s+", "_")) end local function add_location(target, key, file, line, label) if not key or key == "" or not file or file == "" or not line then return end local text = label or key if not target[key] then target[key] = {} end table.insert(target[key], { filename = file, lnum = line, col = 1, text = text, }) end local function collect_file_locations(file, target) local ok, lines = pcall(vim.fn.readfile, file) if not ok or type(lines) ~= "table" then return end local funcs = find_test_functions(lines) for _, fn in ipairs(funcs) do add_location(target, fn.name, file, fn.start + 1, fn.name) local normalized = normalize_go_name(fn.name) if normalized and normalized ~= fn.name then add_location(target, normalized, file, fn.start + 1, fn.name) end for _, sub in ipairs(find_t_runs(lines, fn)) do local full = fn.name .. "/" .. sub.name add_location(target, full, file, sub.start + 1, full) local normalized_full = normalize_go_name(full) if normalized_full and normalized_full ~= full then add_location(target, normalized_full, file, sub.start + 1, full) end end end end local function collect_go_test_files(root) if not root or root == "" then root = vim.loop.cwd() end local files = vim.fn.globpath(root, "**/*_test.go", false, true) if type(files) ~= "table" then return {} end return files end function runner.parse_test_output(output) local out = {} if not output or output == "" then return out end for line in output:gmatch("[^\n]+") do local ok, data = pcall(vim.json.decode, line) if ok and type(data) == "table" and data.Action == "output" and data.Test and data.Output then if not out[data.Test] then out[data.Test] = {} end for _, item in ipairs(split_output_lines(data.Output)) do table.insert(out[data.Test], item) end end end return out end function runner.output_parser() local seen_pass = {} local seen_fail = {} local failures = {} local passes = {} local skips = {} local display = { passes = {}, failures = {}, skips = {} } return { on_line = function(line, _state) local ok, data = pcall(vim.json.decode, line) if not ok or type(data) ~= "table" then return nil end local name = data.Test if not name or name == "" then return nil end local short = name:match("([^/]+)$") or name if data.Action == "pass" and not seen_pass[name] then seen_pass[name] = true table.insert(passes, name) table.insert(display.passes, short) return { passes = { name }, failures = {}, skips = {}, display = { passes = { short }, failures = {}, skips = {} }, failures_all = vim.deepcopy(failures), } elseif data.Action == "fail" and not seen_fail[name] then seen_fail[name] = true table.insert(failures, name) table.insert(display.failures, short) return { passes = {}, failures = { name }, skips = {}, display = { passes = {}, failures = { short }, skips = {} }, failures_all = vim.deepcopy(failures), } elseif data.Action == "skip" and not seen_pass[name] then seen_pass[name] = true table.insert(skips, name) table.insert(display.skips, short) return { passes = {}, failures = {}, skips = { name }, display = { passes = {}, failures = {}, skips = { short } }, failures_all = vim.deepcopy(failures), } end return nil end, on_complete = function(_output, _state) return nil end, } end function runner.build_failed_command(last_command, failures, _scope_kind) if not last_command or type(last_command.cmd) ~= "table" then return nil end local pattern_parts = {} for _, name in ipairs(failures or {}) do table.insert(pattern_parts, escape_go_regex(name)) end if #pattern_parts == 0 then return nil end local pattern = "^(" .. table.concat(pattern_parts, "|") .. ")$" local cmd = {} local skip_next = false for _, arg in ipairs(last_command.cmd) do if skip_next then skip_next = false elseif arg == "-run" then skip_next = true else table.insert(cmd, arg) end end table.insert(cmd, "-run") table.insert(cmd, pattern) return { cmd = cmd, cwd = last_command.cwd, } end function runner.collect_failed_locations(failures, command, scope_kind) if type(failures) ~= "table" or #failures == 0 then return {} end local files = {} if scope_kind == "all" then files = collect_go_test_files(command and command.cwd or nil) elseif command and command.file then files = { command.file } end if #files == 0 then return {} end local locations = {} for _, file in ipairs(files) do collect_file_locations(file, locations) end local items = {} local seen = {} local function add_locations(name, locs) for _, loc in ipairs(locs or {}) do local key = string.format("%s:%d:%s", loc.filename or "", loc.lnum or 0, loc.text or name or "") if not seen[key] then seen[key] = true table.insert(items, loc) end end end for _, name in ipairs(failures) do local direct = locations[name] if direct then add_locations(name, direct) elseif not name:find("/", 1, true) then for full, locs in pairs(locations) do if full:sub(-#name - 1) == "/" .. name then add_locations(full, locs) end end end end return items end return runner