`timescale 1ns / 1ps

module tb_streaming_histogram64;
    reg clk;
    reg rst;
    reg in_valid;
    reg [5:0] sample_in;

    wire out_valid;
    wire [5:0] bin_index;
    wire [15:0] bin_count;

    integer errors = 0;
    integer tests_run = 0;
    integer cycle_count = 0;
    integer i;
    integer j;
    integer accepted_count;
    integer rng_seed;

    reg [15:0] expected_bins [0:63];

    streaming_histogram64 uut (
        .clk(clk),
        .rst(rst),
        .in_valid(in_valid),
        .sample_in(sample_in),
        .out_valid(out_valid),
        .bin_index(bin_index),
        .bin_count(bin_count)
    );

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    function [5:0] quadratic_pattern;
        input integer idx;
        begin
            quadratic_pattern = ((idx * idx) + (idx * 7) + 11) & 6'h3f;
        end
    endfunction

    task clear_model;
        integer idx;
        begin
            for (idx = 0; idx < 64; idx = idx + 1)
                expected_bins[idx] = 16'd0;
            accepted_count = 0;
        end
    endtask

    task step_cycle;
        input [255:0] tag;
        input next_rst;
        input next_in_valid;
        input [5:0] next_sample_in;
        begin
            rst = next_rst;
            in_valid = next_in_valid;
            sample_in = next_sample_in;

            @(posedge clk);
            #1;

            cycle_count = cycle_count + 1;
            tests_run = tests_run + 1;
        end
    endtask

    task check_zero_outputs;
        input [255:0] tag;
        input next_rst;
        input next_in_valid;
        input [5:0] next_sample_in;
        begin
            if (out_valid !== 1'b0) begin
                $display("ERROR [cycle %0d][%0s]: out_valid mismatch. Expected 0, got %b (rst=%b in_valid=%b sample=%0d)",
                         cycle_count, tag, out_valid, next_rst, next_in_valid, next_sample_in);
                errors = errors + 1;
            end

            if (bin_index !== 6'd0) begin
                $display("ERROR [cycle %0d][%0s]: bin_index mismatch. Expected 0, got %0d (rst=%b in_valid=%b sample=%0d)",
                         cycle_count, tag, bin_index, next_rst, next_in_valid, next_sample_in);
                errors = errors + 1;
            end

            if (bin_count !== 16'd0) begin
                $display("ERROR [cycle %0d][%0s]: bin_count mismatch. Expected 0, got %0d (rst=%b in_valid=%b sample=%0d)",
                         cycle_count, tag, bin_count, next_rst, next_in_valid, next_sample_in);
                errors = errors + 1;
            end
        end
    endtask

    task send_reset_cycle;
        input [255:0] tag;
        input next_in_valid;
        input [5:0] next_sample_in;
        begin
            step_cycle(tag, 1'b1, next_in_valid, next_sample_in);
            clear_model;
            check_zero_outputs(tag, 1'b1, next_in_valid, next_sample_in);
        end
    endtask

    task send_idle_cycle;
        input [255:0] tag;
        begin
            step_cycle(tag, 1'b0, 1'b0, 6'd0);
            check_zero_outputs(tag, 1'b0, 1'b0, 6'd0);
        end
    endtask

    task send_sample_cycle;
        input [255:0] tag;
        input [5:0] value;
        begin
            step_cycle(tag, 1'b0, 1'b1, value);

            if (accepted_count >= 256) begin
                $display("ERROR [cycle %0d][%0s]: attempted to send more than 256 accepted samples before readout", cycle_count, tag);
                errors = errors + 1;
            end

            expected_bins[value] = expected_bins[value] + 16'd1;
            accepted_count = accepted_count + 1;

            check_zero_outputs(tag, 1'b0, 1'b1, value);
        end
    endtask

    task check_readout_cycle;
        input [255:0] tag;
        input integer expected_idx;
        begin
            step_cycle(tag, 1'b0, 1'b0, 6'd0);

            if (out_valid !== 1'b1) begin
                $display("ERROR [cycle %0d][%0s]: out_valid mismatch. Expected 1, got %b (expected_idx=%0d)",
                         cycle_count, tag, out_valid, expected_idx);
                errors = errors + 1;
            end

            if (bin_index !== expected_idx[5:0]) begin
                $display("ERROR [cycle %0d][%0s]: bin_index mismatch. Expected %0d, got %0d",
                         cycle_count, tag, expected_idx, bin_index);
                errors = errors + 1;
            end

            if (bin_count !== expected_bins[expected_idx]) begin
                $display("ERROR [cycle %0d][%0s]: bin_count mismatch for bin %0d. Expected %0d, got %0d",
                         cycle_count, tag, expected_idx, expected_bins[expected_idx], bin_count);
                errors = errors + 1;
            end
        end
    endtask

    task verify_full_readout;
        input [255:0] tag;
        integer idx;
        integer total_expected;
        integer total_observed;
        begin
            if (accepted_count !== 256) begin
                $display("ERROR [cycle %0d][%0s]: attempted readout check with accepted_count=%0d instead of 256",
                         cycle_count, tag, accepted_count);
                errors = errors + 1;
            end

            total_expected = 0;
            total_observed = 0;

            for (idx = 0; idx < 64; idx = idx + 1) begin
                total_expected = total_expected + expected_bins[idx];
                check_readout_cycle(tag, idx);
                total_observed = total_observed + bin_count;
            end

            if (total_expected !== 256) begin
                $display("ERROR [cycle %0d][%0s]: internal shadow histogram sum mismatch. Expected 256, got %0d",
                         cycle_count, tag, total_expected);
                errors = errors + 1;
            end

            if (total_observed !== 256) begin
                $display("ERROR [cycle %0d][%0s]: observed readout sum mismatch. Expected 256, got %0d",
                         cycle_count, tag, total_observed);
                errors = errors + 1;
            end

            send_idle_cycle("post_readout_idle");
            clear_model;
        end
    endtask

    task verify_partial_readout_then_reset;
        input [255:0] tag;
        input integer outputs_before_reset;
        integer idx;
        begin
            if (accepted_count !== 256) begin
                $display("ERROR [cycle %0d][%0s]: attempted partial-readout reset check with accepted_count=%0d instead of 256",
                         cycle_count, tag, accepted_count);
                errors = errors + 1;
            end

            for (idx = 0; idx < outputs_before_reset; idx = idx + 1)
                check_readout_cycle(tag, idx);

            send_reset_cycle("mid_readout_reset", 1'b0, 6'd0);
            send_idle_cycle("post_mid_readout_reset_idle");
        end
    endtask

    initial begin
        rst = 1'b0;
        in_valid = 1'b0;
        sample_in = 6'd0;
        accepted_count = 0;
        rng_seed = 32'h13579bdf;
        clear_model;

        $display("===========================================");
        $display("   64-Bin Streaming Histogram Testbench");
        $display("===========================================");

        send_reset_cycle("startup_reset_0", 1'b0, 6'd0);
        send_reset_cycle("startup_reset_1", 1'b1, 6'd17);
        send_idle_cycle("post_reset_idle");

        for (i = 0; i < 20; i = i + 1) begin
            send_sample_cycle("partial_frame_with_idles", (i * 3) & 6'h3f);
            if ((i % 5) == 2)
                send_idle_cycle("partial_frame_gap");
        end
        send_reset_cycle("mid_frame_reset", 1'b0, 6'd0);
        send_idle_cycle("mid_frame_reset_idle");

        for (i = 0; i < 256; i = i + 1)
            send_sample_cycle("all_same_frame", 6'd21);
        verify_full_readout("all_same_readout");

        for (i = 0; i < 256; i = i + 1) begin
            if ((i & 1) == 0)
                send_sample_cycle("alternating_frame", 6'd9);
            else
                send_sample_cycle("alternating_frame", 6'd42);
        end
        verify_full_readout("alternating_readout");

        for (j = 0; j < 4; j = j + 1) begin
            for (i = 0; i < 64; i = i + 1)
                send_sample_cycle("uniform_full_range_frame", i[5:0]);
        end
        verify_full_readout("uniform_full_range_readout");

        i = 0;
        while (accepted_count < 256) begin
            if ((i % 6) == 5)
                send_idle_cycle("pattern_frame_gap");
            send_sample_cycle("pattern_frame", quadratic_pattern(i));
            i = i + 1;
        end
        verify_full_readout("pattern_readout");

        for (i = 0; i < 256; i = i + 1) begin
            if (($random(rng_seed) & 16'h7) < 16'h5)
                send_sample_cycle("biased_random_frame", 6'd12);
            else
                send_sample_cycle("biased_random_frame", $random(rng_seed) & 6'h3f);
        end
        verify_partial_readout_then_reset("biased_random_readout", 11);

        for (j = 0; j < 2; j = j + 1) begin
            for (i = 0; i < 256; i = i + 1)
                send_sample_cycle("random_regression_frame", $random(rng_seed) & 6'h3f);
            verify_full_readout("random_regression_readout");
        end

        $display("");
        $display("===========================================");
        $display("  Tests Run: %0d", tests_run);
        $display("===========================================");

        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end

        $finish;
    end
endmodule
