#!/usr/bin/env python3
"""Generate cycle-accurate golden vectors for systolic_mac2x2."""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class Cycle:
    tag: str
    rst: int
    weight_load: int
    weights: tuple[int, int, int, int] | None
    in_valid: int
    in_data: int


def pack_signed(value: int, width: int) -> int:
    return value & ((1 << width) - 1)


def pack_weights(weights: tuple[int, int, int, int] | None) -> int:
    if weights is None:
        return 0
    packed = 0
    for idx, value in enumerate(weights):
        packed |= pack_signed(value, 8) << (8 * idx)
    return packed


def build_sequence() -> list[Cycle]:
    return [
        Cycle("startup_reset", 1, 0, None, 0, 0),
        Cycle("reset_ignore_both", 1, 1, (9, -3, 5, 7), 1, 85),
        Cycle("post_reset_idle", 0, 0, None, 0, 0),
        Cycle("load_identity_ignore_input", 0, 1, (1, 0, 0, 1), 1, 99),
        Cycle("id_vec0_x0", 0, 0, None, 1, 7),
        Cycle("id_vec0_x1", 0, 0, None, 1, -3),
        Cycle("id_vec1_x0", 0, 0, None, 1, 1),
        Cycle("id_vec1_x1", 0, 0, None, 1, 2),
        Cycle("id_drain0", 0, 0, None, 0, 0),
        Cycle("id_drain1", 0, 0, None, 0, 0),
        Cycle("load_dense", 0, 1, (2, 3, 4, 5), 0, 0),
        Cycle("dense_vec0_x0", 0, 0, None, 1, 1),
        Cycle("dense_vec0_x1", 0, 0, None, 1, 2),
        Cycle("dense_vec1_x0", 0, 0, None, 1, 3),
        Cycle("dense_vec1_x1", 0, 0, None, 1, -1),
        Cycle("partial_before_reset", 0, 0, None, 1, 9),
        Cycle("reset_flush_pending", 1, 1, (-7, 6, 5, -4), 1, -8),
        Cycle("idle_after_reset", 0, 0, None, 0, 0),
        Cycle("zero_weight_vec_x0", 0, 0, None, 1, 4),
        Cycle("zero_weight_vec_x1", 0, 0, None, 1, 5),
        Cycle("zero_weight_drain0", 0, 0, None, 0, 0),
        Cycle("zero_weight_drain1", 0, 0, None, 0, 0),
        Cycle("load_mixed_ignore_input", 0, 1, (3, -2, -1, 4), 1, 77),
        Cycle("mixed_vec0_x0", 0, 0, None, 1, -5),
        Cycle("mixed_vec0_x1", 0, 0, None, 1, 6),
        Cycle("mixed_vec1_x0", 0, 0, None, 1, 127),
        Cycle("mixed_vec1_x1", 0, 0, None, 1, -128),
        Cycle("mixed_drain0", 0, 0, None, 0, 0),
        Cycle("mixed_drain1", 0, 0, None, 0, 0),
        Cycle("load_fixed_ignore_input", 0, 1, (7, -8, 9, -10), 1, -44),
        Cycle("fixed_vec0_x0", 0, 0, None, 1, 11),
        Cycle("fixed_vec0_x1", 0, 0, None, 1, -12),
        Cycle("fixed_vec1_x0", 0, 0, None, 1, -13),
        Cycle("fixed_vec1_x1", 0, 0, None, 1, 14),
        Cycle("fixed_vec2_x0", 0, 0, None, 1, 23),
        Cycle("fixed_vec2_x1", 0, 0, None, 1, -17),
        Cycle("fixed_vec3_x0", 0, 0, None, 1, -31),
        Cycle("fixed_vec3_x1", 0, 0, None, 1, 29),
        Cycle("flush_idle0", 0, 0, None, 0, 0),
        Cycle("flush_idle1", 0, 0, None, 0, 0),
        Cycle("flush_idle2", 0, 0, None, 0, 0),
        Cycle("flush_idle3", 0, 0, None, 0, 0),
    ]


def mac(weights: tuple[int, int, int, int], x0: int, x1: int) -> tuple[int, int]:
    w00, w01, w10, w11 = weights
    return (w00 * x0 + w01 * x1, w10 * x0 + w11 * x1)


def simulate(sequence: list[Cycle]) -> list[tuple[int, int]]:
    weights = (0, 0, 0, 0)
    partial_x0: int | None = None
    pending: dict[int, int] = {}
    outputs: list[tuple[int, int]] = []

    for cycle_idx, cycle in enumerate(sequence):
        if cycle.rst:
            pending.clear()
            partial_x0 = None
            weights = (0, 0, 0, 0)
            outputs.append((0, 0))
            continue

        if cycle_idx in pending:
            outputs.append((1, pending.pop(cycle_idx)))
        else:
            outputs.append((0, 0))

        if cycle.weight_load:
            assert cycle.weights is not None
            weights = cycle.weights
            continue

        if not cycle.in_valid:
            continue

        if partial_x0 is None:
            partial_x0 = cycle.in_data
        else:
            y0, y1 = mac(weights, partial_x0, cycle.in_data)
            assert (cycle_idx + 1) not in pending
            assert (cycle_idx + 2) not in pending
            pending[cycle_idx + 1] = y0
            pending[cycle_idx + 2] = y1
            partial_x0 = None

    return outputs


def emit_verilog(sequence: list[Cycle], outputs: list[tuple[int, int]]) -> str:
    lines = []
    lines.append("// Generated by generate_golden.py and copied here. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(sequence)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_weight_load [0:NUM_TESTS-1];")
    lines.append("reg [31:0] test_weight_data [0:NUM_TESTS-1];")
    lines.append("reg test_in_valid [0:NUM_TESTS-1];")
    lines.append("reg [7:0] test_in_data [0:NUM_TESTS-1];")
    lines.append("reg expected_out_valid [0:NUM_TESTS-1];")
    lines.append("reg [16:0] expected_out_data [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")
    for idx, cycle in enumerate(sequence):
        out_valid, out_data = outputs[idx]
        lines.append(f"    test_rst[{idx}] = 1'b{cycle.rst};")
        lines.append(f"    test_weight_load[{idx}] = 1'b{cycle.weight_load};")
        lines.append(f"    test_weight_data[{idx}] = 32'h{pack_weights(cycle.weights):08x};")
        lines.append(f"    test_in_valid[{idx}] = 1'b{cycle.in_valid};")
        lines.append(f"    test_in_data[{idx}] = 8'h{pack_signed(cycle.in_data, 8):02x};")
        lines.append(f"    expected_out_valid[{idx}] = 1'b{out_valid};")
        lines.append(f"    expected_out_data[{idx}] = 17'h{pack_signed(out_data, 17):05x};")
        lines.append(f'    test_tag[{idx}] = "{cycle.tag}";')
    lines.append("end")
    return "\n".join(lines)


def self_check() -> None:
    assert mac((1, 0, 0, 1), 7, -3) == (7, -3)
    assert mac((2, 3, 4, 5), 1, 2) == (8, 14)
    assert mac((3, -2, -1, 4), -5, 6) == (-27, 29)
    assert mac((7, -8, 9, -10), 23, -17) == (297, 377)


def main() -> None:
    self_check()
    sequence = build_sequence()
    outputs = simulate(sequence)
    print(emit_verilog(sequence, outputs))


if __name__ == "__main__":
    main()
