module systolic_mac2x2 (
    input wire clk,
    input wire rst,
    input wire weight_load,
    input wire [31:0] weight_data,
    input wire in_valid,
    input wire [7:0] in_data,
    output wire out_valid,
    output wire [16:0] out_data
);
    reg signed [7:0] w00;
    reg signed [7:0] w01;
    reg signed [7:0] w10;
    reg signed [7:0] w11;
    reg signed [7:0] x0_hold;
    reg partial_valid;

    reg out_valid_r;
    reg signed [16:0] out_data_r;
    reg next1_valid;
    reg signed [16:0] next1_data;
    reg next2_valid;
    reg signed [16:0] next2_data;

    always @(posedge clk) begin
        if (rst) begin
            w00 <= 8'sd0;
            w01 <= 8'sd0;
            w10 <= 8'sd0;
            w11 <= 8'sd0;
            x0_hold <= 8'sd0;
            partial_valid <= 1'b0;
            out_valid_r <= 1'b0;
            out_data_r <= 17'sd0;
            next1_valid <= 1'b0;
            next1_data <= 17'sd0;
            next2_valid <= 1'b0;
            next2_data <= 17'sd0;
        end else begin
            out_valid_r <= next1_valid;
            out_data_r <= next1_valid ? next1_data : 17'sd0;

            next1_valid <= next2_valid;
            next1_data <= next2_valid ? next2_data : 17'sd0;
            next2_valid <= 1'b0;
            next2_data <= 17'sd0;

            if (weight_load) begin
                w00 <= weight_data[7:0];
                w01 <= weight_data[15:8];
                w10 <= weight_data[23:16];
                w11 <= weight_data[31:24];
            end else if (in_valid) begin
                if (!partial_valid) begin
                    x0_hold <= in_data;
                    partial_valid <= 1'b1;
                end else begin
                    next1_valid <= 1'b1;
                    next1_data <= ($signed(w00) * $signed(x0_hold)) +
                                  ($signed(w01) * $signed(in_data));
                    next2_valid <= 1'b1;
                    next2_data <= ($signed(w10) * $signed(x0_hold)) +
                                  ($signed(w11) * $signed(in_data));
                    partial_valid <= 1'b0;
                end
            end
        end
    end

    assign out_valid = out_valid_r;
    assign out_data = out_valid_r ? out_data_r : 17'b0;
endmodule
