#!/usr/bin/env python3
"""Generate deterministic vectors for the ECG R-peak detector benchmark.

This helper is offline-only. It documents the directed sample sequences and
computes the expected `peak_out` pulse for each cycle. The generated values
are hardcoded directly into `testbench.v`; the Verilog testbench does not
execute Python at runtime.
"""

from __future__ import annotations

from pathlib import Path


REFRACTORY_GAP = 50


def append_cycle(vectors, rst, sample, threshold, tag):
    vectors.append(
        {
            "rst": int(rst),
            "sample": int(sample),
            "threshold": int(threshold),
            "tag": tag,
        }
    )


def append_reset(vectors, cycles, tag):
    for _ in range(cycles):
        append_cycle(vectors, 1, 0, 500, tag)


def append_idle(vectors, cycles, sample, threshold, tag):
    for _ in range(cycles):
        append_cycle(vectors, 0, sample, threshold, tag)


def append_sequence(vectors, samples, thresholds, tag):
    if isinstance(thresholds, int):
        thresholds = [thresholds] * len(samples)
    for sample, threshold in zip(samples, thresholds):
        append_cycle(vectors, 0, sample, threshold, tag)


def compute_expected(vectors):
    history_samples = []
    history_thresholds = []
    refractory_count = 0

    enriched = []
    for vec in vectors:
        rst = vec["rst"]
        sample = vec["sample"]
        threshold = vec["threshold"]

        if rst:
            history_samples = []
            history_thresholds = []
            refractory_count = 0
            expected_peak = 0
        else:
            history_samples.append(sample)
            history_thresholds.append(threshold)

            if len(history_samples) > 5:
                history_samples.pop(0)
                history_thresholds.pop(0)

            if len(history_samples) == 5:
                center = history_samples[2]
                candidate = (
                    center >= history_thresholds[2]
                    and center > history_samples[0]
                    and center > history_samples[1]
                    and center > history_samples[3]
                    and center > history_samples[4]
                )

                if candidate and refractory_count == 0:
                    expected_peak = 1
                    refractory_count = REFRACTORY_GAP - 1
                else:
                    expected_peak = 0
                    if refractory_count > 0:
                        refractory_count -= 1
            else:
                expected_peak = 0

        enriched.append(
            {
                **vec,
                "expected_peak": expected_peak,
            }
        )

    return enriched


def generate_vectors():
    vectors = []

    append_reset(vectors, 2, "startup_reset")
    append_idle(vectors, 4, 20, 500, "startup_fill")

    append_reset(vectors, 1, "valid_reset")
    append_idle(vectors, 2, 30, 500, "valid_prefill")
    append_sequence(
        vectors,
        [100, 300, 600, 250, 120],
        [500, 500, 600, 500, 500],
        "valid_equal_threshold",
    )
    append_idle(vectors, 3, 40, 500, "valid_post")

    append_reset(vectors, 1, "threshold_block_reset")
    append_idle(vectors, 2, 30, 500, "threshold_block_prefill")
    append_sequence(
        vectors,
        [100, 300, 600, 250, 120],
        [500, 500, 601, 500, 500],
        "threshold_block_center",
    )
    append_idle(vectors, 3, 40, 500, "threshold_block_post")

    append_reset(vectors, 1, "plateau_left_reset")
    append_idle(vectors, 2, 25, 500, "plateau_left_prefill")
    append_sequence(
        vectors,
        [100, 700, 700, 300, 100],
        500,
        "plateau_left_equal",
    )
    append_idle(vectors, 3, 40, 500, "plateau_left_post")

    append_reset(vectors, 1, "plateau_right_reset")
    append_idle(vectors, 2, 25, 500, "plateau_right_prefill")
    append_sequence(
        vectors,
        [100, 300, 700, 700, 100],
        500,
        "plateau_right_equal",
    )
    append_idle(vectors, 3, 40, 500, "plateau_right_post")

    append_reset(vectors, 1, "refractory_block_reset")
    append_idle(vectors, 2, 20, 500, "refractory_block_prefill")
    append_sequence(
        vectors,
        [100, 300, 900, 250, 100],
        500,
        "refractory_block_first_peak",
    )
    append_idle(vectors, 25, 15, 500, "refractory_block_gap")
    append_sequence(
        vectors,
        [120, 350, 950, 260, 110],
        500,
        "refractory_block_second_peak",
    )
    append_idle(vectors, 3, 20, 500, "refractory_block_post")

    append_reset(vectors, 1, "refractory_boundary_reset")
    append_idle(vectors, 2, 20, 500, "refractory_boundary_prefill")
    append_sequence(
        vectors,
        [100, 300, 900, 250, 100],
        500,
        "refractory_boundary_first_peak",
    )
    append_idle(vectors, 45, 15, 500, "refractory_boundary_gap")
    append_sequence(
        vectors,
        [120, 350, 950, 260, 110],
        500,
        "refractory_boundary_second_peak",
    )
    append_idle(vectors, 3, 20, 500, "refractory_boundary_post")

    append_reset(vectors, 1, "reset_clears_reset")
    append_idle(vectors, 2, 20, 500, "reset_clears_prefill")
    append_sequence(
        vectors,
        [100, 300, 900, 250, 100],
        500,
        "reset_clears_first_peak",
    )
    append_idle(vectors, 15, 15, 500, "reset_clears_gap")
    append_reset(vectors, 1, "reset_clears_mid_reset")
    append_idle(vectors, 2, 20, 500, "reset_clears_prefill_after")
    append_sequence(
        vectors,
        [120, 350, 950, 260, 110],
        500,
        "reset_clears_second_peak",
    )
    append_idle(vectors, 3, 20, 500, "reset_clears_post")

    return compute_expected(vectors)


def build_array_block(vectors):
    lines = []
    lines.append("// Generated by generate_golden.py. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(vectors)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg [11:0] test_sample [0:NUM_TESTS-1];")
    lines.append("reg [11:0] test_threshold [0:NUM_TESTS-1];")
    lines.append("reg expected_peak [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")
    for idx, vec in enumerate(vectors):
        lines.append(f"    test_rst[{idx}] = 1'b{vec['rst']};")
        lines.append(f"    test_sample[{idx}] = 12'd{vec['sample']};")
        lines.append(f"    test_threshold[{idx}] = 12'd{vec['threshold']};")
        lines.append(f"    expected_peak[{idx}] = 1'b{vec['expected_peak']};")
        lines.append(f'    test_tag[{idx}] = "{vec["tag"]}";')
    lines.append("end")
    lines.append("")
    return "\n".join(lines)


def build_testbench(vectors):
    array_block = build_array_block(vectors)
    return f"""`timescale 1ns / 1ps

module tb_ecg_r_peak_detector;
    reg clk;
    reg rst;
    reg [11:0] sample_in;
    reg [11:0] threshold;

    wire peak_out;

    integer errors = 0;
    integer tests_run = 0;
    integer i;

{array_block}

    ecg_r_peak_detector uut (
        .clk(clk),
        .rst(rst),
        .sample_in(sample_in),
        .threshold(threshold),
        .peak_out(peak_out)
    );

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    task apply_cycle;
        input integer cycle_idx;
        begin
            rst = test_rst[cycle_idx];
            sample_in = test_sample[cycle_idx];
            threshold = test_threshold[cycle_idx];

            @(posedge clk);
            #1;

            tests_run = tests_run + 1;

            if (peak_out !== expected_peak[cycle_idx]) begin
                $display("ERROR [%s cycle %0d]: peak_out mismatch. expected=%b got=%b rst=%b sample=%0d threshold=%0d",
                         test_tag[cycle_idx], cycle_idx, expected_peak[cycle_idx], peak_out,
                         rst, sample_in, threshold);
                errors = errors + 1;
            end
        end
    endtask

    initial begin
        rst = 1'b0;
        sample_in = 12'd0;
        threshold = 12'd0;

        $display("===========================================");
        $display("      ECG R-Peak Detector Testbench");
        $display("===========================================");

        for (i = 0; i < NUM_TESTS; i = i + 1)
            apply_cycle(i);

        $display("");
        $display("===========================================");
        $display("  Tests Run: %0d", tests_run);
        $display("===========================================");

        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
        end

        $finish;
    end
endmodule
"""


def main():
    base_dir = Path(__file__).resolve().parent
    vectors = generate_vectors()
    testbench_path = base_dir / "testbench.v"
    testbench_path.write_text(build_testbench(vectors), encoding="ascii")

    pulse_cycles = [idx for idx, vec in enumerate(vectors) if vec["expected_peak"]]
    print(f"Generated {len(vectors)} test cycles.")
    print("Expected peak_out pulses at cycles:", ", ".join(str(cycle) for cycle in pulse_cycles))


if __name__ == "__main__":
    main()
