`timescale 1ns / 1ps

module tb_axis_width_converter;
    reg clk;
    reg rst;

    reg [7:0] s_axis_tdata;
    reg s_axis_tvalid;
    wire s_axis_tready;
    reg s_axis_tlast;

    wire [31:0] m_axis_tdata;
    wire [3:0] m_axis_tkeep;
    wire m_axis_tvalid;
    reg m_axis_tready;
    wire m_axis_tlast;

    integer errors = 0;
    integer tests_run = 0;
    integer i;

    reg [31:0] golden_pack;
    integer golden_count;

    localparam integer QUEUE_SIZE = 512;
    reg [31:0] expected_data [0:QUEUE_SIZE-1];
    reg [3:0] expected_keep [0:QUEUE_SIZE-1];
    reg expected_last [0:QUEUE_SIZE-1];
    integer exp_head;
    integer exp_tail;

    reg [7:0] packet_bytes [0:31];
    reg last_input_handshake;

    localparam integer READY_HIGH = 0;
    localparam integer READY_RANDOM = 1;
    localparam integer READY_STALL_WINDOW = 2;

    axis_width_converter uut (
        .clk(clk),
        .rst(rst),
        .s_axis_tdata(s_axis_tdata),
        .s_axis_tvalid(s_axis_tvalid),
        .s_axis_tready(s_axis_tready),
        .s_axis_tlast(s_axis_tlast),
        .m_axis_tdata(m_axis_tdata),
        .m_axis_tkeep(m_axis_tkeep),
        .m_axis_tvalid(m_axis_tvalid),
        .m_axis_tready(m_axis_tready),
        .m_axis_tlast(m_axis_tlast)
    );

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    function integer queue_depth;
        begin
            if (exp_tail >= exp_head)
                queue_depth = exp_tail - exp_head;
            else
                queue_depth = QUEUE_SIZE - exp_head + exp_tail;
        end
    endfunction

    function choose_ready;
        input integer mode;
        input integer cycle_idx;
        begin
            case (mode)
                READY_HIGH: choose_ready = 1'b1;
                READY_RANDOM: choose_ready = $urandom_range(0, 1);
                READY_STALL_WINDOW: choose_ready = (cycle_idx >= 4 && cycle_idx <= 6) ? 1'b0 : 1'b1;
                default: choose_ready = 1'b1;
            endcase
        end
    endfunction

    task clear_golden;
        begin
            golden_pack = 32'b0;
            golden_count = 0;
            exp_head = 0;
            exp_tail = 0;
        end
    endtask

    task enqueue_expected;
        input [31:0] data;
        input [3:0] keep;
        input last;
        begin
            if (queue_depth() >= QUEUE_SIZE - 1) begin
                $display("ERROR: expected queue overflow");
                errors = errors + 1;
            end else begin
                expected_data[exp_tail] = data;
                expected_keep[exp_tail] = keep;
                expected_last[exp_tail] = last;
                if (exp_tail == QUEUE_SIZE - 1)
                    exp_tail = 0;
                else
                    exp_tail = exp_tail + 1;
            end
        end
    endtask

    task pop_and_compare;
        input [255:0] tag;
        input [31:0] got_data;
        input [3:0] got_keep;
        input got_last;
        begin
            if (queue_depth() == 0) begin
                $display("ERROR [%s]: unexpected output handshake", tag);
                errors = errors + 1;
            end else begin
                if (got_data !== expected_data[exp_head]) begin
                    $display("ERROR [%s]: m_axis_tdata mismatch. Expected %h, got %h",
                             tag, expected_data[exp_head], got_data);
                    errors = errors + 1;
                end
                if (got_keep !== expected_keep[exp_head]) begin
                    $display("ERROR [%s]: m_axis_tkeep mismatch. Expected %b, got %b",
                             tag, expected_keep[exp_head], got_keep);
                    errors = errors + 1;
                end
                if (got_last !== expected_last[exp_head]) begin
                    $display("ERROR [%s]: m_axis_tlast mismatch. Expected %b, got %b",
                             tag, expected_last[exp_head], got_last);
                    errors = errors + 1;
                end

                if (exp_head == QUEUE_SIZE - 1)
                    exp_head = 0;
                else
                    exp_head = exp_head + 1;
            end
        end
    endtask

    task golden_accept_byte;
        input [7:0] data_byte;
        input last_flag;
        reg [31:0] next_word;
        reg [3:0] next_keep;
        begin
            next_word = golden_pack;
            case (golden_count)
                0: next_word[7:0] = data_byte;
                1: next_word[15:8] = data_byte;
                2: next_word[23:16] = data_byte;
                3: next_word[31:24] = data_byte;
                default: next_word = 32'b0;
            endcase

            if (last_flag || (golden_count == 3)) begin
                case (golden_count)
                    0: next_keep = 4'b0001;
                    1: next_keep = 4'b0011;
                    2: next_keep = 4'b0111;
                    default: next_keep = 4'b1111;
                endcase
                enqueue_expected(next_word, next_keep, last_flag);
                golden_pack = 32'b0;
                golden_count = 0;
            end else begin
                golden_pack = next_word;
                golden_count = golden_count + 1;
            end
        end
    endtask

    task check_current_outputs;
        input [255:0] tag;
        begin
            if (queue_depth() == 0) begin
                if (m_axis_tvalid !== 1'b0) begin
                    $display("ERROR [%s]: m_axis_tvalid should be 0 when no output is pending", tag);
                    errors = errors + 1;
                end
                if (m_axis_tlast !== 1'b0) begin
                    $display("ERROR [%s]: m_axis_tlast should be 0 when no output is pending", tag);
                    errors = errors + 1;
                end
            end else begin
                if (m_axis_tvalid !== 1'b1) begin
                    $display("ERROR [%s]: m_axis_tvalid should be 1 when output is pending", tag);
                    errors = errors + 1;
                end
                if (m_axis_tdata !== expected_data[exp_head]) begin
                    $display("ERROR [%s]: current m_axis_tdata mismatch. Expected %h, got %h",
                             tag, expected_data[exp_head], m_axis_tdata);
                    errors = errors + 1;
                end
                if (m_axis_tkeep !== expected_keep[exp_head]) begin
                    $display("ERROR [%s]: current m_axis_tkeep mismatch. Expected %b, got %b",
                             tag, expected_keep[exp_head], m_axis_tkeep);
                    errors = errors + 1;
                end
                if (m_axis_tlast !== expected_last[exp_head]) begin
                    $display("ERROR [%s]: current m_axis_tlast mismatch. Expected %b, got %b",
                             tag, expected_last[exp_head], m_axis_tlast);
                    errors = errors + 1;
                end
            end
        end
    endtask

    task step_cycle;
        input [255:0] tag;
        input next_rst;
        input next_s_valid;
        input [7:0] next_s_data;
        input next_s_last;
        input next_m_ready;
        reg pre_s_ready;
        reg pre_m_valid;
        reg [31:0] pre_m_data;
        reg [3:0] pre_m_keep;
        reg pre_m_last;
        begin
            rst = next_rst;
            s_axis_tvalid = next_s_valid;
            s_axis_tdata = next_s_data;
            s_axis_tlast = next_s_last;
            m_axis_tready = next_m_ready;

            #1;
            pre_s_ready = s_axis_tready;
            pre_m_valid = m_axis_tvalid;
            pre_m_data = m_axis_tdata;
            pre_m_keep = m_axis_tkeep;
            pre_m_last = m_axis_tlast;

            last_input_handshake = 1'b0;

            @(posedge clk);
            #1;

            tests_run = tests_run + 1;

            if (next_rst) begin
                clear_golden;
                if (m_axis_tvalid !== 1'b0) begin
                    $display("ERROR [%s]: output valid not cleared by reset", tag);
                    errors = errors + 1;
                end
                if (m_axis_tlast !== 1'b0) begin
                    $display("ERROR [%s]: output last not cleared by reset", tag);
                    errors = errors + 1;
                end
            end else begin
                if (pre_m_valid && next_m_ready)
                    pop_and_compare(tag, pre_m_data, pre_m_keep, pre_m_last);

                if (next_s_valid && pre_s_ready) begin
                    golden_accept_byte(next_s_data, next_s_last);
                    last_input_handshake = 1'b1;
                end

                check_current_outputs(tag);
            end
        end
    endtask

    task send_loaded_packet;
        input [255:0] tag;
        input integer length;
        input integer ready_mode;
        integer idx;
        integer local_cycle;
        reg chosen_ready;
        begin
            idx = 0;
            local_cycle = 0;
            while (idx < length) begin
                chosen_ready = choose_ready(ready_mode, local_cycle);
                step_cycle(tag, 1'b0, 1'b1, packet_bytes[idx], (idx == length - 1), chosen_ready);
                if (last_input_handshake)
                    idx = idx + 1;
                local_cycle = local_cycle + 1;
            end
        end
    endtask

    task run_idle_cycles;
        input [255:0] tag;
        input integer count;
        input ready_value;
        integer local_idx;
        begin
            for (local_idx = 0; local_idx < count; local_idx = local_idx + 1)
                step_cycle(tag, 1'b0, 1'b0, 8'h00, 1'b0, ready_value);
        end
    endtask

    task flush_outputs;
        input [255:0] tag;
        input integer max_cycles;
        integer local_idx;
        begin
            local_idx = 0;
            while ((queue_depth() != 0 || m_axis_tvalid !== 1'b0) && (local_idx < max_cycles)) begin
                step_cycle(tag, 1'b0, 1'b0, 8'h00, 1'b0, 1'b1);
                local_idx = local_idx + 1;
            end

            if (queue_depth() != 0 || m_axis_tvalid !== 1'b0) begin
                $display("ERROR [%s]: flush timeout with pending outputs still present", tag);
                errors = errors + 1;
            end
        end
    endtask

    initial begin
        rst = 1'b0;
        s_axis_tdata = 8'h00;
        s_axis_tvalid = 1'b0;
        s_axis_tlast = 1'b0;
        m_axis_tready = 1'b0;
        clear_golden;

        $display("===========================================");
        $display("  AXI4-Stream Width Converter Testbench");
        $display("===========================================");

        // Reset and idle.
        step_cycle("reset_0", 1'b1, 1'b0, 8'h00, 1'b0, 1'b0);
        step_cycle("reset_1", 1'b1, 1'b1, 8'hAA, 1'b1, 1'b1);
        step_cycle("idle_after_reset", 1'b0, 1'b0, 8'h00, 1'b0, 1'b1);

        // Full 4-byte packet.
        packet_bytes[0] = 8'h11;
        packet_bytes[1] = 8'h22;
        packet_bytes[2] = 8'h33;
        packet_bytes[3] = 8'h44;
        send_loaded_packet("full_word", 4, READY_HIGH);
        flush_outputs("full_word_flush", 8);

        // Partial packets: 1, 2, and 3 bytes.
        packet_bytes[0] = 8'hA5;
        send_loaded_packet("partial_len1", 1, READY_HIGH);
        flush_outputs("partial_len1_flush", 8);

        packet_bytes[0] = 8'hC1;
        packet_bytes[1] = 8'hD2;
        send_loaded_packet("partial_len2", 2, READY_HIGH);
        flush_outputs("partial_len2_flush", 8);

        packet_bytes[0] = 8'h01;
        packet_bytes[1] = 8'h02;
        packet_bytes[2] = 8'h03;
        send_loaded_packet("partial_len3", 3, READY_HIGH);
        flush_outputs("partial_len3_flush", 8);

        // Multi-word packets.
        packet_bytes[0] = 8'h10;
        packet_bytes[1] = 8'h20;
        packet_bytes[2] = 8'h30;
        packet_bytes[3] = 8'h40;
        packet_bytes[4] = 8'h50;
        send_loaded_packet("multi_len5", 5, READY_HIGH);
        flush_outputs("multi_len5_flush", 12);

        packet_bytes[0] = 8'h81;
        packet_bytes[1] = 8'h82;
        packet_bytes[2] = 8'h83;
        packet_bytes[3] = 8'h84;
        packet_bytes[4] = 8'h85;
        packet_bytes[5] = 8'h86;
        packet_bytes[6] = 8'h87;
        send_loaded_packet("multi_len7", 7, READY_HIGH);
        flush_outputs("multi_len7_flush", 16);

        packet_bytes[0] = 8'h91;
        packet_bytes[1] = 8'h92;
        packet_bytes[2] = 8'h93;
        packet_bytes[3] = 8'h94;
        packet_bytes[4] = 8'h95;
        packet_bytes[5] = 8'h96;
        packet_bytes[6] = 8'h97;
        packet_bytes[7] = 8'h98;
        send_loaded_packet("multi_len8", 8, READY_HIGH);
        flush_outputs("multi_len8_flush", 16);

        // Hold a full word under output backpressure.
        packet_bytes[0] = 8'hDE;
        packet_bytes[1] = 8'hAD;
        packet_bytes[2] = 8'hBE;
        packet_bytes[3] = 8'hEF;
        send_loaded_packet("stall_full_word", 4, READY_HIGH);
        run_idle_cycles("stall_full_word_hold", 3, 1'b0);
        flush_outputs("stall_full_word_flush", 8);

        // Hold a partial word under output backpressure.
        packet_bytes[0] = 8'h12;
        packet_bytes[1] = 8'h34;
        packet_bytes[2] = 8'h56;
        send_loaded_packet("stall_partial_word", 3, READY_HIGH);
        run_idle_cycles("stall_partial_hold", 2, 1'b0);
        flush_outputs("stall_partial_flush", 8);

        // Keep offering input while output is intermittently blocked.
        packet_bytes[0] = 8'h01;
        packet_bytes[1] = 8'h11;
        packet_bytes[2] = 8'h21;
        packet_bytes[3] = 8'h31;
        packet_bytes[4] = 8'h41;
        packet_bytes[5] = 8'h51;
        packet_bytes[6] = 8'h61;
        packet_bytes[7] = 8'h71;
        send_loaded_packet("stall_during_stream", 8, READY_STALL_WINDOW);
        flush_outputs("stall_during_stream_flush", 20);

        // Back-to-back packets with no intentional gap.
        packet_bytes[0] = 8'hAA;
        packet_bytes[1] = 8'hBB;
        packet_bytes[2] = 8'hCC;
        packet_bytes[3] = 8'hDD;
        send_loaded_packet("back_to_back_a", 4, READY_HIGH);
        packet_bytes[0] = 8'hE1;
        packet_bytes[1] = 8'hE2;
        packet_bytes[2] = 8'hE3;
        send_loaded_packet("back_to_back_b", 3, READY_HIGH);
        flush_outputs("back_to_back_flush", 16);

        // Reset clears partial assembly.
        packet_bytes[0] = 8'h55;
        packet_bytes[1] = 8'h66;
        send_loaded_packet("reset_partial_send", 2, READY_HIGH);
        step_cycle("reset_partial_assert", 1'b1, 1'b0, 8'h00, 1'b0, 1'b1);
        step_cycle("reset_partial_release", 1'b0, 1'b0, 8'h00, 1'b0, 1'b1);

        // Reset clears a pending full word.
        packet_bytes[0] = 8'hF1;
        packet_bytes[1] = 8'hF2;
        packet_bytes[2] = 8'hF3;
        packet_bytes[3] = 8'hF4;
        send_loaded_packet("reset_pending_send", 4, READY_HIGH);
        run_idle_cycles("reset_pending_hold", 1, 1'b0);
        step_cycle("reset_pending_assert", 1'b1, 1'b0, 8'h00, 1'b0, 1'b0);
        step_cycle("reset_pending_release", 1'b0, 1'b0, 8'h00, 1'b0, 1'b1);

        // Randomized regression.
        for (i = 0; i < 25; i = i + 1) begin
            integer pkt_len;
            integer gap_len;
            integer j;

            pkt_len = $urandom_range(1, 12);
            for (j = 0; j < pkt_len; j = j + 1)
                packet_bytes[j] = $urandom;

            send_loaded_packet("random_packet", pkt_len, READY_RANDOM);

            gap_len = $urandom_range(0, 2);
            run_idle_cycles("random_gap", gap_len, $urandom_range(0, 1));
        end
        flush_outputs("random_flush", 100);

        $display("");
        $display("===========================================");
        $display("  Tests Run: %0d", tests_run);
        $display("===========================================");

        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end

        $finish;
    end
endmodule
