#!/usr/bin/env python3
"""
Generate golden test vectors for CNN 3x3 convolution testbench.
Computes expected outputs for a fixed 28x28 input image with 4 output filters.
"""

import numpy as np

# Image dimensions
IMG_H, IMG_W = 28, 28
NUM_FILTERS = 4
KERNEL_SIZE = 3

# Fixed test input: gradient pattern (easy to verify)
def generate_test_input():
    """Generate a simple 28x28 test image with gradient pattern."""
    img = np.zeros((IMG_H, IMG_W), dtype=np.uint8)
    for r in range(IMG_H):
        for c in range(IMG_W):
            img[r, c] = (r * 4 + c * 2) % 256
    return img

# Fixed 3x3 kernels (8-bit signed weights)
KERNELS = [
    # Filter 0: Simple edge detector (horizontal)
    np.array([[-1, -2, -1],
              [ 0,  0,  0],
              [ 1,  2,  1]], dtype=np.int8),
    # Filter 1: Simple edge detector (vertical)  
    np.array([[-1,  0,  1],
              [-2,  0,  2],
              [-1,  0,  1]], dtype=np.int8),
    # Filter 2: Smoothing/blur
    np.array([[ 1,  2,  1],
              [ 2,  4,  2],
              [ 1,  2,  1]], dtype=np.int8),
    # Filter 3: Sharpening (center-weighted)
    np.array([[ 0, -1,  0],
              [-1,  5, -1],
              [ 0, -1,  0]], dtype=np.int8),
]

# Fixed biases (8-bit signed)
BIASES = np.array([0, 0, 8, 10], dtype=np.int8)

def conv2d_single_filter(img, kernel, bias, pad=1, stride=1):
    """
    Compute 2D convolution for a single filter with padding, bias, and ReLU.
    Returns 8-bit unsigned output (clamped to [0, 255]).
    """
    h, w = img.shape
    kh, kw = kernel.shape
    
    # Zero-pad the input
    img_padded = np.pad(img.astype(np.int32), pad, mode='constant', constant_values=0)
    
    # Output dimensions (same as input for stride=1, pad=1)
    out_h = (h + 2*pad - kh) // stride + 1
    out_w = (w + 2*pad - kw) // stride + 1
    output = np.zeros((out_h, out_w), dtype=np.int32)
    
    # Convolution
    for r in range(out_h):
        for c in range(out_w):
            region = img_padded[r*stride:r*stride+kh, c*stride:c*stride+kw]
            acc = np.sum(region * kernel.astype(np.int32))
            output[r, c] = acc
    
    # Add bias
    output = output + int(bias)
    
    # ReLU: clamp negatives to 0
    output = np.maximum(output, 0)
    
    # Saturate to 8-bit unsigned [0, 255]
    output = np.minimum(output, 255)
    
    return output.astype(np.uint8)

def compute_all_outputs(img):
    """Compute convolution outputs for all 4 filters."""
    outputs = []
    for i in range(NUM_FILTERS):
        out = conv2d_single_filter(img, KERNELS[i], BIASES[i])
        outputs.append(out)
    return outputs

def pack_kernel_to_hex(kernel):
    """Pack 3x3 kernel (row-major) into 72-bit hex string."""
    # Flatten row-major, each weight is 8-bit signed
    flat = kernel.flatten()
    val = 0
    for i, w in enumerate(flat):
        # Convert signed to unsigned 8-bit for packing
        w_unsigned = int(w) & 0xFF
        val |= (w_unsigned << (i * 8))
    return f"72'h{val:018X}"

def pack_biases_to_hex(biases):
    """Pack 4 biases into 32-bit hex string."""
    val = 0
    for i, b in enumerate(biases):
        b_unsigned = int(b) & 0xFF
        val |= (b_unsigned << (i * 8))
    return f"32'h{val:08X}"

def generate_verilog_arrays():
    """Generate Verilog-compatible test data (full verbose format)."""
    img = generate_test_input()
    outputs = compute_all_outputs(img)
    
    print("// ============================================")
    print("// AUTO-GENERATED TEST VECTORS - DO NOT EDIT")
    print("// Run: python3 generate_golden.py")
    print("// ============================================")
    print()
    
    # Kernel parameters
    print("// Kernel parameters (72-bit each, 9 x 8-bit signed weights)")
    for i, k in enumerate(KERNELS):
        print(f"parameter KERNEL_{i} = {pack_kernel_to_hex(k)};")
    print()
    
    # Bias parameter
    print("// Bias parameter (32-bit, 4 x 8-bit signed)")
    print(f"parameter BIAS = {pack_biases_to_hex(BIASES)};")
    print()
    
    # Input image
    print(f"// Input image ({IMG_H}x{IMG_W} = {IMG_H*IMG_W} pixels, row-major)")
    print(f"reg [7:0] test_image [0:{IMG_H*IMG_W-1}];")
    print("initial begin")
    for r in range(IMG_H):
        for c in range(IMG_W):
            idx = r * IMG_W + c
            print(f"    test_image[{idx}] = 8'd{img[r, c]};")
    print("end")
    print()
    
    # Expected outputs (4 channels interleaved per pixel output)
    print(f"// Expected outputs ({IMG_H}x{IMG_W}x{NUM_FILTERS} = {IMG_H*IMG_W*NUM_FILTERS} values)")
    print(f"// Stored as 4-channel packed values per pixel position")
    print(f"reg [31:0] expected_output [0:{IMG_H*IMG_W-1}];")
    print("initial begin")
    for r in range(IMG_H):
        for c in range(IMG_W):
            idx = r * IMG_W + c
            # Pack 4 filter outputs into 32-bit value
            packed = 0
            for f in range(NUM_FILTERS):
                packed |= (int(outputs[f][r, c]) << (f * 8))
            print(f"    expected_output[{idx}] = 32'h{packed:08X};")
    print("end")


def generate_testbench_array():
    """Generate expected_output as compact Verilog array literal for testbench embedding."""
    img = generate_test_input()
    outputs = compute_all_outputs(img)
    
    # Collect all packed values
    values = []
    for r in range(IMG_H):
        for c in range(IMG_W):
            packed = 0
            for f in range(NUM_FILTERS):
                packed |= (int(outputs[f][r, c]) << (f * 8))
            values.append(f"32'h{packed:08X}")
    
    # Print as compact array (8 values per line)
    print("    // Golden expected outputs from generate_golden.py")
    print("    // DO NOT EDIT - regenerate with: python3 generate_golden.py --array")
    print("    reg [31:0] expected_output [0:783] = '{")
    for i in range(0, len(values), 8):
        line_vals = values[i:i+8]
        suffix = "," if i + 8 < len(values) else ""
        print("        " + ", ".join(line_vals) + suffix)
    print("    };")


if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == "--array":
        generate_testbench_array()
    else:
        generate_verilog_arrays()
