#!/usr/bin/env python3
"""Generate deterministic BPM test vectors for the sliding-window testbench."""

from pathlib import Path
import random


WINDOW_SIZE = 100
SCALE = 12


def append_reset(vectors, cycles):
    for _ in range(cycles):
        vectors.append({"rst": 1, "peak_in": 0})


def append_idle(vectors, cycles):
    for _ in range(cycles):
        vectors.append({"rst": 0, "peak_in": 0})


def append_pattern(vectors, cycles, interval, phase=0):
    for idx in range(cycles):
        peak = 1 if idx >= phase and ((idx - phase) % interval == 0) else 0
        vectors.append({"rst": 0, "peak_in": peak})


def append_random_legal(vectors, cycles, seed):
    rng = random.Random(seed)
    cooldown = 0

    for _ in range(cycles):
        peak = 0
        if cooldown == 0 and rng.random() < 0.17:
            peak = 1
            cooldown = 3
        else:
            if cooldown > 0:
                cooldown -= 1
        vectors.append({"rst": 0, "peak_in": peak})


def compute_expected(vectors):
    history = [0] * WINDOW_SIZE
    sample_count = 0
    beat_count = 0

    enriched = []
    for vec in vectors:
        rst = vec["rst"]
        peak = vec["peak_in"]

        if rst:
            history = [0] * WINDOW_SIZE
            sample_count = 0
            beat_count = 0
            bpm_valid = 0
            bpm_out = 0
        else:
            if sample_count < WINDOW_SIZE:
                history = history[1:] + [peak]
                beat_count += peak
                sample_count += 1
                if sample_count == WINDOW_SIZE:
                    bpm_valid = 1
                    bpm_out = beat_count * SCALE
                else:
                    bpm_valid = 0
                    bpm_out = 0
            else:
                outgoing = history[0]
                history = history[1:] + [peak]
                beat_count += peak - outgoing
                bpm_valid = 1
                bpm_out = beat_count * SCALE

        enriched.append(
            {
                "rst": rst,
                "peak_in": peak,
                "expected_valid": bpm_valid,
                "expected_bpm": bpm_out,
            }
        )

    return enriched


def generate_vectors():
    vectors = []

    append_reset(vectors, 2)
    append_pattern(vectors, 100, interval=20, phase=0)
    append_idle(vectors, 25)
    append_idle(vectors, 85)

    append_reset(vectors, 1)
    append_idle(vectors, 10)
    append_pattern(vectors, 100, interval=25, phase=5)
    append_pattern(vectors, 60, interval=20, phase=10)

    append_reset(vectors, 2)
    append_pattern(vectors, 120, interval=4, phase=0)

    append_reset(vectors, 1)
    append_random_legal(vectors, 260, seed=20260324)

    return compute_expected(vectors)


def build_array_block(vectors):
    lines = []
    lines.append("// Generated by generate_golden.py. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(vectors)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_peak_in [0:NUM_TESTS-1];")
    lines.append("reg expected_valid [0:NUM_TESTS-1];")
    lines.append("reg [8:0] expected_bpm [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")
    for idx, vec in enumerate(vectors):
        lines.append(f"    test_rst[{idx}] = 1'b{vec['rst']};")
        lines.append(f"    test_peak_in[{idx}] = 1'b{vec['peak_in']};")
        lines.append(f"    expected_valid[{idx}] = 1'b{vec['expected_valid']};")
        lines.append(f"    expected_bpm[{idx}] = 9'd{vec['expected_bpm']};")
    lines.append("end")
    lines.append("")
    return "\n".join(lines)


def write_verilog(vectors, output_path):
    with output_path.open("w", encoding="ascii") as f:
        f.write(build_array_block(vectors))


def write_testbench(vectors, output_path):
    array_block = build_array_block(vectors)
    content = f"""`timescale 1ns / 1ps

module tb_bpm_calculator;
    reg clk;
    reg rst;
    reg peak_in;

    wire bpm_valid;
    wire [8:0] bpm_out;

    integer errors = 0;
    integer tests_run = 0;
    integer i;

{array_block}

    bpm_calculator uut (
        .clk(clk),
        .rst(rst),
        .peak_in(peak_in),
        .bpm_valid(bpm_valid),
        .bpm_out(bpm_out)
    );

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    task apply_cycle;
        input integer cycle_idx;
        begin
            rst = test_rst[cycle_idx];
            peak_in = test_peak_in[cycle_idx];

            @(posedge clk);
            #1;

            tests_run = tests_run + 1;

            if (bpm_valid !== expected_valid[cycle_idx]) begin
                $display("ERROR [cycle %0d]: bpm_valid mismatch. Expected %b, got %b (rst=%b peak_in=%b)",
                         cycle_idx, expected_valid[cycle_idx], bpm_valid, rst, peak_in);
                errors = errors + 1;
            end

            if (bpm_out !== expected_bpm[cycle_idx]) begin
                $display("ERROR [cycle %0d]: bpm_out mismatch. Expected %0d, got %0d (rst=%b peak_in=%b)",
                         cycle_idx, expected_bpm[cycle_idx], bpm_out, rst, peak_in);
                errors = errors + 1;
            end
        end
    endtask

    initial begin
        rst = 1'b0;
        peak_in = 1'b0;

        $display("=======================================");
        $display("  Sliding-Window BPM Calculator Testbench");
        $display("=======================================");

        for (i = 0; i < NUM_TESTS; i = i + 1)
            apply_cycle(i);

        $display("");
        $display("=======================================");
        $display("  Tests Run: %0d", tests_run);
        $display("=======================================");

        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end

        $finish;
    end
endmodule
"""
    with output_path.open("w", encoding="ascii") as f:
        f.write(content)


def main():
    base_dir = Path(__file__).resolve().parent
    vectors = generate_vectors()
    write_verilog(vectors, base_dir / "golden_vectors.vh")
    write_testbench(vectors, base_dir / "testbench.v")
    print(f"Generated {len(vectors)} test vectors.")


if __name__ == "__main__":
    main()
