#!/usr/bin/env python3
"""Generate cycle-accurate golden vectors for sort8_ranker."""

from __future__ import annotations

from dataclasses import dataclass
import random


LATENCY = 6


@dataclass(frozen=True)
class Cycle:
    tag: str
    rst: int
    in_valid: int
    batch: tuple[int, ...] | None


def pack_values(values: list[int]) -> int:
    packed = 0
    for lane, value in enumerate(values):
        packed |= (value & 0xFFFF) << (16 * lane)
    return packed


def pack_indices(indices: list[int]) -> int:
    packed = 0
    for lane, index in enumerate(indices):
        packed |= (index & 0x7) << (3 * lane)
    return packed


def sort_batch(values: tuple[int, ...]) -> tuple[int, int]:
    ordered = sorted(enumerate(values), key=lambda item: (item[1], item[0]))
    sorted_values = [value for _, value in ordered]
    sorted_indices = [index for index, _ in ordered]
    return pack_values(sorted_values), pack_indices(sorted_indices)


def build_sequence() -> list[Cycle]:
    seq: list[Cycle] = [
        Cycle("startup_reset", 1, 0, None),
        Cycle("reset_ign_valid", 1, 1, (1, 2, 3, 4, 5, 6, 7, 8)),
        Cycle("post_reset_idle", 0, 0, None),
        Cycle("ascending", 0, 1, (-20, -5, -1, 0, 7, 9, 12, 30)),
        Cycle("descending", 0, 1, (30, 12, 9, 7, 0, -1, -5, -20)),
        Cycle("all_equal", 0, 1, (4, 4, 4, 4, 4, 4, 4, 4)),
        Cycle("alt_extreme", 0, 1, (-32768, 32767, -32768, 32767, -32768, 32767, -32768, 32767)),
        Cycle("input_bubble", 0, 0, None),
        Cycle("dup_mix", 0, 1, (5, -1, 5, 3, -1, 8, 3, 0)),
        Cycle("signed_mix", 0, 1, (32767, -32768, 1, -1, 12345, -12345, 0, 32767)),
        Cycle("negative_mix", 0, 1, (-8, -2, -5, -2, -9, -1, -7, -3)),
        Cycle("late_burst0", 0, 1, (15, 15, -4, 99, -4, 15, 0, -100)),
        Cycle("flush_reset1", 1, 0, None),
        Cycle("idle_after_r1", 0, 0, None),
        Cycle("burst2_0", 0, 1, (9, 4, 7, 1, 3, 2, 8, 6)),
        Cycle("burst2_1", 0, 1, (-16, 5, -16, 5, -16, 5, 0, 0)),
        Cycle("burst2_2", 0, 1, (100, -100, 50, -50, 25, -25, 0, 75)),
        Cycle("burst2_3", 0, 1, (11, 10, 9, 8, 7, 6, 5, 4)),
        Cycle("burst2_4", 0, 1, (1, 32767, 2, 32766, 3, -32768, 4, -32767)),
        Cycle("burst2_5", 0, 1, (42, 42, 41, 41, 40, 40, 39, 39)),
    ]

    rng = random.Random(43043)
    dup_pool = [-32768, -255, -7, -1, 0, 1, 5, 42, 99, 32767]

    def rand_batch(bias_duplicates: bool) -> tuple[int, ...]:
        values: list[int] = []
        for _ in range(8):
            if bias_duplicates and rng.random() < 0.65:
                values.append(rng.choice(dup_pool))
            else:
                values.append(rng.randint(-32768, 32767))
        return tuple(values)

    for idx in range(4):
        seq.append(Cycle(f"discard_rand{idx}", 0, 1, rand_batch(True)))

    seq.append(Cycle("reset_during_out", 1, 1, rand_batch(False)))
    seq.append(Cycle("idle_after_r2", 0, 0, None))
    seq.append(Cycle("post3_dup", 0, 1, (6, 1, 6, 1, 6, 1, 6, 1)))
    seq.append(Cycle("post3_mix", 0, 1, (-300, 200, -100, 0, 100, 200, -300, 50)))

    for idx in range(4):
        seq.append(Cycle(f"keep_rand{idx}", 0, 1, rand_batch(idx % 2 == 0)))

    seq.extend(
        [
            Cycle("flush_idle0", 0, 0, None),
            Cycle("flush_idle1", 0, 0, None),
            Cycle("flush_idle2", 0, 0, None),
            Cycle("flush_idle3", 0, 0, None),
            Cycle("flush_idle4", 0, 0, None),
            Cycle("flush_idle5", 0, 0, None),
            Cycle("flush_idle6", 0, 0, None),
            Cycle("flush_idle7", 0, 0, None),
        ]
    )
    return seq


def simulate(sequence: list[Cycle]) -> list[tuple[int, int, int]]:
    pending: dict[int, tuple[int, int]] = {}
    outputs: list[tuple[int, int, int]] = []

    for cycle_idx, cycle in enumerate(sequence):
        if cycle.rst:
            outputs.append((0, 0, 0))
            pending.clear()
            continue

        if cycle_idx in pending:
            out_data, out_index = pending.pop(cycle_idx)
            outputs.append((1, out_data, out_index))
        else:
            outputs.append((0, 0, 0))

        if cycle.in_valid:
            assert cycle.batch is not None
            pending[cycle_idx + LATENCY] = sort_batch(cycle.batch)

    return outputs


def emit_verilog(sequence: list[Cycle], outputs: list[tuple[int, int, int]]) -> str:
    lines = []
    lines.append("// Generated by generate_golden.py. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(sequence)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_in_valid [0:NUM_TESTS-1];")
    lines.append("reg [127:0] test_in_data [0:NUM_TESTS-1];")
    lines.append("reg expected_out_valid [0:NUM_TESTS-1];")
    lines.append("reg [127:0] expected_out_data [0:NUM_TESTS-1];")
    lines.append("reg [23:0] expected_out_index [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")

    for idx, cycle in enumerate(sequence):
        in_data = 0 if cycle.batch is None else pack_values(list(cycle.batch))
        out_valid, out_data, out_index = outputs[idx]
        lines.append(f"    test_rst[{idx}] = 1'b{cycle.rst};")
        lines.append(f"    test_in_valid[{idx}] = 1'b{cycle.in_valid};")
        lines.append(f"    test_in_data[{idx}] = 128'h{in_data:032x};")
        lines.append(f"    expected_out_valid[{idx}] = 1'b{out_valid};")
        lines.append(f"    expected_out_data[{idx}] = 128'h{out_data:032x};")
        lines.append(f"    expected_out_index[{idx}] = 24'h{out_index:06x};")
        lines.append(f'    test_tag[{idx}] = "{cycle.tag}";')
    lines.append("end")
    return "\n".join(lines)


def main() -> None:
    sequence = build_sequence()
    outputs = simulate(sequence)
    print(emit_verilog(sequence, outputs))


if __name__ == "__main__":
    main()
