#!/usr/bin/env python3
"""Generate cycle-accurate golden vectors for streaming_dot_product8."""

from __future__ import annotations

from dataclasses import dataclass


LATENCY = 2


@dataclass(frozen=True)
class Cycle:
    tag: str
    rst: int
    in_valid: int
    a_data: int
    b_data: int
    last_in: int


def pack_signed(value: int, width: int) -> int:
    return value & ((1 << width) - 1)


def dot_product(pairs: tuple[tuple[int, int], ...]) -> int:
    return sum(a * b for a, b in pairs)


def emit_vector(prefix: str, pairs: tuple[tuple[int, int], ...]) -> list[Cycle]:
    emitted: list[Cycle] = []
    assert len(pairs) == 8
    for idx, (a_val, b_val) in enumerate(pairs):
        emitted.append(
            Cycle(
                tag=f"{prefix}_p{idx}",
                rst=0,
                in_valid=1,
                a_data=a_val,
                b_data=b_val,
                last_in=1 if idx == 7 else 0,
            )
        )
    return emitted


def build_sequence() -> list[Cycle]:
    seq: list[Cycle] = [
        Cycle("startup_reset", 1, 0, 0, 0, 0),
        Cycle("reset_ignore_input", 1, 1, 12, -3, 1),
        Cycle("post_reset_idle_last_ignored", 0, 0, 0, 0, 1),
    ]

    seq.extend(
        emit_vector(
            "vec0_basic",
            (
                (1, 8),
                (2, 7),
                (3, 6),
                (4, 5),
                (5, 4),
                (6, 3),
                (7, 2),
                (8, 1),
            ),
        )
    )

    seq.extend(
        emit_vector(
            "vec1_mixed",
            (
                (-3, 10),
                (4, -5),
                (0, 9),
                (-8, -2),
                (7, 3),
                (-1, 12),
                (15, -4),
                (-6, -7),
            ),
        )
    )

    seq.extend(
        emit_vector(
            "vec2_extreme",
            (
                (127, 127),
                (-128, 127),
                (64, -64),
                (-64, -64),
                (1, -128),
                (-1, -128),
                (50, -3),
                (-50, -3),
            ),
        )
    )

    seq.extend(
        [
            Cycle("partial_before_reset_p0", 0, 1, 9, -2, 0),
            Cycle("partial_before_reset_p1", 0, 1, -11, 5, 0),
            Cycle("mid_vector_reset", 1, 0, 0, 0, 0),
            Cycle("idle_after_mid_vector_reset", 0, 0, 0, 0, 0),
        ]
    )

    seq.extend(
        emit_vector(
            "vec3_flush_pending",
            (
                (10, 10),
                (20, 1),
                (-30, 2),
                (40, -3),
                (-50, -4),
                (60, 5),
                (-70, 6),
                (80, -7),
            ),
        )
    )

    seq.extend(
        [
            Cycle("reset_flush_pending_output", 1, 0, 0, 0, 0),
            Cycle("idle_after_flush_reset", 0, 0, 0, 0, 0),
        ]
    )

    seq.extend(
        emit_vector(
            "vec4_zero",
            (
                (0, 0),
                (0, 5),
                (12, 0),
                (0, -7),
                (-9, 0),
                (0, 1),
                (3, 0),
                (0, 0),
            ),
        )
    )

    seq.extend(
        emit_vector(
            "vec5_back_to_back",
            (
                (5, -9),
                (-4, -8),
                (3, -7),
                (-2, -6),
                (1, -5),
                (0, -4),
                (-1, -3),
                (2, -2),
            ),
        )
    )

    seq.extend(
        [
            Cycle("drain_idle0", 0, 0, 0, 0, 0),
            Cycle("drain_idle1", 0, 0, 0, 0, 0),
            Cycle("drain_idle2", 0, 0, 0, 0, 1),
        ]
    )

    return seq


def simulate(sequence: list[Cycle]) -> list[tuple[int, int]]:
    accum = 0
    pending: dict[int, int] = {}
    outputs: list[tuple[int, int]] = []

    for cycle_idx, cycle in enumerate(sequence):
        if cycle.rst:
            accum = 0
            pending.clear()
            outputs.append((0, 0))
            continue

        if cycle_idx in pending:
            outputs.append((1, pending.pop(cycle_idx)))
        else:
            outputs.append((0, 0))

        if not cycle.in_valid:
            continue

        accum += cycle.a_data * cycle.b_data
        if cycle.last_in:
            pending[cycle_idx + LATENCY] = accum
            accum = 0

    return outputs


def emit_verilog(sequence: list[Cycle], outputs: list[tuple[int, int]]) -> str:
    lines = []
    lines.append("// Generated by generate_golden.py and copied here. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(sequence)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_in_valid [0:NUM_TESTS-1];")
    lines.append("reg [7:0] test_a_data [0:NUM_TESTS-1];")
    lines.append("reg [7:0] test_b_data [0:NUM_TESTS-1];")
    lines.append("reg test_last_in [0:NUM_TESTS-1];")
    lines.append("reg expected_result_valid [0:NUM_TESTS-1];")
    lines.append("reg [31:0] expected_result [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")

    for idx, cycle in enumerate(sequence):
        out_valid, out_data = outputs[idx]
        lines.append(f"    test_rst[{idx}] = 1'b{cycle.rst};")
        lines.append(f"    test_in_valid[{idx}] = 1'b{cycle.in_valid};")
        lines.append(f"    test_a_data[{idx}] = 8'h{pack_signed(cycle.a_data, 8):02x};")
        lines.append(f"    test_b_data[{idx}] = 8'h{pack_signed(cycle.b_data, 8):02x};")
        lines.append(f"    test_last_in[{idx}] = 1'b{cycle.last_in};")
        lines.append(f"    expected_result_valid[{idx}] = 1'b{out_valid};")
        lines.append(f"    expected_result[{idx}] = 32'h{pack_signed(out_data, 32):08x};")
        lines.append(f'    test_tag[{idx}] = "{cycle.tag}";')

    lines.append("end")
    return "\n".join(lines)


def self_check() -> None:
    assert dot_product(((1, 8), (2, 7), (3, 6), (4, 5), (5, 4), (6, 3), (7, 2), (8, 1))) == 120
    assert dot_product(((-3, 10), (4, -5), (0, 9), (-8, -2), (7, 3), (-1, 12), (15, -4), (-6, -7))) == -43
    assert dot_product(((127, 127), (-128, 127), (64, -64), (-64, -64), (1, -128), (-1, -128), (50, -3), (-50, -3))) == -127
    assert dot_product(((5, -9), (-4, -8), (3, -7), (-2, -6), (1, -5), (0, -4), (-1, -3), (2, -2))) == -28


def main() -> None:
    self_check()
    sequence = build_sequence()
    outputs = simulate(sequence)
    print(emit_verilog(sequence, outputs))


if __name__ == "__main__":
    main()
