#!/usr/bin/env python3
"""
Generate golden test vectors for IPv4 Header Parser testbench.
Creates various test cases including valid/invalid headers and edge cases.
"""

import random

def create_ipv4_header(version=4, ihl=5, tos=0, total_length=40, 
                        identification=0, flags=0, fragment_offset=0,
                        ttl=64, protocol=6, checksum=0,
                        src_ip=(192, 168, 1, 1), dst_ip=(10, 0, 0, 1)):
    """
    Create a 160-bit IPv4 header as a list of bytes (20 bytes).
    Returns the header as an integer (160 bits, MSB first).
    """
    header_bytes = []
    
    # Byte 0: Version (4 bits) + IHL (4 bits)
    header_bytes.append((version << 4) | (ihl & 0xF))
    
    # Byte 1: TOS
    header_bytes.append(tos & 0xFF)
    
    # Bytes 2-3: Total Length
    header_bytes.append((total_length >> 8) & 0xFF)
    header_bytes.append(total_length & 0xFF)
    
    # Bytes 4-5: Identification
    header_bytes.append((identification >> 8) & 0xFF)
    header_bytes.append(identification & 0xFF)
    
    # Bytes 6-7: Flags (3 bits) + Fragment Offset (13 bits)
    flags_frag = (flags << 13) | (fragment_offset & 0x1FFF)
    header_bytes.append((flags_frag >> 8) & 0xFF)
    header_bytes.append(flags_frag & 0xFF)
    
    # Byte 8: TTL
    header_bytes.append(ttl & 0xFF)
    
    # Byte 9: Protocol
    header_bytes.append(protocol & 0xFF)
    
    # Bytes 10-11: Header Checksum
    header_bytes.append((checksum >> 8) & 0xFF)
    header_bytes.append(checksum & 0xFF)
    
    # Bytes 12-15: Source IP
    for octet in src_ip:
        header_bytes.append(octet & 0xFF)
    
    # Bytes 16-19: Destination IP
    for octet in dst_ip:
        header_bytes.append(octet & 0xFF)
    
    # Convert to 160-bit integer
    header_int = 0
    for b in header_bytes:
        header_int = (header_int << 8) | b
    
    return header_int

def ip_to_int(ip_tuple):
    """Convert IP tuple (a, b, c, d) to 32-bit integer."""
    return (ip_tuple[0] << 24) | (ip_tuple[1] << 16) | (ip_tuple[2] << 8) | ip_tuple[3]

def generate_test_cases():
    """Generate test cases for the IPv4 parser."""
    test_cases = []
    
    # Test 0: Basic valid header (TCP)
    test_cases.append({
        'name': 'valid_tcp',
        'version': 4, 'ihl': 5,
        'total_length': 40, 'ttl': 64, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 1
    })
    
    # Test 1: Valid UDP header
    test_cases.append({
        'name': 'valid_udp',
        'version': 4, 'ihl': 5,
        'total_length': 60, 'ttl': 128, 'protocol': 17,
        'src_ip': (8, 8, 8, 8), 'dst_ip': (1, 1, 1, 1),
        'expected_valid': 1
    })
    
    # Test 2: Valid ICMP header
    test_cases.append({
        'name': 'valid_icmp',
        'version': 4, 'ihl': 5,
        'total_length': 84, 'ttl': 255, 'protocol': 1,
        'src_ip': (172, 16, 0, 1), 'dst_ip': (172, 16, 0, 2),
        'expected_valid': 1
    })
    
    # Test 3: Invalid version (IPv6 = 6)
    test_cases.append({
        'name': 'invalid_version',
        'version': 6, 'ihl': 5,
        'total_length': 40, 'ttl': 64, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 0
    })
    
    # Test 4: Invalid IHL (has options, IHL=6)
    test_cases.append({
        'name': 'invalid_ihl_6',
        'version': 4, 'ihl': 6,
        'total_length': 44, 'ttl': 64, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 0
    })
    
    # Test 5: Invalid IHL (too short, IHL=4)
    test_cases.append({
        'name': 'invalid_ihl_4',
        'version': 4, 'ihl': 4,
        'total_length': 40, 'ttl': 64, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 0
    })
    
    # Test 6: TTL = 1 (about to expire)
    test_cases.append({
        'name': 'ttl_one',
        'version': 4, 'ihl': 5,
        'total_length': 40, 'ttl': 1, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 1
    })
    
    # Test 7: TTL = 0 (expired, still valid header)
    test_cases.append({
        'name': 'ttl_zero',
        'version': 4, 'ihl': 5,
        'total_length': 40, 'ttl': 0, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 1
    })
    
    # Test 8: Max total length (65535)
    test_cases.append({
        'name': 'max_length',
        'version': 4, 'ihl': 5,
        'total_length': 65535, 'ttl': 64, 'protocol': 6,
        'src_ip': (192, 168, 1, 1), 'dst_ip': (10, 0, 0, 1),
        'expected_valid': 1
    })
    
    # Test 9: All zeros IP (invalid but header is syntactically valid)
    test_cases.append({
        'name': 'zero_ips',
        'version': 4, 'ihl': 5,
        'total_length': 20, 'ttl': 64, 'protocol': 6,
        'src_ip': (0, 0, 0, 0), 'dst_ip': (0, 0, 0, 0),
        'expected_valid': 1
    })
    
    # Test 10: Broadcast destination
    test_cases.append({
        'name': 'broadcast',
        'version': 4, 'ihl': 5,
        'total_length': 100, 'ttl': 32, 'protocol': 17,
        'src_ip': (192, 168, 1, 100), 'dst_ip': (255, 255, 255, 255),
        'expected_valid': 1
    })
    
    # Test 11-20: Random valid headers
    random.seed(12345)
    for i in range(10):
        test_cases.append({
            'name': f'random_{i}',
            'version': 4, 'ihl': 5,
            'total_length': random.randint(20, 65535),
            'ttl': random.randint(0, 255),
            'protocol': random.choice([1, 6, 17, 47, 50, 51, 89]),
            'src_ip': tuple(random.randint(0, 255) for _ in range(4)),
            'dst_ip': tuple(random.randint(0, 255) for _ in range(4)),
            'expected_valid': 1
        })
    
    # Test 21-24: Random invalid headers
    for i in range(4):
        v = random.choice([0, 5, 6, 7, 15])  # Invalid versions
        ihl = random.choice([0, 1, 4, 6, 15])  # Invalid IHL
        test_cases.append({
            'name': f'random_invalid_{i}',
            'version': v, 'ihl': ihl,
            'total_length': random.randint(20, 1500),
            'ttl': random.randint(0, 255),
            'protocol': random.randint(0, 255),
            'src_ip': tuple(random.randint(0, 255) for _ in range(4)),
            'dst_ip': tuple(random.randint(0, 255) for _ in range(4)),
            'expected_valid': 1 if (v == 4 and ihl == 5) else 0
        })
    
    return test_cases

def main():
    test_cases = generate_test_cases()
    
    print("// ==============================================")
    print("// IPv4 Header Parser Golden Test Vectors")
    print("// Generated by: python3 generate_golden.py")
    print("// ==============================================")
    print()
    print(f"localparam NUM_TESTS = {len(test_cases)};")
    print()
    print("// Test data")
    print(f"reg [159:0] tb_header [0:{len(test_cases)-1}];")
    print(f"reg         exp_valid [0:{len(test_cases)-1}];")
    print(f"reg [15:0]  exp_total_length [0:{len(test_cases)-1}];")
    print(f"reg [7:0]   exp_ttl [0:{len(test_cases)-1}];")
    print(f"reg [7:0]   exp_protocol [0:{len(test_cases)-1}];")
    print(f"reg [31:0]  exp_src_ip [0:{len(test_cases)-1}];")
    print(f"reg [31:0]  exp_dst_ip [0:{len(test_cases)-1}];")
    print()
    print("initial begin")
    
    for i, tc in enumerate(test_cases):
        header = create_ipv4_header(
            version=tc['version'], ihl=tc['ihl'],
            total_length=tc['total_length'], ttl=tc['ttl'],
            protocol=tc['protocol'],
            src_ip=tc['src_ip'], dst_ip=tc['dst_ip']
        )
        src_int = ip_to_int(tc['src_ip'])
        dst_int = ip_to_int(tc['dst_ip'])
        
        print(f"    // Test {i}: {tc['name']}")
        print(f"    tb_header[{i}] = 160'h{header:040x};")
        print(f"    exp_valid[{i}] = 1'b{tc['expected_valid']};")
        print(f"    exp_total_length[{i}] = 16'd{tc['total_length']};")
        print(f"    exp_ttl[{i}] = 8'd{tc['ttl']};")
        print(f"    exp_protocol[{i}] = 8'd{tc['protocol']};")
        print(f"    exp_src_ip[{i}] = 32'h{src_int:08x};")
        print(f"    exp_dst_ip[{i}] = 32'h{dst_int:08x};")
        print()
    
    print("end")

if __name__ == "__main__":
    main()
