#!/usr/bin/env python3
"""Generate cycle-accurate golden vectors for timestamp_delta_encoder."""

from __future__ import annotations

from dataclasses import dataclass


LATENCY = 1


@dataclass(frozen=True)
class Cycle:
    tag: str
    rst: int
    in_valid: int
    in_timestamp: int


def classify_payload(has_prev: bool, prev_timestamp: int, timestamp: int) -> tuple[int, int]:
    if not has_prev:
        return 0b10, timestamp & 0xFFFF_FFFF

    delta = (timestamp - prev_timestamp) & 0xFFFF_FFFF
    if delta <= 0xFF:
        return 0b00, delta
    if delta <= 0xFFFF:
        return 0b01, delta
    return 0b10, delta


def build_sequence() -> list[Cycle]:
    return [
        Cycle("startup_reset", 1, 0, 0x0000_0000),
        Cycle("reset_ignored_input", 1, 1, 0x1234_5678),
        Cycle("post_reset_idle", 0, 0, 0x0000_0000),
        Cycle("first_sample", 0, 1, 0x0000_1000),
        Cycle("zero_delta", 0, 1, 0x0000_1000),
        Cycle("delta_255", 0, 1, 0x0000_10FF),
        Cycle("delta_256", 0, 1, 0x0000_11FF),
        Cycle("delta_65535", 0, 1, 0x0001_11FE),
        Cycle("delta_65536", 0, 1, 0x0002_11FE),
        Cycle("idle_gap", 0, 0, 0x0000_0000),
        Cycle("small_after_gap", 0, 1, 0x0002_1210),
        Cycle("flush_candidate", 0, 1, 0x0003_0000),
        Cycle("flush_reset", 1, 0, 0x0000_0000),
        Cycle("post_flush_idle", 0, 0, 0x0000_0000),
        Cycle("post_reset_first", 0, 1, 0xABCD_0000),
        Cycle("post_reset_zero", 0, 1, 0xABCD_0000),
        Cycle("rand_small_5", 0, 1, 0xABCD_0005),
        Cycle("rand_medium_288", 0, 1, 0xABCD_0125),
        Cycle("rand_zero", 0, 1, 0xABCD_0125),
        Cycle("rand_full_65536", 0, 1, 0xABCE_0125),
        Cycle("rand_idle_gap", 0, 0, 0x0000_0000),
        Cycle("rand_small_17", 0, 1, 0xABCE_0136),
        Cycle("rand_medium_65535", 0, 1, 0xABCF_0135),
        Cycle("rand_small_3", 0, 1, 0xABCF_0138),
        Cycle("rand_full_70000", 0, 1, 0xABD0_12A8),
        Cycle("drain_last", 0, 0, 0x0000_0000),
        Cycle("final_idle", 0, 0, 0x0000_0000),
    ]


def simulate(sequence: list[Cycle]) -> list[tuple[int, int, int]]:
    has_prev = False
    prev_timestamp = 0
    pending: dict[int, tuple[int, int]] = {}
    outputs: list[tuple[int, int, int]] = []

    for cycle_idx, cycle in enumerate(sequence):
        if cycle.rst:
            has_prev = False
            prev_timestamp = 0
            pending.clear()
            outputs.append((0, 0b00, 0))
            continue

        if cycle_idx in pending:
            out_tag, out_value = pending.pop(cycle_idx)
            outputs.append((1, out_tag, out_value))
        else:
            outputs.append((0, 0b00, 0))

        if cycle.in_valid:
            out_tag, out_value = classify_payload(has_prev, prev_timestamp, cycle.in_timestamp)
            pending[cycle_idx + LATENCY] = (out_tag, out_value)
            has_prev = True
            prev_timestamp = cycle.in_timestamp & 0xFFFF_FFFF

    return outputs


def emit_verilog(sequence: list[Cycle], outputs: list[tuple[int, int, int]]) -> str:
    lines = []
    lines.append("// Generated by generate_golden.py and copied here. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(sequence)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_in_valid [0:NUM_TESTS-1];")
    lines.append("reg [31:0] test_in_timestamp [0:NUM_TESTS-1];")
    lines.append("reg expected_out_valid [0:NUM_TESTS-1];")
    lines.append("reg [1:0] expected_out_tag [0:NUM_TESTS-1];")
    lines.append("reg [31:0] expected_out_value [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")
    for idx, cycle in enumerate(sequence):
        out_valid, out_tag, out_value = outputs[idx]
        lines.append(f"    test_rst[{idx}] = 1'b{cycle.rst};")
        lines.append(f"    test_in_valid[{idx}] = 1'b{cycle.in_valid};")
        lines.append(f"    test_in_timestamp[{idx}] = 32'h{cycle.in_timestamp:08x};")
        lines.append(f"    expected_out_valid[{idx}] = 1'b{out_valid};")
        lines.append(f"    expected_out_tag[{idx}] = 2'b{out_tag:02b};")
        lines.append(f"    expected_out_value[{idx}] = 32'h{out_value:08x};")
        lines.append(f'    test_tag[{idx}] = "{cycle.tag}";')
    lines.append("end")
    return "\n".join(lines)


def self_check() -> None:
    assert classify_payload(False, 0, 0x0000_1000) == (0b10, 0x0000_1000)
    assert classify_payload(True, 0x0000_1000, 0x0000_1000) == (0b00, 0x0000_0000)
    assert classify_payload(True, 0x0000_1000, 0x0000_10FF) == (0b00, 0x0000_00FF)
    assert classify_payload(True, 0x0000_10FF, 0x0000_11FF) == (0b01, 0x0000_0100)
    assert classify_payload(True, 0x0001_11FE, 0x0002_11FE) == (0b10, 0x0001_0000)

    sequence = build_sequence()
    outputs = simulate(sequence)
    checks = {sequence[idx].tag: outputs[idx] for idx in range(len(sequence))}
    assert checks["startup_reset"] == (0, 0b00, 0)
    assert checks["zero_delta"] == (1, 0b10, 0x0000_1000)
    assert checks["delta_255"] == (1, 0b00, 0x0000_0000)
    assert checks["delta_256"] == (1, 0b00, 0x0000_00FF)
    assert checks["delta_65535"] == (1, 0b01, 0x0000_0100)
    assert checks["delta_65536"] == (1, 0b01, 0x0000_FFFF)
    assert checks["flush_reset"] == (0, 0b00, 0)
    assert checks["post_reset_zero"] == (1, 0b10, 0xABCD_0000)
    assert checks["rand_medium_288"] == (1, 0b00, 0x0000_0005)
    assert checks["rand_full_70000"] == (1, 0b00, 0x0000_0003)
    assert checks["drain_last"] == (1, 0b10, 0x0001_1170)
    assert checks["final_idle"] == (0, 0b00, 0)


def main() -> None:
    self_check()
    sequence = build_sequence()
    outputs = simulate(sequence)
    print(emit_verilog(sequence, outputs))


if __name__ == "__main__":
    main()
