#!/usr/bin/env python3
"""Generate cycle-accurate golden vectors for the streaming 5-point stencil benchmark."""

from __future__ import annotations

from dataclasses import dataclass
import random


TILE_ROWS = 8
TILE_COLS = 8


@dataclass(frozen=True)
class Cycle:
    tag: str
    rst: int
    in_valid: int
    tile_start: int
    sample: int | None


def to_twos(value: int, width: int) -> int:
    return value & ((1 << width) - 1)


def constant_tile(value: int) -> list[list[int]]:
    return [[value for _ in range(TILE_COLS)] for _ in range(TILE_ROWS)]


def affine_tile(a_r: int, a_c: int, bias: int) -> list[list[int]]:
    return [[a_r * r + a_c * c + bias for c in range(TILE_COLS)] for r in range(TILE_ROWS)]


def impulse_tile(row: int, col: int, value: int) -> list[list[int]]:
    tile = constant_tile(0)
    tile[row][col] = value
    return tile


def border_activity_tile() -> list[list[int]]:
    tile = constant_tile(0)
    for c in range(TILE_COLS):
        tile[0][c] = 40 - 3 * c
    for r in range(TILE_ROWS):
        tile[r][7] = -20 + 5 * r
    tile[7][1] = 27
    tile[6][0] = -31
    return tile


def extreme_tile() -> list[list[int]]:
    return [[127 if ((r + c) % 2 == 0) else -128 for c in range(TILE_COLS)] for r in range(TILE_ROWS)]


def random_tile(seed: int) -> list[list[int]]:
    rng = random.Random(seed)
    return [[rng.randint(-128, 127) for _ in range(TILE_COLS)] for _ in range(TILE_ROWS)]


def append_idle(seq: list[Cycle], tag: str, count: int = 1) -> None:
    for idx in range(count):
        seq.append(Cycle(f"{tag}_idle{idx}", 0, 0, 0, None))


def append_reset(seq: list[Cycle], tag: str, count: int = 1, ignored_sample: int | None = None) -> None:
    for idx in range(count):
        in_valid = 1 if (ignored_sample is not None and idx == 0) else 0
        tile_start = 1 if in_valid else 0
        sample = ignored_sample if in_valid else None
        seq.append(Cycle(f"{tag}_rst{idx}", 1, in_valid, tile_start, sample))


def append_tile(
    seq: list[Cycle],
    tile: list[list[int]],
    tag: str,
    bubble_before: set[int] | None = None,
    accepted_limit: int | None = None,
) -> None:
    accepted = 0
    bubble_before = bubble_before or set()
    for r in range(TILE_ROWS):
        for c in range(TILE_COLS):
            if accepted_limit is not None and accepted >= accepted_limit:
                return
            while accepted in bubble_before:
                seq.append(Cycle(f"{tag}_gap{accepted}", 0, 0, 0, None))
                bubble_before.remove(accepted)
            seq.append(
                Cycle(
                    tag=f"{tag}_{accepted:02d}",
                    rst=0,
                    in_valid=1,
                    tile_start=1 if accepted == 0 else 0,
                    sample=tile[r][c],
                )
            )
            accepted += 1


def stencil_value(tile: list[list[int]], center_r: int, center_c: int) -> int:
    return (
        tile[center_r - 1][center_c]
        + tile[center_r + 1][center_c]
        + tile[center_r][center_c - 1]
        + tile[center_r][center_c + 1]
        - 4 * tile[center_r][center_c]
    )


def expected_outputs_for_tile(tile: list[list[int]]) -> list[int]:
    outputs: list[int] = []
    for center_r in range(1, 7):
        for center_c in range(1, 7):
            outputs.append(stencil_value(tile, center_r, center_c))
    return outputs


def build_sequence() -> list[Cycle]:
    seq: list[Cycle] = []

    append_reset(seq, "startup", count=1)
    append_reset(seq, "reset_ignore", count=1, ignored_sample=91)
    append_idle(seq, "post_reset", count=1)

    append_tile(seq, constant_tile(23), "const23")
    append_tile(seq, affine_tile(3, -2, 5), "affine")
    append_tile(seq, impulse_tile(3, 4, 7), "posimp", bubble_before={5, 17, 34, 51})
    append_tile(seq, impulse_tile(2, 2, -9), "negimp")
    append_tile(seq, border_activity_tile(), "border")
    append_tile(seq, extreme_tile(), "flush_partial", accepted_limit=22)

    seq.append(Cycle("flush_reset", 1, 0, 0, None))
    append_idle(seq, "post_flush", count=1)

    append_tile(seq, random_tile(52052), "random0")
    append_idle(seq, "drain", count=3)

    return seq


def simulate(sequence: list[Cycle]) -> list[tuple[int, int]]:
    tile = [[0 for _ in range(TILE_COLS)] for _ in range(TILE_ROWS)]
    sample_idx = 0
    pending: dict[int, int] = {}
    outputs: list[tuple[int, int]] = []

    for cycle_idx, cycle in enumerate(sequence):
        if cycle.rst:
            sample_idx = 0
            pending.clear()
            tile = [[0 for _ in range(TILE_COLS)] for _ in range(TILE_ROWS)]
            outputs.append((0, 0))
            continue

        outputs.append((1, pending.pop(cycle_idx))) if cycle_idx in pending else outputs.append((0, 0))

        if not cycle.in_valid:
            continue

        if cycle.tile_start:
            sample_idx = 0
            tile = [[0 for _ in range(TILE_COLS)] for _ in range(TILE_ROWS)]

        assert cycle.sample is not None
        row = sample_idx // TILE_COLS
        col = sample_idx % TILE_COLS
        tile[row][col] = cycle.sample

        if row >= 2 and col >= 2:
            center_r = row - 1
            center_c = col - 1
            pending[cycle_idx + 1] = stencil_value(tile, center_r, center_c)

        sample_idx = 0 if sample_idx == (TILE_ROWS * TILE_COLS - 1) else sample_idx + 1

    return outputs


def emit_verilog(sequence: list[Cycle], outputs: list[tuple[int, int]]) -> str:
    lines: list[str] = []
    lines.append("// Generated by generate_golden.py. Do not edit by hand.")
    lines.append(f"localparam integer NUM_TESTS = {len(sequence)};")
    lines.append("reg test_rst [0:NUM_TESTS-1];")
    lines.append("reg test_in_valid [0:NUM_TESTS-1];")
    lines.append("reg test_tile_start [0:NUM_TESTS-1];")
    lines.append("reg signed [7:0] test_sample_in [0:NUM_TESTS-1];")
    lines.append("reg expected_out_valid [0:NUM_TESTS-1];")
    lines.append("reg signed [10:0] expected_stencil_out [0:NUM_TESTS-1];")
    lines.append("reg [255:0] test_tag [0:NUM_TESTS-1];")
    lines.append("")
    lines.append("initial begin")

    for idx, (cycle, (exp_valid, exp_value)) in enumerate(zip(sequence, outputs)):
        sample = 0 if cycle.sample is None else cycle.sample
        lines.append(f"    test_rst[{idx}] = 1'b{cycle.rst};")
        lines.append(f"    test_in_valid[{idx}] = 1'b{cycle.in_valid};")
        lines.append(f"    test_tile_start[{idx}] = 1'b{cycle.tile_start};")
        lines.append(f"    test_sample_in[{idx}] = 8'sh{to_twos(sample, 8):02x};")
        lines.append(f"    expected_out_valid[{idx}] = 1'b{exp_valid};")
        lines.append(f"    expected_stencil_out[{idx}] = 11'sh{to_twos(exp_value, 11):03x};")
        lines.append(f'    test_tag[{idx}] = "{cycle.tag}";')
    lines.append("end")
    return "\n".join(lines)


def self_check() -> None:
    const_outputs = expected_outputs_for_tile(constant_tile(23))
    assert len(const_outputs) == 36
    assert all(value == 0 for value in const_outputs)

    plane_outputs = expected_outputs_for_tile(affine_tile(3, -2, 5))
    assert all(value == 0 for value in plane_outputs)

    impulse_outputs = expected_outputs_for_tile(impulse_tile(3, 4, 7))
    nonzero = {idx: value for idx, value in enumerate(impulse_outputs) if value != 0}
    expected_nonzero = {
        (2 * 6 + 3): -28,
        (1 * 6 + 3): 7,
        (3 * 6 + 3): 7,
        (2 * 6 + 2): 7,
        (2 * 6 + 4): 7,
    }
    assert nonzero == expected_nonzero


def main() -> None:
    self_check()
    sequence = build_sequence()
    outputs = simulate(sequence)
    print(emit_verilog(sequence, outputs))


if __name__ == "__main__":
    main()
