#!/usr/bin/env python3
"""
Golden model for G-Share branch predictor.
Used to verify testbench correctness and generate expected outputs.
"""

class GSharePredictor:
    """G-Share branch predictor with 7-bit PC, 7-bit GHR, 128-entry PHT."""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Asynchronous reset: PHT to 01 (weakly not-taken), GHR to 0."""
        self.pht = [1] * 128  # 2'b01 = 1 = weakly not-taken
        self.ghr = 0  # 7-bit global history register
    
    def _get_index(self, pc, history):
        """Compute PHT index = PC XOR history."""
        return (pc ^ history) & 0x7F
    
    def _counter_inc(self, val):
        """Saturating increment (max 3)."""
        return min(val + 1, 3)
    
    def _counter_dec(self, val):
        """Saturating decrement (min 0)."""
        return max(val - 1, 0)
    
    def predict(self, pc):
        """
        Make a prediction.
        Returns: (predict_taken, predict_history)
        Note: GHR update happens at next clock edge (call update_ghr_for_prediction after).
        """
        index = self._get_index(pc, self.ghr)
        predict_taken = (self.pht[index] >> 1) & 1  # MSB of 2-bit counter
        predict_history = self.ghr
        return predict_taken, predict_history
    
    def update_ghr_for_prediction(self, predict_taken):
        """Update GHR at clock edge after prediction."""
        self.ghr = ((self.ghr << 1) | predict_taken) & 0x7F
    
    def train(self, pc, history, taken, mispredicted):
        """
        Train the predictor based on actual outcome.
        Updates PHT and optionally recovers GHR if mispredicted.
        """
        index = self._get_index(pc, history)
        
        # Update PHT counter
        if taken:
            self.pht[index] = self._counter_inc(self.pht[index])
        else:
            self.pht[index] = self._counter_dec(self.pht[index])
        
        # Recover GHR if mispredicted
        if mispredicted:
            self.ghr = ((history << 1) | taken) & 0x7F
    
    def cycle(self, predict_valid, predict_pc, train_valid, train_pc, 
              train_history, train_taken, train_mispredicted):
        """
        Simulate one clock cycle.
        Returns: (predict_taken, predict_history) or (None, None) if no prediction.
        
        Note: This mimics the RTL behavior where:
        - Predictions see pre-training PHT state
        - Training takes priority for GHR updates
        - PHT updates occur at next clock edge
        """
        predict_taken_out = None
        predict_history_out = None
        
        # Prediction (sees current PHT state before training updates)
        if predict_valid:
            index = self._get_index(predict_pc, self.ghr)
            predict_taken_out = (self.pht[index] >> 1) & 1
            predict_history_out = self.ghr
        
        # Now apply updates that happen at the clock edge
        
        # PHT update from training
        if train_valid:
            train_index = self._get_index(train_pc, train_history)
            if train_taken:
                self.pht[train_index] = self._counter_inc(self.pht[train_index])
            else:
                self.pht[train_index] = self._counter_dec(self.pht[train_index])
        
        # GHR update (training misprediction takes priority)
        if train_valid and train_mispredicted:
            self.ghr = ((train_history << 1) | train_taken) & 0x7F
        elif predict_valid:
            self.ghr = ((self.ghr << 1) | predict_taken_out) & 0x7F
        
        return predict_taken_out, predict_history_out


def test_golden_model():
    """Run tests to verify the golden model."""
    pred = GSharePredictor()
    
    print("=== Test 1: Initial state after reset ===")
    pred.reset()
    assert pred.ghr == 0, f"GHR should be 0, got {pred.ghr}"
    assert all(p == 1 for p in pred.pht), "All PHT entries should be 01"
    print("PASS: Initial state correct")
    
    print("\n=== Test 2: Simple prediction ===")
    pred.reset()
    taken, history = pred.cycle(1, 0x00, 0, 0, 0, 0, 0)
    assert taken == 0, f"Initial prediction should be not-taken (01 counter), got {taken}"
    assert history == 0, f"History should be 0, got {history}"
    print(f"PASS: predict_taken={taken}, predict_history={history}")
    
    print("\n=== Test 3: GHR update after prediction ===")
    # After predicting not-taken, GHR should shift in 0
    taken2, history2 = pred.cycle(1, 0x00, 0, 0, 0, 0, 0)
    assert pred.ghr == 0, f"GHR should be 0 (shifted in 0), got {pred.ghr}"
    print(f"PASS: GHR={pred.ghr} after not-taken prediction")
    
    print("\n=== Test 4: Training to taken ===")
    pred.reset()
    # Train PC=0 with history=0 to taken three times to reach strongly taken
    for i in range(3):
        pred.cycle(0, 0, 1, 0x00, 0, 1, 0)
    assert pred.pht[0] == 3, f"PHT[0] should be 11 (3), got {pred.pht[0]}"
    print(f"PASS: PHT[0]={pred.pht[0]} after training to taken 3 times")
    
    print("\n=== Test 5: Predict taken after training ===")
    pred.reset()
    for i in range(2):  # Train to strongly taken
        pred.cycle(0, 0, 1, 0x00, 0, 1, 0)
    taken, history = pred.cycle(1, 0x00, 0, 0, 0, 0, 0)
    assert taken == 1, f"Should predict taken after training, got {taken}"
    print(f"PASS: predict_taken={taken}")
    
    print("\n=== Test 6: Misprediction recovery ===")
    pred.reset()
    # First make some predictions to change GHR
    pred.cycle(1, 0x10, 0, 0, 0, 0, 0)  # Predict not-taken, GHR becomes 0
    pred.cycle(1, 0x20, 0, 0, 0, 0, 0)  # Predict not-taken, GHR becomes 0
    ghr_before = pred.ghr
    
    # Now train with misprediction - should recover GHR to {history[5:0], taken}
    train_history = 0b1010101
    train_taken = 1
    pred.cycle(0, 0, 1, 0x00, train_history, train_taken, 1)
    expected_ghr = ((train_history << 1) | train_taken) & 0x7F
    assert pred.ghr == expected_ghr, f"GHR should be {expected_ghr}, got {pred.ghr}"
    print(f"PASS: GHR recovered to {pred.ghr} (expected {expected_ghr})")
    
    print("\n=== Test 7: Training priority over prediction for GHR ===")
    pred.reset()
    # Both train (mispredict) and predict in same cycle
    # Training should take priority for GHR update
    train_history = 0b0101010
    train_taken = 0
    taken, history = pred.cycle(1, 0x00, 1, 0x00, train_history, train_taken, 1)
    expected_ghr = ((train_history << 1) | train_taken) & 0x7F
    assert pred.ghr == expected_ghr, f"GHR should be {expected_ghr} (train priority), got {pred.ghr}"
    print(f"PASS: Training takes priority, GHR={pred.ghr}")
    
    print("\n=== Test 8: XOR indexing ===")
    pred.reset()
    # Train PC=0x55, history=0x2A -> index = 0x55 ^ 0x2A = 0x7F
    pred.cycle(0, 0, 1, 0x55, 0x2A, 1, 0)
    assert pred.pht[0x7F] == 2, f"PHT[0x7F] should be 2, got {pred.pht[0x7F]}"
    assert pred.pht[0] == 1, f"PHT[0] should still be 1, got {pred.pht[0]}"
    print(f"PASS: XOR indexing correct, PHT[0x7F]={pred.pht[0x7F]}")
    
    print("\n=== All tests passed! ===")
    return True


def generate_test_vectors():
    """Generate test vectors for Verilog testbench verification."""
    pred = GSharePredictor()
    vectors = []
    
    # Format: (predict_valid, predict_pc, train_valid, train_pc, train_history, train_taken, train_mispredict)
    # Expected: (predict_taken, predict_history)
    
    test_cases = [
        # Basic predictions
        (1, 0x00, 0, 0, 0, 0, 0),
        (1, 0x01, 0, 0, 0, 0, 0),
        (1, 0x7F, 0, 0, 0, 0, 0),
        # Training
        (0, 0, 1, 0x00, 0x00, 1, 0),
        (0, 0, 1, 0x00, 0x00, 1, 0),
        # Predict after training
        (1, 0x00, 0, 0, 0, 0, 0),
        # Misprediction recovery  
        (0, 0, 1, 0x10, 0x55, 0, 1),
        (1, 0x00, 0, 0, 0, 0, 0),
        # Simultaneous train and predict
        (1, 0x20, 1, 0x10, 0x00, 1, 0),
        # More cases...
    ]
    
    pred.reset()
    for tc in test_cases:
        pv, ppc, tv, tpc, th, tt, tm = tc
        pt, ph = pred.cycle(pv, ppc, tv, tpc, th, tt, tm)
        vectors.append({
            'inputs': tc,
            'outputs': (pt, ph),
            'ghr_after': pred.ghr
        })
    
    return vectors


if __name__ == "__main__":
    test_golden_model()
    print("\n" + "="*50)
    print("Generating test vectors...")
    vectors = generate_test_vectors()
    for i, v in enumerate(vectors):
        print(f"  Cycle {i}: in={v['inputs']}, out={v['outputs']}, ghr_after={v['ghr_after']}")
