`timescale 1ns / 1ps

module tb_register_file_2r1w;
    // Clock and reset
    reg clk;
    reg rst;

    // DUT interface
    reg  [4:0]  raddr1;
    reg  [4:0]  raddr2;
    wire [31:0] rdata1;
    wire [31:0] rdata2;
    reg         we;
    reg  [4:0]  waddr;
    reg  [31:0] wdata;

    integer errors = 0;
    integer i;

    // Golden model
    reg [31:0] golden [0:31];

    // Instantiate DUT
    register_file_2r1w uut (
        .clk(clk),
        .rst(rst),
        .raddr1(raddr1),
        .raddr2(raddr2),
        .rdata1(rdata1),
        .rdata2(rdata2),
        .we(we),
        .waddr(waddr),
        .wdata(wdata)
    );

    // Clock generation (100MHz)
    initial begin
        clk = 0;
        forever #5 clk = ~clk;
    end

    task init_golden;
        begin
            for (i = 0; i < 32; i = i + 1) begin
                golden[i] = 32'b0;
            end
        end
    endtask

    function [31:0] expected_read;
        input [4:0] addr;
        begin
            if (addr == 5'd0)
                expected_read = 32'b0;
            else if (we && (waddr == addr) && (waddr != 5'd0))
                expected_read = wdata;
            else
                expected_read = golden[addr];
        end
    endfunction

    task check_reads;
        input [255:0] tag;
        reg [31:0] exp1;
        reg [31:0] exp2;
        begin
            exp1 = expected_read(raddr1);
            exp2 = expected_read(raddr2);
            if (rdata1 !== exp1) begin
                $display("ERROR %s: rdata1 expected %h, got %h (raddr1=%0d)", tag, exp1, rdata1, raddr1);
                errors = errors + 1;
            end
            if (rdata2 !== exp2) begin
                $display("ERROR %s: rdata2 expected %h, got %h (raddr2=%0d)", tag, exp2, rdata2, raddr2);
                errors = errors + 1;
            end
        end
    endtask

    task apply_cycle;
        input [255:0] tag;
        input         next_we;
        input [4:0]   next_waddr;
        input [31:0]  next_wdata;
        input [4:0]   next_raddr1;
        input [4:0]   next_raddr2;
        begin
            we     = next_we;
            waddr  = next_waddr;
            wdata  = next_wdata;
            raddr1 = next_raddr1;
            raddr2 = next_raddr2;

            // Combinational read check before clock edge
            #1;
            check_reads(tag);

            // Rising edge updates
            @(posedge clk);
            if (rst) begin
                init_golden;
            end else if (we && (waddr != 5'd0)) begin
                golden[waddr] = wdata;
            end

            // Check again after the edge with the same inputs
            #1;
            check_reads(tag);
        end
    endtask

    task do_reset;
        begin
            rst = 1;
            we = 0;
            waddr = 0;
            wdata = 0;
            raddr1 = 0;
            raddr2 = 0;

            @(posedge clk);
            init_golden;
            rst = 0;
            #1;

            for (i = 0; i < 32; i = i + 1) begin
                raddr1 = i[4:0];
                raddr2 = 5'd31 - i[4:0];
                #1;
                check_reads("reset_check");
            end
        end
    endtask

    initial begin
        // Initialize
        rst = 0;
        we = 0;
        waddr = 0;
        wdata = 0;
        raddr1 = 0;
        raddr2 = 0;
        init_golden;

        // Reset test
        do_reset;

        // Directed tests
        apply_cycle("write_zero", 1'b1, 5'd0, 32'hFFFFFFFF, 5'd0, 5'd1);
        apply_cycle("read_zero",  1'b0, 5'd0, 32'h0,        5'd0, 5'd1);
        apply_cycle("write_bypass", 1'b1, 5'd5, 32'h12345678, 5'd5, 5'd6);
        apply_cycle("read_after",   1'b0, 5'd0, 32'h0,        5'd5, 5'd6);
        apply_cycle("write_other",  1'b1, 5'd10, 32'hDEADBEEF, 5'd10, 5'd5);
        apply_cycle("read_after2",  1'b0, 5'd0, 32'h0,        5'd10, 5'd5);

        // Randomized tests
        for (i = 0; i < 100; i = i + 1) begin
            apply_cycle("rand", $urandom_range(0,1), $urandom_range(0,31), $urandom, $urandom_range(0,31), $urandom_range(0,31));
        end

        // ========================================
        // Final Results
        // ========================================
        $display("\n===========================================");
        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end
        $display("===========================================");
        
        $finish;
    end
endmodule
