diff --git a/.gitignore b/.gitignore index 7ff46e9..2164b93 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ output .nvimlog +.idea diff --git a/lua/test-samurai/core.lua b/lua/test-samurai/core.lua index 942d2bb..7b68fc1 100644 --- a/lua/test-samurai/core.lua +++ b/lua/test-samurai/core.lua @@ -436,13 +436,6 @@ function M.get_runner_for_buf(bufnr) return nil end - if path:sub(-8) == "_test.go" then - local ok, go = pcall(require, "test-samurai.runners.go") - if ok and type(go) == "table" then - return go - end - end - if path:find(".test.", 1, true) or path:find(".spec.", 1, true) then local ok, jsjest = pcall(require, "test-samurai.runners.js-jest") if ok and type(jsjest) == "table" then @@ -1231,9 +1224,81 @@ local function pick_display(results, key, scope_kind) return results[key] end -local function track_result_lines(start_line, results, scope_kind) +local function entry_root(name) + if not name or name == "" then + return nil + end + local idx = name:find("/", 1, true) + if not idx then + return name + end + return name:sub(1, idx - 1) +end + +local function parent_of(name) + if not name or name == "" then + return nil + end + local last = nil + for i = #name, 1, -1 do + if name:sub(i, i) == "/" then + last = i + break + end + end + if not last then + return nil + end + return name:sub(1, last - 1) +end + +local function group_entries_by_parent(entries) + local nodes = {} + local ordered = {} + local roots = {} + + for _, entry in ipairs(entries) do + local full = entry.full + if full and full ~= "" then + if not nodes[full] then + nodes[full] = { entry = entry, children = {} } + table.insert(ordered, full) + end + else + table.insert(roots, { entry = entry, children = {} }) + end + end + + for _, full in ipairs(ordered) do + local node = nodes[full] + local parent = parent_of(full) + if parent and nodes[parent] then + table.insert(nodes[parent].children, node) + else + table.insert(roots, node) + end + end + + local out = {} + local function emit(node) + if node.entry then + table.insert(out, node.entry) + end + for _, child in ipairs(node.children) do + emit(child) + end + end + + for _, node in ipairs(roots) do + emit(node) + end + + return out +end + +local function build_listing_entries(results, scope_kind) if not results then - return + return {} end local entries = {} local function append_kind(kind) @@ -1250,37 +1315,52 @@ local function track_result_lines(start_line, results, scope_kind) if not full_name or full_name == "" then full_name = name end - table.insert(entries, full_name) + table.insert(entries, { + kind = kind, + full = full_name, + display = name, + }) end end append_kind("passes") append_kind("skips") append_kind("failures") - for i, name in ipairs(entries) do - if name and name ~= "" then - state.last_result_line_map[start_line + i] = name + local parent_set = {} + for _, entry in ipairs(entries) do + if entry.full and entry.full ~= "" then + parent_set[entry.full] = true + end + end + for _, entry in ipairs(entries) do + if entry.full and entry.full:find("/", 1, true) then + local parent = parent_of(entry.full) + if parent and parent_set[parent] then + entry.display = entry.full + end + end + end + return group_entries_by_parent(entries) +end + +local function track_result_lines(start_line, results, scope_kind) + local entries = build_listing_entries(results, scope_kind) + for i, entry in ipairs(entries) do + if entry.full and entry.full ~= "" then + state.last_result_line_map[start_line + i] = entry.full end end end local function format_results(results, scope_kind) local lines = {} - local passes = pick_display(results, "passes", scope_kind) - if type(passes) == "table" then - for _, title in ipairs(passes) do - table.insert(lines, "[ PASS ] - " .. title) - end - end - local skips = pick_display(results, "skips", scope_kind) - if type(skips) == "table" then - for _, title in ipairs(skips) do - table.insert(lines, "[ SKIP ] - " .. title) - end - end - local failures = pick_display(results, "failures", scope_kind) - if type(failures) == "table" then - for _, title in ipairs(failures) do - table.insert(lines, "[ FAIL ] - " .. title) + local entries = build_listing_entries(results, scope_kind) + for _, entry in ipairs(entries) do + if entry.kind == "passes" then + table.insert(lines, "[ PASS ] - " .. entry.display) + elseif entry.kind == "skips" then + table.insert(lines, "[ SKIP ] - " .. entry.display) + elseif entry.kind == "failures" then + table.insert(lines, "[ FAIL ] - " .. entry.display) end end return lines @@ -1308,6 +1388,79 @@ local function add_unique_items(target, items) end end +local function init_aggregate_results() + return { + passes = {}, + failures = {}, + skips = {}, + display = { passes = {}, failures = {}, skips = {} }, + } +end + +local function merge_results(agg, results, seen) + if not agg or not results or not seen then + return + end + local function merge_kind(kind) + local items = results[kind] + if type(items) ~= "table" then + return + end + local display_items = nil + if type(results.display) == "table" and type(results.display[kind]) == "table" then + display_items = results.display[kind] + end + for i, name in ipairs(items) do + if name and name ~= "" and not seen[kind][name] then + seen[kind][name] = true + table.insert(agg[kind], name) + if display_items and display_items[i] then + table.insert(agg.display[kind], display_items[i]) + else + table.insert(agg.display[kind], name) + end + end + end + end + merge_kind("passes") + merge_kind("failures") + merge_kind("skips") +end + +local function should_group_results(results) + if not results then + return false + end + local parent_set = {} + for _, kind in ipairs({ "passes", "failures", "skips" }) do + local list = results[kind] + if type(list) == "table" then + for _, name in ipairs(list) do + if name and name ~= "" and not name:find("/", 1, true) then + parent_set[name] = true + end + end + end + end + if not next(parent_set) then + return false + end + for _, kind in ipairs({ "passes", "failures", "skips" }) do + local list = results[kind] + if type(list) == "table" then + for _, name in ipairs(list) do + if name and name:find("/", 1, true) then + local root = entry_root(name) + if root and parent_set[root] then + return true + end + end + end + end + end + return false +end + local function update_summary(summary, results) if not summary or not results then return @@ -1480,6 +1633,14 @@ local function run_command(command, opts) local runner = options.runner local parser_state = {} parser_state.scope_kind = options.scope_kind + parser_state.aggregate_results = nil + parser_state.result_start_line = nil + parser_state.result_end_line = nil + parser_state.seen = { + passes = {}, + failures = {}, + skips = {}, + } local had_parsed_output = false local summary_enabled = options.scope_kind == "file" or options.scope_kind == "all" or options.scope_kind == "nearest" local summary = make_summary_tracker(summary_enabled) @@ -1521,6 +1682,10 @@ local function run_command(command, opts) return end had_parsed_output = true + if not parser_state.aggregate_results then + parser_state.aggregate_results = init_aggregate_results() + end + merge_results(parser_state.aggregate_results, results, parser_state.seen) if type(results.failures) == "table" then for _, name in ipairs(results.failures) do if name and name ~= "" and not failures_seen[name] then @@ -1547,6 +1712,10 @@ local function run_command(command, opts) append_lines(buf, lines) apply_result_highlights(buf, start_line, lines) track_result_lines(start_line, results, options.scope_kind) + if not parser_state.result_start_line then + parser_state.result_start_line = start_line + end + parser_state.result_end_line = vim.api.nvim_buf_line_count(buf) end run_cmd(cmd, cwd, { @@ -1603,6 +1772,18 @@ local function run_command(command, opts) handle_parsed(results) end end + if parser_state.aggregate_results and parser_state.result_start_line and should_group_results(parser_state.aggregate_results) then + local start_line = parser_state.result_start_line + local end_line = parser_state.result_end_line or start_line + local grouped = format_results(parser_state.aggregate_results, options.scope_kind) + vim.api.nvim_buf_set_lines(buf, start_line, end_line, false, grouped) + vim.api.nvim_buf_clear_namespace(buf, result_ns, start_line, end_line) + apply_result_highlights(buf, start_line, grouped) + state.last_result_line_map = {} + track_result_lines(start_line, parser_state.aggregate_results, options.scope_kind) + parser_state.result_end_line = start_line + #grouped + end + local pass_count, fail_count = count_summary(result_counts) if fail_count > 0 then state.last_border_kind = "fail" diff --git a/tests/minimal_init.lua b/tests/minimal_init.lua index 90e05fd..4022136 100644 --- a/tests/minimal_init.lua +++ b/tests/minimal_init.lua @@ -1,2 +1,8 @@ -vim.opt.runtimepath:append(vim.loop.cwd()) +local cwd = vim.loop.cwd() +vim.opt.runtimepath:append(cwd) +package.path = table.concat({ + cwd .. "/lua/?.lua", + cwd .. "/lua/?/init.lua", + package.path, +}, ";") require("plenary.busted") diff --git a/tests/test_samurai_core_spec.lua b/tests/test_samurai_core_spec.lua index 8004dff..ad208ae 100644 --- a/tests/test_samurai_core_spec.lua +++ b/tests/test_samurai_core_spec.lua @@ -22,6 +22,14 @@ describe("test-samurai core", function() assert.equals("go", runner.name) end) + it("does not fallback to Go runner when no runners are configured", function() + test_samurai.setup({ runner_modules = {} }) + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_name(bufnr, "/tmp/no_runner_test.go") + local runner = core.get_runner_for_buf(bufnr) + assert.is_nil(runner) + end) + it("selects JS jest runner for *.test.ts files", function() local bufnr = vim.api.nvim_create_buf(false, true) vim.api.nvim_buf_set_name(bufnr, "/tmp/foo.test.ts") diff --git a/tests/test_samurai_output_spec.lua b/tests/test_samurai_output_spec.lua index c9b636b..b633420 100644 --- a/tests/test_samurai_output_spec.lua +++ b/tests/test_samurai_output_spec.lua @@ -1276,16 +1276,18 @@ describe("test-samurai output formatting", function() assert.is_true(set_calls.TestSamuraiSummarySkip.fg == 333) end) - it("formats go subtests as short names", function() - local json_line = vim.json.encode({ - Action = "pass", - Test = "TestHandleGet/returns_200", - }) + it("groups Go subtests under their parent in listing", function() + local json_lines = { + vim.json.encode({ Action = "pass", Test = "TestHandleGet/returns_200" }), + vim.json.encode({ Action = "fail", Test = "TestOther/returns_500" }), + vim.json.encode({ Action = "pass", Test = "TestHandleGet" }), + vim.json.encode({ Action = "skip", Test = "TestOther" }), + } local orig_jobstart = vim.fn.jobstart vim.fn.jobstart = function(_cmd, opts) if opts and opts.on_stdout then - opts.on_stdout(1, { json_line }, nil) + opts.on_stdout(1, json_lines, nil) end if opts and opts.on_exit then opts.on_exit(1, 0, nil) @@ -1294,21 +1296,23 @@ describe("test-samurai output formatting", function() end local bufnr = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_name(bufnr, "/tmp/output_go_short_test.go") + vim.api.nvim_buf_set_name(bufnr, "/tmp/output_go_grouped_test.go") vim.bo[bufnr].filetype = "go" vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, { "package main", "import \"testing\"", "", "func TestHandleGet(t *testing.T) {", - " t.Run(\"returns_200\", func(t *testing.T) {", - " -- inside test", - " })", + " t.Run(\"returns_200\", func(t *testing.T) {})", + "}", + "", + "func TestOther(t *testing.T) {", + " t.Run(\"returns_500\", func(t *testing.T) {})", "}", }) vim.api.nvim_set_current_buf(bufnr) - vim.api.nvim_win_set_cursor(0, { 6, 0 }) + vim.api.nvim_win_set_cursor(0, { 5, 0 }) core.run_nearest() @@ -1317,17 +1321,97 @@ describe("test-samurai output formatting", function() vim.fn.jobstart = orig_jobstart - local has_pass = false - local has_raw_json = false - for _, line in ipairs(lines) do - if line == "[ PASS ] - returns_200" then - has_pass = true - elseif line == json_line then - has_raw_json = true + local idx_parent_1 = nil + local idx_sub_1 = nil + local idx_parent_2 = nil + local idx_sub_2 = nil + for i, line in ipairs(lines) do + if line == "[ PASS ] - TestHandleGet" then + idx_parent_1 = i + elseif line == "[ PASS ] - TestHandleGet/returns_200" then + idx_sub_1 = i + elseif line == "[ SKIP ] - TestOther" then + idx_parent_2 = i + elseif line == "[ FAIL ] - TestOther/returns_500" then + idx_sub_2 = i end end - assert.is_true(has_pass) - assert.is_false(has_raw_json) + + assert.is_not_nil(idx_parent_1) + assert.is_not_nil(idx_sub_1) + assert.is_not_nil(idx_parent_2) + assert.is_not_nil(idx_sub_2) + assert.is_true(idx_parent_1 < idx_sub_1) + assert.is_true(idx_parent_2 < idx_sub_2) + end) + + it("groups nested Go subtests under subtest parents in listing", function() + local json_lines = { + vim.json.encode({ Action = "pass", Test = "TestWriteJSON/returns_500_when/data_could_not_be_serialized_and_logs_it" }), + vim.json.encode({ Action = "pass", Test = "TestWriteJSON" }), + vim.json.encode({ Action = "pass", Test = "TestWriteJSON/returns_500_when" }), + vim.json.encode({ Action = "pass", Test = "TestWriteJSON/returns_500_when/error_at_writing_response_occurs_and_logs_it" }), + } + + local orig_jobstart = vim.fn.jobstart + vim.fn.jobstart = function(_cmd, opts) + if opts and opts.on_stdout then + opts.on_stdout(1, json_lines, nil) + end + if opts and opts.on_exit then + opts.on_exit(1, 0, nil) + end + return 1 + end + + local bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_name(bufnr, "/tmp/output_go_nested_test.go") + vim.bo[bufnr].filetype = "go" + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, { + "package main", + "import \"testing\"", + "", + "func TestWriteJSON(t *testing.T) {", + " t.Run(\"returns_500_when\", func(t *testing.T) {", + " t.Run(\"data_could_not_be_serialized_and_logs_it\", func(t *testing.T) {})", + " t.Run(\"error_at_writing_response_occurs_and_logs_it\", func(t *testing.T) {})", + " })", + "}", + }) + + vim.api.nvim_set_current_buf(bufnr) + vim.api.nvim_win_set_cursor(0, { 5, 0 }) + + core.run_nearest() + + local out_buf = vim.api.nvim_get_current_buf() + local lines = vim.api.nvim_buf_get_lines(out_buf, 0, -1, false) + + vim.fn.jobstart = orig_jobstart + + local idx_parent = nil + local idx_mid = nil + local idx_child_1 = nil + local idx_child_2 = nil + for i, line in ipairs(lines) do + if line == "[ PASS ] - TestWriteJSON" then + idx_parent = i + elseif line == "[ PASS ] - TestWriteJSON/returns_500_when" then + idx_mid = i + elseif line == "[ PASS ] - TestWriteJSON/returns_500_when/data_could_not_be_serialized_and_logs_it" then + idx_child_1 = i + elseif line == "[ PASS ] - TestWriteJSON/returns_500_when/error_at_writing_response_occurs_and_logs_it" then + idx_child_2 = i + end + end + + assert.is_not_nil(idx_parent) + assert.is_not_nil(idx_mid) + assert.is_not_nil(idx_child_1) + assert.is_not_nil(idx_child_2) + assert.is_true(idx_parent < idx_mid) + assert.is_true(idx_mid < idx_child_1) + assert.is_true(idx_mid < idx_child_2) end) it("does not print raw JSON output for mocha json-stream", function()