#!/usr/bin/env python3
"""
Generate golden test vectors for Hardmax/Argmax One-hot Testbench.
Calculates the index of the maximum value and outputs valid Verilog test vectors.
"""

import numpy as np

def generate_test_vectors():
    test_cases = []
    
    # 1. Edge case: All zeros
    test_cases.append(([0, 0, 0, 0, 0, 0], "all_zeros"))
    
    # 2. Distinct maximums
    test_cases.append(([100, 0, 0, 0, 0, 0], "max_x0"))
    test_cases.append(([0, 100, 0, 0, 0, 0], "max_x1"))
    test_cases.append(([0, 0, 100, 0, 0, 0], "max_x2"))
    test_cases.append(([0, 0, 0, 100, 0, 0], "max_x3"))
    test_cases.append(([0, 0, 0, 0, 100, 0], "max_x4"))
    test_cases.append(([0, 0, 0, 0, 0, 100], "max_x5"))
    
    # 3. Negative numbers
    test_cases.append(([-10, -20, -30, -40, -50, -60], "all_negative"))
    test_cases.append(([-60, -50, -40, -30, -20, -10], "max_x5_neg"))
    
    # 4. Tie breaking (lowest index wins)
    test_cases.append(([50, 50, 0, 0, 0, 0], "tie_0_1"))
    test_cases.append(([50, 0, 50, 0, 0, 0], "tie_0_2"))
    test_cases.append(([0, 0, 0, 50, 50, 50], "tie_3_4_5"))
    test_cases.append(([50, 50, 50, 50, 50, 50], "tie_all"))
    
    # 5. Close values
    test_cases.append(([10, 11, 10, 10, 10, 10], "close_x1"))
    test_cases.append(([10, 9, 8, 7, 6, 5], "descending"))
    test_cases.append(([5, 6, 7, 8, 9, 10], "ascending"))
    
    # 6. Random cases
    np.random.seed(1337)
    for i in range(20):
        # Generate random 8-bit signed integers (-128 to 127)
        vals = [int(np.random.randint(-128, 128)) for _ in range(6)]
        test_cases.append((vals, f"random_{i}"))
        
    return test_cases

def compute_onehot(inputs):
    """Compute 6-bit one-hot value for max index (lowest index wins ties)."""
    # Find max value
    max_val = max(inputs)
    # Find first index with max value
    max_idx = inputs.index(max_val)
    # Create one-hot
    return 1 << max_idx

def main():
    test_cases = generate_test_vectors()
    
    print("// ==============================================")
    print("// Hardmax One-Hot Golden Test Vectors")
    print("// Generated by: python3 generate_golden.py")
    print("// ==============================================")
    print()
    
    print(f"localparam NUM_TESTS = {len(test_cases)};")
    print()
    
    # Declare arrays
    print("// Input vectors (8-bit signed, stored as unsigned)")
    for i in range(6):
        print(f"reg [7:0] tb_x{i} [0:{len(test_cases)-1}];")
    print()
    print("// Expected output (6-bit one-hot)")
    print(f"reg [5:0] expected_y [0:{len(test_cases)-1}];")
    print()
    
    print("initial begin")
    for i, (inputs, name) in enumerate(test_cases):
        onehot = compute_onehot(inputs)
        
        # Format inputs for Verilog (handle negative representation if printing raw)
        # But we can just use decimal with 'd since Verilog handles -10 as signed correctly if assigned to reg
        # Actually safer to mask with 0xFF for unsigned storage arrays
        unsigned_inputs = [x & 0xFF for x in inputs]
        
        print(f"    // Test {i}: {name} (Inputs: {inputs}, Max Index: {inputs.index(max(inputs))})")
        for j in range(6):
            print(f"    tb_x{j}[{i}] = 8'd{unsigned_inputs[j]};", end=" ")
        print()
        print(f"    expected_y[{i}] = 6'b{onehot:06b};")
        print()
        
    print("end")

if __name__ == "__main__":
    main()
