#!/usr/bin/env python3
"""Generate deterministic vectors for the median7 denoiser benchmark.

This helper is offline-only. It documents the directed sample sequences and
computes the expected `out_valid` / `sample_out` values for each cycle. The
generated values are hardcoded directly into `testbench.v`; the Verilog
testbench does not execute Python at runtime.
"""

from __future__ import annotations

from statistics import median


def append_cycle(vectors, rst, sample, tag):
    vectors.append(
        {
            "rst": int(rst),
            "sample": int(sample),
            "tag": tag,
        }
    )


def append_reset(vectors, cycles, tag):
    for _ in range(cycles):
        append_cycle(vectors, 1, 0, tag)


def append_sequence(vectors, samples, tag):
    for sample in samples:
        append_cycle(vectors, 0, sample, tag)


def signed_literal(value: int) -> str:
    return f"12'sh{value & 0xFFF:03x}"


def compute_expected(vectors):
    history = []
    enriched = []

    for vec in vectors:
        if vec["rst"]:
            expected_valid = 0
            expected_sample = 0
            history = []
        else:
            if len(history) == 7:
                expected_valid = 1
                expected_sample = int(median(sorted(history)))
            else:
                expected_valid = 0
                expected_sample = 0

            history.append(vec["sample"])
            if len(history) > 7:
                history.pop(0)

        enriched.append(
            {
                **vec,
                "expected_valid": expected_valid,
                "expected_sample": expected_sample,
            }
        )

    return enriched


def generate_vectors():
    vectors = []

    append_reset(vectors, 2, "startup_reset")
    append_sequence(
        vectors,
        [-9, -4, 0, 3, 8, 12, 20, 25, 30],
        "startup_monotonic",
    )

    append_reset(vectors, 1, "pos_impulse_reset")
    append_sequence(
        vectors,
        [10, 11, 12, 300, 13, 11, 10, 10, 9],
        "pos_impulse",
    )

    append_reset(vectors, 1, "neg_impulse_reset")
    append_sequence(
        vectors,
        [-8, -7, -6, -200, -5, -7, -6, -6, -5],
        "neg_impulse",
    )

    append_reset(vectors, 1, "signed_dup_reset")
    append_sequence(
        vectors,
        [2047, -2048, 4, 4, -2, 4, 9, 7, -30],
        "signed_dup",
    )

    append_reset(vectors, 1, "sliding_reset")
    append_sequence(
        vectors,
        [5, 1, 9, 3, 7, 11, 13, 15, -5, 4, 6, 8],
        "sliding_windows",
    )

    append_reset(vectors, 1, "midstream_reset_begin")
    append_sequence(
        vectors,
        [20, 21, 22, 23, 24, 25, 26, 27],
        "midstream_before_reset",
    )
    append_reset(vectors, 1, "midstream_reset_assert")
    append_sequence(
        vectors,
        [-1, -2, -3, -4, -5, -6, -7, -8],
        "midstream_after_reset",
    )

    return compute_expected(vectors)


def build_array_block(vectors):
    lines = []
    lines.append("// Generated by generate_golden.py. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(vectors)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg signed [11:0] test_sample [0:NUM_TESTS-1];")
    lines.append("reg expected_valid [0:NUM_TESTS-1];")
    lines.append("reg signed [11:0] expected_sample [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")
    for idx, vec in enumerate(vectors):
        lines.append(f"    test_rst[{idx}] = 1'b{vec['rst']};")
        lines.append(f"    test_sample[{idx}] = {signed_literal(vec['sample'])};")
        lines.append(f"    expected_valid[{idx}] = 1'b{vec['expected_valid']};")
        lines.append(
            f"    expected_sample[{idx}] = {signed_literal(vec['expected_sample'])};"
        )
        lines.append(f'    test_tag[{idx}] = "{vec["tag"]}";')
    lines.append("end")
    return "\n".join(lines)


def main():
    vectors = generate_vectors()
    print(build_array_block(vectors))


if __name__ == "__main__":
    main()
