#!/usr/bin/env python3
"""Generate cycle-accurate golden vectors for prefix_scan8."""

from __future__ import annotations

from dataclasses import dataclass
import random


LATENCY = 3


@dataclass(frozen=True)
class Cycle:
    tag: str
    rst: int
    in_valid: int
    batch: tuple[int, ...] | None


def pack_signed(values: list[int], width: int) -> int:
    mask = (1 << width) - 1
    packed = 0
    for lane, value in enumerate(values):
        packed |= (value & mask) << (width * lane)
    return packed


def prefix_scan(values: tuple[int, ...]) -> int:
    running = 0
    outputs: list[int] = []
    for value in values:
        running += value
        outputs.append(running)
    return pack_signed(outputs, 15)


def build_sequence() -> list[Cycle]:
    seq: list[Cycle] = [
        Cycle("startup_reset", 1, 0, None),
        Cycle("reset_ign_valid", 1, 1, (100, -50, 25, -10, 5, -2, 1, -1)),
        Cycle("post_reset_idle", 0, 0, None),
        Cycle("basic_inc", 0, 1, (1, 2, 3, 4, 5, 6, 7, 8)),
        Cycle("mixed_sign", 0, 1, (-3, 7, -2, 1, -8, 4, 0, 5)),
        Cycle("all_zero", 0, 1, (0, 0, 0, 0, 0, 0, 0, 0)),
        Cycle("alt_extreme", 0, 1, (2047, -2048, 2047, -2048, 2047, -2048, 2047, -2048)),
        Cycle("input_bubble", 0, 0, None),
        Cycle("dup_pos", 0, 1, (5, 5, 5, 5, 5, 5, 5, 5)),
        Cycle("neg_sweep", 0, 1, (-1, -2, -3, -4, -5, -6, -7, -8)),
        Cycle("single_hot_tail", 0, 1, (0, 0, 0, 0, 0, 0, 0, 13)),
        Cycle("pre_reset_pending", 0, 1, (17, -4, 9, -3, 12, -8, 5, 1)),
        Cycle("flush_reset", 1, 0, None),
        Cycle("idle_after_reset", 0, 0, None),
        Cycle("post_reset_simple", 0, 1, (10, -10, 20, -20, 30, -30, 40, -40)),
        Cycle("post_reset_back2", 0, 1, (100, 200, -50, 25, -75, 10, 5, -15)),
        Cycle("post_reset_gap", 0, 0, None),
    ]

    rng = random.Random(45045)
    for idx in range(3):
        batch = tuple(rng.randint(-2048, 2047) for _ in range(8))
        seq.append(Cycle(f"rand{idx}", 0, 1, batch))

    seq.extend(
        [
            Cycle("flush_idle0", 0, 0, None),
            Cycle("flush_idle1", 0, 0, None),
            Cycle("flush_idle2", 0, 0, None),
            Cycle("flush_idle3", 0, 0, None),
            Cycle("flush_idle4", 0, 0, None),
        ]
    )
    return seq


def simulate(sequence: list[Cycle]) -> list[tuple[int, int]]:
    pending: dict[int, int] = {}
    outputs: list[tuple[int, int]] = []

    for cycle_idx, cycle in enumerate(sequence):
        if cycle.rst:
            pending.clear()
            outputs.append((0, 0))
            continue

        outputs.append((1, pending.pop(cycle_idx))) if cycle_idx in pending else outputs.append((0, 0))

        if cycle.in_valid:
            assert cycle.batch is not None
            pending[cycle_idx + LATENCY] = prefix_scan(cycle.batch)

    return outputs


def emit_verilog(sequence: list[Cycle], outputs: list[tuple[int, int]]) -> str:
    lines = []
    lines.append("// Generated by generate_golden.py and copied here. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(sequence)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_in_valid [0:NUM_TESTS-1];")
    lines.append("reg [95:0] test_in_data [0:NUM_TESTS-1];")
    lines.append("reg expected_out_valid [0:NUM_TESTS-1];")
    lines.append("reg [119:0] expected_out_data [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")

    for idx, cycle in enumerate(sequence):
        in_data = 0 if cycle.batch is None else pack_signed(list(cycle.batch), 12)
        out_valid, out_data = outputs[idx]
        lines.append(f"    test_rst[{idx}] = 1'b{cycle.rst};")
        lines.append(f"    test_in_valid[{idx}] = 1'b{cycle.in_valid};")
        lines.append(f"    test_in_data[{idx}] = 96'h{in_data:024x};")
        lines.append(f"    expected_out_valid[{idx}] = 1'b{out_valid};")
        lines.append(f"    expected_out_data[{idx}] = 120'h{out_data:030x};")
        lines.append(f'    test_tag[{idx}] = "{cycle.tag}";')
    lines.append("end")
    return "\n".join(lines)


def self_check() -> None:
    assert prefix_scan((1, 2, 3, 4, 5, 6, 7, 8)) == pack_signed([1, 3, 6, 10, 15, 21, 28, 36], 15)
    assert prefix_scan((-3, 7, -2, 1, -8, 4, 0, 5)) == pack_signed([-3, 4, 2, 3, -5, -1, -1, 4], 15)
    assert prefix_scan((0, 0, 0, 0, 0, 0, 0, 13)) == pack_signed([0, 0, 0, 0, 0, 0, 0, 13], 15)


def main() -> None:
    self_check()
    sequence = build_sequence()
    outputs = simulate(sequence)
    print(emit_verilog(sequence, outputs))


if __name__ == "__main__":
    main()
