#!/usr/bin/env python3
"""Generate golden vectors for the fixed-method softmax approximation problem."""

from __future__ import annotations

import random
from typing import List, Sequence, Tuple


def to_signed8(val: int) -> int:
    return val - 256 if val >= 128 else val


def q4_4_to_float(val: int) -> float:
    return to_signed8(val) / 16.0


def exp_piecewise(diff: int) -> int:
    """Exact required piecewise approximation. Input diff is raw signed Q4.4 steps."""
    if diff > -4:
        e = 256 + diff * 16
    elif diff > -8:
        e = 192 + (diff + 4) * 12
    elif diff > -16:
        e = 144 + (diff + 8) * 9
    elif diff > -32:
        e = 92 + (((diff + 16) * 92) >> 4)
    elif diff > -64:
        e = 34 + (((diff + 32) * 34) >> 4)
    else:
        e = 0

    if e < 0:
        return 0
    if e > 255:
        return 255
    return e


def softmax_fixed(xs: Sequence[int]) -> List[int]:
    sx = [to_signed8(v) for v in xs]
    max_v = max(sx)
    diffs = [v - max_v for v in sx]
    exps = [exp_piecewise(d) for d in diffs]
    sum_e = sum(exps)

    if sum_e == 0:
        return [64, 64, 64, 64]

    ys = [min(255, (e * 256) // sum_e) for e in exps]
    return ys


def make_test_cases() -> List[Tuple[List[int], str]]:
    tests: List[Tuple[List[int], str]] = [
        ([0, 0, 0, 0], "all_zero"),
        ([64, 64, 64, 64], "all_equal_pos"),
        ([224, 224, 224, 224], "all_equal_neg"),
        ([127, 0, 0, 0], "single_dominant_pos"),
        ([0, 127, 0, 0], "single_dominant_pos_1"),
        ([0, 0, 127, 0], "single_dominant_pos_2"),
        ([0, 0, 0, 127], "single_dominant_pos_3"),
        ([127, 126, 125, 124], "close_cluster_high"),
        ([16, 15, 14, 13], "close_cluster_mid"),
        ([0, 252, 248, 240], "boundary_exact"),
        ([0, 253, 249, 241], "boundary_just_above"),
        ([0, 251, 247, 239], "boundary_just_below"),
        ([0, 224, 192, 128], "wide_negative_span"),
        ([32, 240, 48, 224], "mixed_signed_values"),
        ([255, 254, 253, 252], "negative_close_cluster"),
        ([127, 63, 31, 0], "descending_positive"),
        ([0, 16, 32, 48], "ascending_positive"),
        ([128, 144, 160, 176], "ascending_negative"),
        ([127, 128, 127, 128], "alternating_extremes"),
        ([80, 79, 78, 16], "three_similar_one_low"),
        ([80, 16, 79, 78], "three_similar_permuted"),
        ([16, 80, 79, 78], "three_similar_permuted2"),
        ([128, 192, 224, 0], "neg_to_zero"),
        ([8, 89, 52, 129], "mixed_seed_like"),
    ]

    rng = random.Random(12345)
    for i in range(24):
        tests.append(([rng.randrange(256) for _ in range(4)], f"random_{i:02d}"))

    return tests


def emit_verilog_arrays(test_cases: Sequence[Tuple[Sequence[int], str]]) -> None:
    n = len(test_cases)

    print(f"localparam NUM_TESTS = {n};")
    print()
    print(f"reg [7:0] test_x0 [0:{n-1}];")
    print(f"reg [7:0] test_x1 [0:{n-1}];")
    print(f"reg [7:0] test_x2 [0:{n-1}];")
    print(f"reg [7:0] test_x3 [0:{n-1}];")
    print()
    print(f"reg [7:0] exp_y0 [0:{n-1}];")
    print(f"reg [7:0] exp_y1 [0:{n-1}];")
    print(f"reg [7:0] exp_y2 [0:{n-1}];")
    print(f"reg [7:0] exp_y3 [0:{n-1}];")
    print()
    print("initial begin")

    for i, (xs, name) in enumerate(test_cases):
        ys = softmax_fixed(xs)
        x0, x1, x2, x3 = xs
        y0, y1, y2, y3 = ys
        print(f"    // Test {i}: {name}")
        print(
            f"    test_x0[{i}] = 8'd{x0}; test_x1[{i}] = 8'd{x1}; "
            f"test_x2[{i}] = 8'd{x2}; test_x3[{i}] = 8'd{x3};"
        )
        print(
            f"    exp_y0[{i}] = 8'd{y0}; exp_y1[{i}] = 8'd{y1}; "
            f"exp_y2[{i}] = 8'd{y2}; exp_y3[{i}] = 8'd{y3};"
        )

    print("end")


def emit_summary(test_cases: Sequence[Tuple[Sequence[int], str]]) -> None:
    print()
    print("// Summary:")
    for i, (xs, name) in enumerate(test_cases):
        ys = softmax_fixed(xs)
        xf = [q4_4_to_float(v) for v in xs]
        yp = [v / 256.0 for v in ys]
        print(
            f"// {i:2d}. {name:22s} "
            f"x=[{xf[0]:6.2f},{xf[1]:6.2f},{xf[2]:6.2f},{xf[3]:6.2f}] "
            f"-> y=[{yp[0]:.3f},{yp[1]:.3f},{yp[2]:.3f},{yp[3]:.3f}] sum={sum(ys)}"
        )


def main() -> None:
    tests = make_test_cases()
    print("// Golden vectors for 012_softmax_approx")
    print("// Generated by: python3 generate_golden.py")
    print("// Algorithm: fixed piecewise softmax from spec.yaml")
    print()
    emit_verilog_arrays(tests)
    emit_summary(tests)


if __name__ == "__main__":
    main()
