#!/usr/bin/env python3
"""Generate hardcoded testbench vectors for problem 051.

This helper is kept in the problem folder for auditability. The emitted
`run_step(...)` lines are intended to be copied directly into `testbench.v`.
"""

from dataclasses import dataclass


MASK_8 = (1 << 8) - 1
MASK_16 = (1 << 16) - 1
MASK_27 = (1 << 27) - 1


def to_signed(value: int, bits: int) -> int:
    mask = (1 << bits) - 1
    value &= mask
    if value & (1 << (bits - 1)):
        value -= 1 << bits
    return value


def to_hex(value: int, bits: int) -> str:
    width = (bits + 3) // 4
    return f"{value & ((1 << bits) - 1):0{width}x}"


@dataclass(frozen=True)
class Step:
    tag: str
    rst: int
    data_valid: int
    data_in: int
    coeff_wr_en: int
    coeff_addr: int
    coeff_data: int


class Model:
    def __init__(self):
        self.reset()

    def reset(self):
        self.active = [0] * 8
        self.staged = [0] * 8
        self.history = [0] * 7
        self.pending_valid = 0
        self.pending_data = 0
        self.pending_dirty = False
        self.frame_pos = 0

    def step(self, s: Step):
        if s.rst:
            self.reset()
            return 0, 0

        expected_valid = self.pending_valid
        expected_data = self.pending_data

        sample_accept = bool(s.data_valid)
        old_dirty = self.pending_dirty
        frame_start = sample_accept and self.frame_pos == 0
        coeffs = self.staged if (frame_start and old_dirty) else self.active

        next_pending_valid = 0
        next_pending_data = 0

        if sample_accept:
            samples = [to_signed(s.data_in, 8)] + self.history[:]
            acc = 0
            for idx in range(8):
                acc += coeffs[idx] * samples[idx]
            next_pending_valid = 1
            next_pending_data = to_signed(acc, 27)

        if frame_start and old_dirty:
            self.active = self.staged.copy()

        if s.coeff_wr_en:
            self.staged[s.coeff_addr] = to_signed(s.coeff_data, 16)

        if s.coeff_wr_en:
            self.pending_dirty = True
        elif frame_start and old_dirty:
            self.pending_dirty = False

        if sample_accept:
            self.history = [to_signed(s.data_in, 8)] + self.history[:6]
            self.frame_pos = 0 if self.frame_pos == 7 else self.frame_pos + 1

        self.pending_valid = next_pending_valid
        self.pending_data = next_pending_data

        return expected_valid, expected_data


def build_steps():
    return [
        Step("rst_boot_0", 1, 0, 0x00, 0, 0, 0x0000),
        Step("rst_ignore_both", 1, 1, 0x55, 1, 2, 0x1234),
        Step("post_reset_idle", 0, 0, 0x00, 0, 0, 0x0000),
        Step("pre_a0_temp", 0, 0, 0x00, 1, 0, 0x0005),
        Step("pre_a0_final", 0, 0, 0x00, 1, 0, 0x0001),
        Step("pre_a1_temp", 0, 0, 0x00, 1, 1, 0x000c),
        Step("pre_a1_final", 0, 0, 0x00, 1, 1, 0x0000),
        Step("pre_frame0_idle", 0, 0, 0x00, 0, 0, 0x0000),
        Step("f0_s0", 0, 1, 0x05, 0, 0, 0x0000),
        Step("f0_s1_wr_b0", 0, 1, 0xfd, 1, 0, 0x0002),
        Step("f0_idle_wr_b1_temp", 0, 0, 0x00, 1, 1, 0x0009),
        Step("f0_s2_wr_b1_final", 0, 1, 0x07, 1, 1, 0xffff),
        Step("f0_s3", 0, 1, 0x00, 0, 0, 0x0000),
        Step("f0_s4_wr_b2", 0, 1, 0xff, 1, 2, 0x0003),
        Step("f0_idle_mid", 0, 0, 0x00, 0, 0, 0x0000),
        Step("f0_s5", 0, 1, 0x04, 0, 0, 0x0000),
        Step("f0_s6", 0, 1, 0xfb, 0, 0, 0x0000),
        Step("f0_s7", 0, 1, 0x06, 0, 0, 0x0000),
        Step("f1_s0_wr_c2", 0, 1, 0x02, 1, 2, 0x0004),
        Step("f1_s1", 0, 1, 0x01, 0, 0, 0x0000),
        Step("f1_s2_wr_c0", 0, 1, 0xfc, 1, 0, 0xfffd),
        Step("f1_idle_wr_c1_temp", 0, 0, 0x00, 1, 1, 0x0005),
        Step("f1_s3_wr_c1_final", 0, 1, 0x03, 1, 1, 0xffff),
        Step("f1_s4", 0, 1, 0x00, 0, 0, 0x0000),
        Step("f1_s5", 0, 1, 0xfe, 0, 0, 0x0000),
        Step("f1_idle_mid", 0, 0, 0x00, 0, 0, 0x0000),
        Step("f1_s6", 0, 1, 0x05, 0, 0, 0x0000),
        Step("f1_s7", 0, 1, 0xff, 0, 0, 0x0000),
        Step("f2_s0_wr_d0", 0, 1, 0x07, 1, 0, 0x0006),
        Step("f2_s1", 0, 1, 0xf8, 0, 0, 0x0000),
        Step("f2_s2_wr_d1", 0, 1, 0x02, 1, 1, 0x0002),
        Step("f2_s3", 0, 1, 0x04, 0, 0, 0x0000),
        Step("f2_s4", 0, 1, 0xfd, 0, 0, 0x0000),
        Step("f2_s5", 0, 1, 0x01, 0, 0, 0x0000),
        Step("f2_s6", 0, 1, 0x00, 0, 0, 0x0000),
        Step("f2_s7", 0, 1, 0x06, 0, 0, 0x0000),
        Step("pre_frame3_wr_d2", 0, 0, 0x00, 1, 2, 0xfff9),
        Step("rst_discard_d", 1, 1, 0x33, 1, 4, 0x2222),
        Step("post_reset_idle_0", 0, 0, 0x00, 0, 0, 0x0000),
        Step("flush_candidate", 0, 1, 0xf4, 0, 0, 0x0000),
        Step("flush_reset", 1, 0, 0x00, 0, 0, 0x0000),
        Step("post_flush_idle", 0, 0, 0x00, 0, 0, 0x0000),
        Step("pre_e0", 0, 0, 0x00, 1, 0, 0x7fff),
        Step("pre_e1", 0, 0, 0x00, 1, 1, 0x8000),
        Step("pre_e2", 0, 0, 0x00, 1, 2, 0x4000),
        Step("pre_e3", 0, 0, 0x00, 1, 3, 0xc000),
        Step("pre_e4", 0, 0, 0x00, 1, 4, 0x0400),
        Step("pre_e5", 0, 0, 0x00, 1, 5, 0xfc00),
        Step("pre_e6", 0, 0, 0x00, 1, 6, 0x0003),
        Step("pre_e7", 0, 0, 0x00, 1, 7, 0xfffd),
        Step("ext_s0", 0, 1, 0x7f, 0, 0, 0x0000),
        Step("ext_s1", 0, 1, 0x80, 0, 0, 0x0000),
        Step("ext_s2", 0, 1, 0x7f, 0, 0, 0x0000),
        Step("ext_s3", 0, 1, 0x80, 0, 0, 0x0000),
        Step("ext_s4", 0, 1, 0x40, 0, 0, 0x0000),
        Step("ext_s5", 0, 1, 0xc0, 0, 0, 0x0000),
        Step("ext_s6", 0, 1, 0x01, 0, 0, 0x0000),
        Step("ext_s7", 0, 1, 0xff, 0, 0, 0x0000),
        Step("final_drain", 0, 0, 0x00, 0, 0, 0x0000),
        Step("final_clear", 0, 0, 0x00, 0, 0, 0x0000),
    ]


def main():
    model = Model()
    steps = build_steps()
    lines = []
    results = {}

    for step in steps:
        exp_valid, exp_data = model.step(step)
        results[step.tag] = (exp_valid, exp_data)
        lines.append(
            "        run_step("
            f"\"{step.tag}\", "
            f"1'b{step.rst}, "
            f"1'b{step.data_valid}, 8'h{step.data_in & MASK_8:02x}, "
            f"1'b{step.coeff_wr_en}, 3'd{step.coeff_addr}, 16'h{step.coeff_data & MASK_16:04x}, "
            f"1'b{exp_valid}, 27'h{to_hex(exp_data, 27)});"
        )

    # Hand-audited spot checks for the intended semantic edges.
    assert results["f0_s1_wr_b0"] == (1, 5)
    assert results["f0_s2_wr_b1_final"] == (0, 0)
    assert results["f1_s0_wr_c2"] == (1, 6)
    assert results["f1_s1"] == (1, -17)
    assert results["f2_s0_wr_d0"] == (1, -13)
    assert results["f2_s1"] == (1, 0)
    assert results["rst_discard_d"] == (0, 0)
    assert results["flush_reset"] == (0, 0)
    assert results["ext_s1"] == (1, 4161409)
    assert results["final_drain"] == (1, -2424572)
    assert results["final_clear"] == (0, 0)

    print("\n".join(lines))


if __name__ == "__main__":
    main()
