`timescale 1ns / 1ps

module tb_rv32_micro_core;
    reg clk;
    reg rst;

    wire [31:0] imem_addr;
    reg  [31:0] imem_rdata;
    wire [31:0] dmem_addr;
    wire [31:0] dmem_wdata;
    wire        dmem_we;
    reg  [31:0] dmem_rdata;
    wire        halted;
    wire        trap;

    integer errors = 0;
    integer tests_run = 0;
    integer run_cycles = 0;
    integer store_count = 0;
    integer i;

    localparam integer IMEM_WORDS = 256;
    localparam integer DMEM_WORDS = 256;
    localparam integer MAX_STORES = 16;

    reg [31:0] imem [0:IMEM_WORDS-1];
    reg [31:0] dmem [0:DMEM_WORDS-1];
    reg [31:0] observed_store_addr [0:MAX_STORES-1];
    reg [31:0] observed_store_data [0:MAX_STORES-1];

    rv32_micro_core uut (
        .clk(clk),
        .rst(rst),
        .imem_addr(imem_addr),
        .imem_rdata(imem_rdata),
        .dmem_addr(dmem_addr),
        .dmem_wdata(dmem_wdata),
        .dmem_we(dmem_we),
        .dmem_rdata(dmem_rdata),
        .halted(halted),
        .trap(trap)
    );

    function [31:0] sentinel_word;
        input integer idx;
        begin
            sentinel_word = 32'hCAFE0000 + idx[31:0];
        end
    endfunction

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    always @(*) begin
        if (^imem_addr[9:2] === 1'bx)
            imem_rdata = 32'h00000013;
        else
            imem_rdata = imem[imem_addr[9:2]];

        if (^dmem_addr[9:2] === 1'bx)
            dmem_rdata = 32'h00000000;
        else
            dmem_rdata = dmem[dmem_addr[9:2]];
    end

    always @(posedge clk) begin
        if (dmem_we === 1'b1) begin
            if (rst === 1'b1) begin
                $display("ERROR: store attempted while reset is asserted");
                errors = errors + 1;
            end

            if (store_count >= MAX_STORES) begin
                $display("ERROR: observed store trace overflow");
                errors = errors + 1;
            end else begin
                observed_store_addr[store_count] = dmem_addr;
                observed_store_data[store_count] = dmem_wdata;
                store_count = store_count + 1;
            end

            if (dmem_addr[1:0] != 2'b00) begin
                $display("ERROR: misaligned external store observed at address 0x%08h", dmem_addr);
                errors = errors + 1;
            end else begin
                dmem[dmem_addr[9:2]] <= dmem_wdata;
            end
        end
    end

    task init_memories;
        begin
            for (i = 0; i < IMEM_WORDS; i = i + 1)
                imem[i] = 32'h00000013;

            for (i = 0; i < DMEM_WORDS; i = i + 1)
                dmem[i] = sentinel_word(i);

            for (i = 0; i < MAX_STORES; i = i + 1) begin
                observed_store_addr[i] = 32'h0;
                observed_store_data[i] = 32'h0;
            end

            store_count = 0;
            run_cycles = 0;
        end
    endtask

    task load_program;
        input integer program_id;
        begin
            case (program_id)
                1: begin
                    imem[0] = 32'h00100073;
                end
                2: begin
                    imem[0]  = 32'h10000513;
                    imem[1]  = 32'h123450b7;
                    imem[2]  = 32'h01500113;
                    imem[3]  = 32'h00a00193;
                    imem[4]  = 32'h00310233;
                    imem[5]  = 32'h403102b3;
                    imem[6]  = 32'h00527333;
                    imem[7]  = 32'h0012e3b3;
                    imem[8]  = 32'h0013c433;
                    imem[9]  = 32'h06300013;
                    imem[10] = 32'h00152023;
                    imem[11] = 32'h00452223;
                    imem[12] = 32'h00552423;
                    imem[13] = 32'h00652623;
                    imem[14] = 32'h00752823;
                    imem[15] = 32'h00852a23;
                    imem[16] = 32'h00052c23;
                    imem[17] = 32'h00100073;
                end
                3: begin
                    imem[0]  = 32'h04000513;
                    imem[1]  = 32'h00052083;
                    imem[2]  = 32'h00452103;
                    imem[3]  = 32'h002081b3;
                    imem[4]  = 32'h0020c233;
                    imem[5]  = 32'h12000593;
                    imem[6]  = 32'h0015a023;
                    imem[7]  = 32'h0025a223;
                    imem[8]  = 32'h0035a423;
                    imem[9]  = 32'h0045a623;
                    imem[10] = 32'h00100073;
                end
                4: begin
                    imem[0]  = 32'h14000513;
                    imem[1]  = 32'h00000093;
                    imem[2]  = 32'h00300113;
                    imem[3]  = 32'h00100193;
                    imem[4]  = 32'h02010263;
                    imem[5]  = 32'h002080b3;
                    imem[6]  = 32'h40310133;
                    imem[7]  = 32'hfe011ce3;
                    imem[8]  = 32'h00108463;
                    imem[9]  = 32'h06300213;
                    imem[10] = 32'h00152023;
                    imem[11] = 32'h00252223;
                    imem[12] = 32'h00100073;
                    imem[13] = 32'h04d00213;
                    imem[14] = 32'h00452423;
                    imem[15] = 32'h00100073;
                end
                5: begin
                    imem[0] = 32'h16000513;
                    imem[1] = 32'h00700093;
                    imem[2] = 32'h00c002ef;
                    imem[3] = 32'h06300093;
                    imem[4] = 32'h05800093;
                    imem[5] = 32'h00108133;
                    imem[6] = 32'h00552023;
                    imem[7] = 32'h00252223;
                    imem[8] = 32'h00100073;
                end
                6: begin
                    imem[0] = 32'h18000513;
                    imem[1] = 32'h00000000;
                    imem[2] = 32'h00a52023;
                    imem[3] = 32'h00100073;
                end
                7: begin
                    imem[0] = 32'h04100513;
                    imem[1] = 32'h00052083;
                    imem[2] = 32'h00102023;
                    imem[3] = 32'h00100073;
                end
                8: begin
                    imem[0] = 32'h10000513;
                    imem[1] = 32'h04200093;
                    imem[2] = 32'h10200593;
                    imem[3] = 32'h0015a023;
                    imem[4] = 32'h00152023;
                    imem[5] = 32'h00100073;
                end
                default: begin
                    $display("ERROR: unknown program id %0d", program_id);
                    errors = errors + 1;
                end
            endcase
        end
    endtask

    task preload_program_data;
        input integer program_id;
        begin
            case (program_id)
                3: begin
                    dmem[32'h040 >> 2] = 32'h11223344;
                    dmem[32'h044 >> 2] = 32'h01020304;
                end
                default: begin
                end
            endcase
        end
    endtask

    task check_word;
        input [255:0] tag;
        input [31:0] byte_addr;
        input [31:0] expected;
        reg [31:0] got;
        begin
            got = dmem[byte_addr[9:2]];
            if (got !== expected) begin
                $display("ERROR [%s]: memory mismatch at 0x%08h. expected=0x%08h got=0x%08h",
                         tag, byte_addr, expected, got);
                errors = errors + 1;
            end
        end
    endtask

    task check_store_count;
        input [255:0] tag;
        input integer expected_count;
        begin
            if (store_count !== expected_count) begin
                $display("ERROR [%s]: store count mismatch. expected=%0d got=%0d",
                         tag, expected_count, store_count);
                errors = errors + 1;
            end
        end
    endtask

    task check_store;
        input [255:0] tag;
        input integer idx;
        input [31:0] expected_addr;
        input [31:0] expected_data;
        begin
            if (idx >= store_count) begin
                $display("ERROR [%s]: missing store %0d. expected addr=0x%08h data=0x%08h",
                         tag, idx, expected_addr, expected_data);
                errors = errors + 1;
            end else begin
                if (observed_store_addr[idx] !== expected_addr) begin
                    $display("ERROR [%s]: store %0d address mismatch. expected=0x%08h got=0x%08h",
                             tag, idx, expected_addr, observed_store_addr[idx]);
                    errors = errors + 1;
                end
                if (observed_store_data[idx] !== expected_data) begin
                    $display("ERROR [%s]: store %0d data mismatch. expected=0x%08h got=0x%08h",
                             tag, idx, expected_data, observed_store_data[idx]);
                    errors = errors + 1;
                end
            end
        end
    endtask

    task run_program;
        input [255:0] tag;
        input integer expected_halted;
        input integer expected_trap;
        input integer expected_cycles;
        input integer max_cycles;
        integer freeze_stores;
        begin
            rst = 1'b1;
            repeat (2) @(posedge clk);
            #1;
            rst = 1'b0;
            #1;

            if (halted !== 1'b0) begin
                $display("ERROR [%s]: halted not low after reset release", tag);
                errors = errors + 1;
            end
            if (trap !== 1'b0) begin
                $display("ERROR [%s]: trap not low after reset release", tag);
                errors = errors + 1;
            end

            run_cycles = 0;
            while ((halted !== 1'b1) && (trap !== 1'b1) && (run_cycles < max_cycles)) begin
                @(posedge clk);
                #1;
                run_cycles = run_cycles + 1;
                tests_run = tests_run + 1;
            end

            if ((halted !== 1'b1) && (trap !== 1'b1)) begin
                $display("ERROR [%s]: program did not terminate within %0d cycles", tag, max_cycles);
                errors = errors + 1;
            end

            if (halted !== expected_halted[0]) begin
                $display("ERROR [%s]: halted mismatch. expected=%0d got=%b",
                         tag, expected_halted, halted);
                errors = errors + 1;
            end

            if (trap !== expected_trap[0]) begin
                $display("ERROR [%s]: trap mismatch. expected=%0d got=%b",
                         tag, expected_trap, trap);
                errors = errors + 1;
            end

            if (run_cycles !== expected_cycles) begin
                $display("ERROR [%s]: termination cycle mismatch. expected=%0d got=%0d",
                         tag, expected_cycles, run_cycles);
                errors = errors + 1;
            end

            freeze_stores = store_count;
            repeat (2) begin
                @(posedge clk);
                #1;
                tests_run = tests_run + 1;
                if (halted !== expected_halted[0]) begin
                    $display("ERROR [%s]: halted did not remain sticky after termination", tag);
                    errors = errors + 1;
                end
                if (trap !== expected_trap[0]) begin
                    $display("ERROR [%s]: trap did not remain sticky after termination", tag);
                    errors = errors + 1;
                end
                if (store_count !== freeze_stores) begin
                    $display("ERROR [%s]: observed store after terminal state", tag);
                    errors = errors + 1;
                    freeze_stores = store_count;
                end
            end
        end
    endtask

    task run_test_p01;
        begin
            init_memories;
            load_program(1);
            run_program("p01_halt_smoke", 1, 0, 1, 8);
            check_store_count("p01_halt_smoke", 0);
        end
    endtask

    task run_test_p02;
        begin
            init_memories;
            load_program(2);
            run_program("p02_arith_logic_x0", 1, 0, 18, 32);

            check_store_count("p02_arith_logic_x0", 7);
            check_store("p02_arith_logic_x0", 0, 32'h00000100, 32'h12345000);
            check_store("p02_arith_logic_x0", 1, 32'h00000104, 32'h0000001f);
            check_store("p02_arith_logic_x0", 2, 32'h00000108, 32'h0000000b);
            check_store("p02_arith_logic_x0", 3, 32'h0000010c, 32'h0000000b);
            check_store("p02_arith_logic_x0", 4, 32'h00000110, 32'h1234500b);
            check_store("p02_arith_logic_x0", 5, 32'h00000114, 32'h0000000b);
            check_store("p02_arith_logic_x0", 6, 32'h00000118, 32'h00000000);

            check_word("p02_arith_logic_x0", 32'h00000100, 32'h12345000);
            check_word("p02_arith_logic_x0", 32'h00000104, 32'h0000001f);
            check_word("p02_arith_logic_x0", 32'h00000108, 32'h0000000b);
            check_word("p02_arith_logic_x0", 32'h0000010c, 32'h0000000b);
            check_word("p02_arith_logic_x0", 32'h00000110, 32'h1234500b);
            check_word("p02_arith_logic_x0", 32'h00000114, 32'h0000000b);
            check_word("p02_arith_logic_x0", 32'h00000118, 32'h00000000);
        end
    endtask

    task run_test_p03;
        begin
            init_memories;
            preload_program_data(3);
            load_program(3);
            run_program("p03_load_store", 1, 0, 11, 24);

            check_store_count("p03_load_store", 4);
            check_store("p03_load_store", 0, 32'h00000120, 32'h11223344);
            check_store("p03_load_store", 1, 32'h00000124, 32'h01020304);
            check_store("p03_load_store", 2, 32'h00000128, 32'h12243648);
            check_store("p03_load_store", 3, 32'h0000012c, 32'h10203040);

            check_word("p03_load_store", 32'h00000120, 32'h11223344);
            check_word("p03_load_store", 32'h00000124, 32'h01020304);
            check_word("p03_load_store", 32'h00000128, 32'h12243648);
            check_word("p03_load_store", 32'h0000012c, 32'h10203040);
        end
    endtask

    task run_test_p04;
        begin
            init_memories;
            load_program(4);
            run_program("p04_branch_loop", 1, 0, 18, 32);

            check_store_count("p04_branch_loop", 2);
            check_store("p04_branch_loop", 0, 32'h00000140, 32'h00000006);
            check_store("p04_branch_loop", 1, 32'h00000144, 32'h00000000);

            check_word("p04_branch_loop", 32'h00000140, 32'h00000006);
            check_word("p04_branch_loop", 32'h00000144, 32'h00000000);
            check_word("p04_branch_loop", 32'h00000148, sentinel_word(32'h148 >> 2));
        end
    endtask

    task run_test_p05;
        begin
            init_memories;
            load_program(5);
            run_program("p05_jal_link", 1, 0, 7, 20);

            check_store_count("p05_jal_link", 2);
            check_store("p05_jal_link", 0, 32'h00000160, 32'h0000000c);
            check_store("p05_jal_link", 1, 32'h00000164, 32'h0000000e);

            check_word("p05_jal_link", 32'h00000160, 32'h0000000c);
            check_word("p05_jal_link", 32'h00000164, 32'h0000000e);
        end
    endtask

    task run_test_p06;
        begin
            init_memories;
            load_program(6);
            run_program("p06_illegal_word", 0, 1, 2, 12);
            check_store_count("p06_illegal_word", 0);
            check_word("p06_illegal_word", 32'h00000180, sentinel_word(32'h180 >> 2));
        end
    endtask

    task run_test_p07;
        begin
            init_memories;
            load_program(7);
            run_program("p07_misaligned_lw", 0, 1, 2, 12);
            check_store_count("p07_misaligned_lw", 0);
            check_word("p07_misaligned_lw", 32'h00000000, sentinel_word(0));
        end
    endtask

    task run_test_p08;
        begin
            init_memories;
            load_program(8);
            run_program("p08_misaligned_sw", 0, 1, 4, 12);
            check_store_count("p08_misaligned_sw", 0);
            check_word("p08_misaligned_sw", 32'h00000100, sentinel_word(32'h100 >> 2));
        end
    endtask

    initial begin
        rst = 1'b0;

        $display("===========================================");
        $display("         RV32I Micro-Core Testbench");
        $display("===========================================");

        run_test_p01;
        run_test_p02;
        run_test_p03;
        run_test_p04;
        run_test_p05;
        run_test_p06;
        run_test_p07;
        run_test_p08;

        $display("");
        $display("===========================================");
        $display("  Cycle Steps: %0d", tests_run);
        $display("===========================================");

        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end

        $finish;
    end
endmodule
