`timescale 1ns / 1ps

module tb_matrix_mult_4x4;
    // Clock and Reset
    reg clk;
    reg rst;
    
    // Input matrices A and B (4x4 of 16-bit signed integers)
    // Flattened: A[row][col] -> a[row*4 + col]
    reg signed [15:0] a [0:15];
    reg signed [15:0] b [0:15];
    
    // Packed input buses (256 bits each = 16 x 16-bit)
    reg [255:0] a_in;
    reg [255:0] b_in;
    
    // Output matrix C (4x4 of 32-bit signed integers for accumulation)
    wire [511:0] c_out;
    reg signed [31:0] c [0:15];  // Unpacked for checking
    
    // Flow control signals
    reg in_valid;
    wire in_ready;
    wire out_valid;
    reg out_ready;
    
    // Expected values
    reg signed [31:0] expected [0:15];
    
    integer errors = 0;
    integer i, j, k;
    integer timeout_count;
    parameter MAX_LATENCY = 200;
    
    // Instantiate Unit Under Test (UUT)
    matrix_mult_4x4 uut (
        .clk(clk),
        .rst(rst),
        .a_in(a_in),
        .b_in(b_in),
        .c_out(c_out),
        .in_valid(in_valid),
        .in_ready(in_ready),
        .out_valid(out_valid),
        .out_ready(out_ready)
    );
    
    // Clock Generation (100MHz)
    initial begin
        clk = 0;
        forever #5 clk = ~clk;
    end
    
    // Pack input arrays into buses
    task pack_inputs;
        integer idx;
        begin
            a_in = 256'd0;
            b_in = 256'd0;
            for (idx = 0; idx < 16; idx = idx + 1) begin
                a_in[idx*16 +: 16] = a[idx];
                b_in[idx*16 +: 16] = b[idx];
            end
        end
    endtask
    
    // Unpack output bus
    task unpack_output;
        integer idx;
        begin
            for (idx = 0; idx < 16; idx = idx + 1) begin
                c[idx] = c_out[idx*32 +: 32];
            end
        end
    endtask
    
    // Calculate expected matrix multiplication C = A * B
    task calculate_expected;
        integer row, col, kk;
        reg signed [31:0] sum;
        begin
            for (row = 0; row < 4; row = row + 1) begin
                for (col = 0; col < 4; col = col + 1) begin
                    sum = 0;
                    for (kk = 0; kk < 4; kk = kk + 1) begin
                        sum = sum + a[row*4 + kk] * b[kk*4 + col];
                    end
                    expected[row*4 + col] = sum;
                end
            end
        end
    endtask
    
    // Task to run a single test with valid/ready handshake
    task run_test;
        input [255:0] test_name;
        begin : run_test_body
            pack_inputs;
            calculate_expected;
            
            // Wait for ready, then assert valid
            @(posedge clk);
            in_valid = 1;
            
            // Wait for handshake (in_valid && in_ready)
            timeout_count = 0;
            while (!in_ready && timeout_count < MAX_LATENCY) begin
                @(posedge clk);
                timeout_count = timeout_count + 1;
            end
            
            if (timeout_count >= MAX_LATENCY) begin
                $display("ERROR %s: Timeout waiting for in_ready", test_name);
                errors = errors + 1;
                in_valid = 0;
                disable run_test_body;
            end
            
            // Handshake complete, deassert valid
            @(posedge clk);
            in_valid = 0;
            out_ready = 1;
            
            // Wait for output valid
            timeout_count = 0;
            while (!out_valid && timeout_count < MAX_LATENCY) begin
                @(posedge clk);
                timeout_count = timeout_count + 1;
            end
            
            if (timeout_count >= MAX_LATENCY) begin
                $display("ERROR %s: Timeout waiting for out_valid", test_name);
                errors = errors + 1;
                disable run_test_body;
            end
            
            // Check results
            #1;
            unpack_output;
            
            for (i = 0; i < 16; i = i + 1) begin
                if (c[i] !== expected[i]) begin
                    $display("ERROR %s: C[%0d] expected %0d, got %0d", 
                             test_name, i, expected[i], c[i]);
                    errors = errors + 1;
                end
            end
            
            if (errors == 0) begin
                $display("PASS: %s (latency=%0d cycles)", test_name, timeout_count + 1);
            end
            
            @(posedge clk);
            out_ready = 0;
        end
    endtask
    
    // Test Procedure
    initial begin
        // Initialize
        rst = 1;
        in_valid = 0;
        out_ready = 0;
        for (i = 0; i < 16; i = i + 1) begin
            a[i] = 0;
            b[i] = 0;
        end
        
        $display("--- Starting Test for matrix_mult_4x4 ---");
        $display("4x4 matrix multiplication: C = A * B");
        $display("16-bit signed inputs, 32-bit signed outputs");
        $display("------------------------------------------");
        
        // Reset sequence
        repeat(3) @(posedge clk);
        rst = 0;
        repeat(2) @(posedge clk);
        
        // --- Test 1: Identity multiplication ---
        // A = I (identity), B = some values => C = B
        $display("\n[Test 1] Identity matrix multiplication");
        // A = I (diagonal = 1)
        for (i = 0; i < 16; i = i + 1) a[i] = 0;
        a[0] = 1; a[5] = 1; a[10] = 1; a[15] = 1;
        // B = simple values
        for (i = 0; i < 16; i = i + 1) b[i] = i + 1;
        run_test("Identity");
        
        // --- Test 2: All ones ---
        $display("\n[Test 2] All ones matrices");
        for (i = 0; i < 16; i = i + 1) begin
            a[i] = 1;
            b[i] = 1;
        end
        // Each element of C should be 4 (sum of 4 ones)
        run_test("All Ones");
        
        // --- Test 3: Simple multiplication ---
        $display("\n[Test 3] Simple 2x2 in corner");
        for (i = 0; i < 16; i = i + 1) begin
            a[i] = 0;
            b[i] = 0;
        end
        // A[0][0]=1, A[0][1]=2, A[1][0]=3, A[1][1]=4
        a[0] = 1; a[1] = 2; a[4] = 3; a[5] = 4;
        // B[0][0]=5, B[0][1]=6, B[1][0]=7, B[1][1]=8
        b[0] = 5; b[1] = 6; b[4] = 7; b[5] = 8;
        run_test("Simple 2x2");
        
        // --- Test 4: Signed numbers ---
        $display("\n[Test 4] Signed numbers");
        for (i = 0; i < 16; i = i + 1) begin
            a[i] = 0;
            b[i] = 0;
        end
        a[0] = -1; a[1] = 2; a[4] = 3; a[5] = -4;
        b[0] = 10; b[1] = -5; b[4] = -2; b[5] = 3;
        run_test("Signed");
        
        // --- Test 5: Larger values ---
        $display("\n[Test 5] Larger values");
        for (i = 0; i < 16; i = i + 1) begin
            a[i] = (i + 1) * 100;
            b[i] = (16 - i) * 10;
        end
        run_test("Larger Values");
        
        // --- Final Results ---
        $display("\n------------------------------------------");
        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end
        
        $finish;
    end

endmodule
