`timescale 1ns / 1ps

module matrix_mult_4x4 (
    input  wire        clk,
    input  wire        rst,
    input  wire [255:0] a_in,
    input  wire [255:0] b_in,
    output reg  [511:0] c_out,
    input  wire        in_valid,
    output wire        in_ready,
    output reg         out_valid,
    input  wire        out_ready
);

    // FSM states
    localparam IDLE    = 2'd0,
               COMPUTE = 2'd1,
               DONE    = 2'd2;

    reg [1:0] state;

    // Unpacked input matrices
    reg signed [15:0] a [0:15];
    reg signed [15:0] b [0:15];

    // Result matrix
    reg signed [31:0] c [0:15];

    // Computation indices
    reg [3:0] comp_idx; // 0..15 for the 16 output elements

    integer idx;

    // Ready when IDLE
    assign in_ready = (state == IDLE);

    always @(posedge clk) begin
        if (rst) begin
            state     <= IDLE;
            out_valid <= 1'b0;
            c_out     <= 512'd0;
            comp_idx  <= 4'd0;
        end else begin
            case (state)
                IDLE: begin
                    out_valid <= 1'b0;
                    if (in_valid && in_ready) begin
                        // Latch inputs
                        for (idx = 0; idx < 16; idx = idx + 1) begin
                            a[idx] <= a_in[idx*16 +: 16];
                            b[idx] <= b_in[idx*16 +: 16];
                        end
                        comp_idx <= 4'd0;
                        state    <= COMPUTE;
                    end
                end

                COMPUTE: begin
                    // Compute one output element per cycle
                    // C[row][col] = sum(A[row][k] * B[k][col]) for k=0..3
                    // comp_idx encodes row*4+col
                    c[comp_idx] <=
                        a[comp_idx[3:2]*4 + 0] * b[0*4 + comp_idx[1:0]] +
                        a[comp_idx[3:2]*4 + 1] * b[1*4 + comp_idx[1:0]] +
                        a[comp_idx[3:2]*4 + 2] * b[2*4 + comp_idx[1:0]] +
                        a[comp_idx[3:2]*4 + 3] * b[3*4 + comp_idx[1:0]];

                    if (comp_idx == 4'd15) begin
                        state <= DONE;
                    end else begin
                        comp_idx <= comp_idx + 4'd1;
                    end
                end

                DONE: begin
                    // Pack output
                    for (idx = 0; idx < 16; idx = idx + 1) begin
                        c_out[idx*32 +: 32] <= c[idx];
                    end
                    out_valid <= 1'b1;
                    if (out_valid && out_ready) begin
                        out_valid <= 1'b0;
                        state     <= IDLE;
                    end
                end

                default: state <= IDLE;
            endcase
        end
    end

endmodule
