module axis_width_converter (
    input  wire        clk,
    input  wire        rst,
    input  wire [7:0]  s_axis_tdata,
    input  wire        s_axis_tvalid,
    output wire        s_axis_tready,
    input  wire        s_axis_tlast,
    output reg  [31:0] m_axis_tdata,
    output reg  [3:0]  m_axis_tkeep,
    output reg         m_axis_tvalid,
    input  wire        m_axis_tready,
    output reg         m_axis_tlast
);
    reg [31:0] pack_data;
    reg [1:0] pack_count;
    reg [31:0] next_word;

    assign s_axis_tready = ~m_axis_tvalid;

    always @(posedge clk) begin
        if (rst) begin
            pack_data <= 32'b0;
            pack_count <= 2'b0;
            m_axis_tdata <= 32'b0;
            m_axis_tkeep <= 4'b0;
            m_axis_tvalid <= 1'b0;
            m_axis_tlast <= 1'b0;
        end else begin
            if (m_axis_tvalid && m_axis_tready) begin
                m_axis_tvalid <= 1'b0;
                m_axis_tlast <= 1'b0;
            end

            if (!m_axis_tvalid && s_axis_tvalid) begin
                next_word = pack_data;
                case (pack_count)
                    2'd0: next_word[7:0] = s_axis_tdata;
                    2'd1: next_word[15:8] = s_axis_tdata;
                    2'd2: next_word[23:16] = s_axis_tdata;
                    default: next_word[31:24] = s_axis_tdata;
                endcase

                if (s_axis_tlast || (pack_count == 2'd3)) begin
                    m_axis_tdata <= next_word;
                    case (pack_count)
                        2'd0: m_axis_tkeep <= 4'b0001;
                        2'd1: m_axis_tkeep <= 4'b0011;
                        2'd2: m_axis_tkeep <= 4'b0111;
                        default: m_axis_tkeep <= 4'b1111;
                    endcase
                    m_axis_tvalid <= 1'b1;
                    m_axis_tlast <= s_axis_tlast;
                    pack_data <= 32'b0;
                    pack_count <= 2'b0;
                end else begin
                    pack_data <= next_word;
                    pack_count <= pack_count + 2'd1;
                end
            end
        end
    end
endmodule
