#!/usr/bin/env python3
"""Generate an exhaustive self-checking testbench for problem 044."""

from pathlib import Path


MODULE_NAME = "bnn_dense8"
TESTBENCH_NAME = "tb_bnn_dense8"
OUTPUT_PATH = Path(__file__).with_name("testbench.v")

WEIGHTS = [
    0xB3A6,
    0x5CD9,
    0xE3C5,
    0x35B6,
    0xCA79,
    0x6E1E,
    0x95E3,
    0xF42D,
]

THRESHOLDS = [2, 0, -4, 6, -2, 4, 8, 0]


def bipolar(bit: int) -> int:
    return 1 if bit else -1


def expected_output(x_in: int) -> int:
    y_out = 0
    for neuron_idx, (weight, threshold) in enumerate(zip(WEIGHTS, THRESHOLDS)):
        score = 0
        for bit_idx in range(16):
            x_bit = (x_in >> bit_idx) & 1
            w_bit = (weight >> bit_idx) & 1
            score += bipolar(x_bit) * bipolar(w_bit)
        if score >= threshold:
            y_out |= 1 << neuron_idx
    return y_out


def emit_expected_table() -> list[str]:
    lines = []
    for x_in in range(1 << 16):
        lines.append(f"        expected[{x_in}] = 8'h{expected_output(x_in):02X};")
    return lines


def build_testbench() -> str:
    lines = [
        "`timescale 1ns / 1ps",
        "",
        f"module {TESTBENCH_NAME};",
        "    reg [15:0] x_in;",
        "    wire [7:0] y_out;",
        "    reg [7:0] expected [0:65535];",
        "    integer i;",
        "    integer errors;",
        "    integer tests_run;",
        "",
        f"    {MODULE_NAME} uut (",
        "        .x_in(x_in),",
        "        .y_out(y_out)",
        "    );",
        "",
        "    task check_vector;",
        "        input [15:0] vec;",
        "        reg [7:0] exp;",
        "        begin",
        "            x_in = vec;",
        "            #1;",
        "            exp = expected[vec];",
        "            tests_run = tests_run + 1;",
        "            if (y_out !== exp) begin",
        '                $display("ERROR: x_in=%h expected=%h got=%h", vec, exp, y_out);',
        "                errors = errors + 1;",
        "            end",
        "        end",
        "    endtask",
        "",
        "    initial begin",
        "        errors = 0;",
        "        tests_run = 0;",
        "        x_in = 16'h0000;",
        "",
        "        $display(\"===============================================\");",
        "        $display(\"  Fixed-Weight Binary Dense Layer Testbench\");",
        "        $display(\"===============================================\");",
        "",
        "        // Exhaustive expected-output table generated offline from the spec.",
    ]

    lines.extend(emit_expected_table())

    lines.extend(
        [
            "",
            "        // Quick smoke vectors before the full exhaustive sweep.",
            "        check_vector(16'h0000);",
            "        check_vector(16'hFFFF);",
            "        check_vector(16'hAAAA);",
            "        check_vector(16'h5555);",
            "",
            "        for (i = 0; i < 65536; i = i + 1) begin",
            "            check_vector(i[15:0]);",
            "        end",
            "",
            '        $display("");',
            '        $display("===============================================");',
            '        $display("  Tests Run: %0d", tests_run);',
            '        $display("===============================================");',
            "",
            "        if (errors == 0) begin",
            '            $display("TEST_RESULT: PASS");',
            "        end else begin",
            '            $display("TEST_RESULT: FAIL (%0d errors)", errors);',
            "        end",
            "",
            "        $finish;",
            "    end",
            "endmodule",
            "",
        ]
    )

    return "\n".join(lines)


def main() -> None:
    OUTPUT_PATH.write_text(build_testbench())


if __name__ == "__main__":
    main()
