#!/usr/bin/env python3
"""Generate hardcoded testbench vectors for problem 047.

This helper is kept in the problem folder for auditability, but the
generated values are intended to be copied directly into `testbench.v`.
"""

from dataclasses import dataclass


@dataclass(frozen=True)
class Step:
    tag: str
    rst: int
    lookup_key: int
    wr_en: int
    wr_addr: int
    wr_clear: int
    wr_key: int
    wr_mask: int
    wr_value: int


def tcam_lookup(state, rst: int, lookup_key: int):
    if rst:
        return 0, 0, 0

    for idx, entry in enumerate(state):
        if not entry["valid"]:
            continue
        if (((lookup_key ^ entry["key"]) & entry["mask"]) & 0xFF) == 0:
            return 1, idx, entry["value"]
    return 0, 0, 0


def apply_edge(state, step: Step):
    if step.rst:
        for entry in state:
            entry["valid"] = 0
            entry["key"] = 0
            entry["mask"] = 0
            entry["value"] = 0
        return

    if not step.wr_en:
        return

    entry = state[step.wr_addr]
    if step.wr_clear:
        entry["valid"] = 0
        entry["key"] = 0
        entry["mask"] = 0
        entry["value"] = 0
    else:
        entry["valid"] = 1
        entry["key"] = step.wr_key & 0xFF
        entry["mask"] = step.wr_mask & 0xFF
        entry["value"] = step.wr_value & 0xFF


def build_steps():
    return [
        Step("rst_idle0", 1, 0x00, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("rst_ign_write", 1, 0xA5, 1, 2, 0, 0xA5, 0xFF, 0x11),
        Step("empty_post_reset", 0, 0xA5, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_exact_s2", 0, 0xA5, 1, 2, 0, 0xA5, 0xFF, 0x11),
        Step("query_exact_hit", 0, 0xA5, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("query_exact_miss", 0, 0xA4, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_masked_s5", 0, 0xB7, 1, 5, 0, 0xB2, 0xF0, 0x22),
        Step("query_masked_hit", 0, 0xBF, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("query_masked_miss", 0, 0xA2, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_wildcard_s1", 0, 0x5C, 1, 1, 0, 0x00, 0x00, 0x33),
        Step("query_wildcard", 0, 0xE1, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_priority_s0", 0, 0xE7, 1, 0, 0, 0xE0, 0xF0, 0x44),
        Step("query_priority_s0", 0, 0xE3, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_third_match_s3", 0, 0xE6, 1, 3, 0, 0xE5, 0xF0, 0x55),
        Step("clear_priority_s0", 0, 0xE1, 1, 0, 1, 0x00, 0x00, 0x00),
        Step("overwrite_s1_exact", 0, 0x12, 1, 1, 0, 0x12, 0xFF, 0x66),
        Step("query_after_overwrite_miss", 0, 0x34, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_s6_exact", 0, 0x3C, 1, 6, 0, 0x3C, 0xFF, 0x77),
        Step("write_s4_masked", 0, 0x3B, 1, 4, 0, 0x30, 0xF0, 0x99),
        Step("write_s7_masked", 0, 0x3B, 1, 7, 0, 0x38, 0xFC, 0x88),
        Step("clear_s4_reveal_s7", 0, 0x3B, 1, 4, 1, 0x00, 0x00, 0x00),
        Step("rst_midstream", 1, 0x12, 1, 2, 0, 0x5A, 0xFF, 0xAB),
        Step("post_reset_empty", 0, 0x12, 0, 0, 0, 0x00, 0x00, 0x00),
        Step("write_post_reset", 0, 0x5A, 1, 2, 0, 0x5A, 0xFF, 0xAB),
        Step("query_post_reset_hit", 0, 0x5A, 0, 0, 0, 0x00, 0x00, 0x00),
    ]


def main():
    state = [
        {"valid": 0, "key": 0, "mask": 0, "value": 0}
        for _ in range(8)
    ]

    steps = build_steps()
    output_lines = []
    trace = []

    for step in steps:
        pre = tcam_lookup(state, step.rst, step.lookup_key)
        apply_edge(state, step)
        post = tcam_lookup(state, step.rst, step.lookup_key)
        trace.append((step, pre, post))

        output_lines.append(
            "        run_step("
            f"\"{step.tag}\", "
            f"1'b{step.rst}, 8'h{step.lookup_key:02x}, "
            f"1'b{step.wr_en}, 3'd{step.wr_addr}, 1'b{step.wr_clear}, "
            f"8'h{step.wr_key:02x}, 8'h{step.wr_mask:02x}, 8'h{step.wr_value:02x}, "
            f"1'b{pre[0]}, 3'd{pre[1]}, 8'h{pre[2]:02x}, "
            f"1'b{post[0]}, 3'd{post[1]}, 8'h{post[2]:02x});"
        )

    # Small hand-audited sanity checks before freezing the vectors.
    checks = {step.tag: (pre, post) for step, pre, post in trace}
    assert checks["rst_ign_write"] == ((0, 0, 0), (0, 0, 0))
    assert checks["write_exact_s2"] == ((0, 0, 0), (1, 2, 0x11))
    assert checks["write_masked_s5"] == ((0, 0, 0), (1, 5, 0x22))
    assert checks["write_wildcard_s1"] == ((0, 0, 0), (1, 1, 0x33))
    assert checks["write_priority_s0"] == ((1, 1, 0x33), (1, 0, 0x44))
    assert checks["clear_priority_s0"] == ((1, 0, 0x44), (1, 1, 0x33))
    assert checks["clear_s4_reveal_s7"] == ((1, 4, 0x99), (1, 7, 0x88))
    assert checks["rst_midstream"] == ((0, 0, 0), (0, 0, 0))
    assert checks["write_post_reset"] == ((0, 0, 0), (1, 2, 0xAB))

    print("\n".join(output_lines))


if __name__ == "__main__":
    main()
