`timescale 1ns / 1ps

module tb_hamming_secded_decoder;
    reg clk;
    reg rst;
    reg in_valid;
    reg [12:0] codeword_in;

    wire out_valid;
    wire [7:0] data_out;
    wire single_error_corrected;
    wire double_error_detected;

    integer errors = 0;
    integer tests_run = 0;
    integer cycle_count = 0;
    integer i;
    integer pos;
    integer mix_idx;

    reg golden_out_valid;
    reg [7:0] golden_out_data;
    reg golden_out_single;
    reg golden_out_double;

    reg [7:0] clean_patterns [0:11];
    reg [7:0] single_patterns [0:1];
    reg [7:0] mixed_patterns [0:3];
    integer double_a [0:11];
    integer double_b [0:11];

    hamming_secded_decoder uut (
        .clk(clk),
        .rst(rst),
        .in_valid(in_valid),
        .codeword_in(codeword_in),
        .out_valid(out_valid),
        .data_out(data_out),
        .single_error_corrected(single_error_corrected),
        .double_error_detected(double_error_detected)
    );

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    function [3:0] calc_syndrome;
        input [12:0] cw;
        begin
            calc_syndrome[0] = cw[0] ^ cw[2] ^ cw[4] ^ cw[6] ^ cw[8] ^ cw[10];
            calc_syndrome[1] = cw[1] ^ cw[2] ^ cw[5] ^ cw[6] ^ cw[9] ^ cw[10];
            calc_syndrome[2] = cw[3] ^ cw[4] ^ cw[5] ^ cw[6] ^ cw[11];
            calc_syndrome[3] = cw[7] ^ cw[8] ^ cw[9] ^ cw[10] ^ cw[11];
        end
    endfunction

    function [7:0] extract_data;
        input [12:0] cw;
        begin
            extract_data = {cw[11], cw[10], cw[9], cw[8], cw[6], cw[5], cw[4], cw[2]};
        end
    endfunction

    function [12:0] encode_codeword;
        input [7:0] data;
        reg [12:0] cw;
        begin
            cw = 13'd0;
            cw[2] = data[0];
            cw[4] = data[1];
            cw[5] = data[2];
            cw[6] = data[3];
            cw[8] = data[4];
            cw[9] = data[5];
            cw[10] = data[6];
            cw[11] = data[7];

            cw[0] = cw[2] ^ cw[4] ^ cw[6] ^ cw[8] ^ cw[10];
            cw[1] = cw[2] ^ cw[5] ^ cw[6] ^ cw[9] ^ cw[10];
            cw[3] = cw[4] ^ cw[5] ^ cw[6] ^ cw[11];
            cw[7] = cw[8] ^ cw[9] ^ cw[10] ^ cw[11];
            cw[12] = ^cw[11:0];

            encode_codeword = cw;
        end
    endfunction

    function [12:0] flip_one_bit;
        input [12:0] cw;
        input integer bit_pos;
        reg [12:0] flipped;
        begin
            flipped = cw;
            flipped[bit_pos - 1] = ~flipped[bit_pos - 1];
            flip_one_bit = flipped;
        end
    endfunction

    task decode_expected;
        input [12:0] cw;
        output [7:0] decoded_data;
        output decoded_single;
        output decoded_double;
        reg [3:0] syndrome;
        reg total_parity_odd;
        reg [12:0] corrected;
        begin
            syndrome = calc_syndrome(cw);
            total_parity_odd = ^cw;
            corrected = cw;
            decoded_single = 1'b0;
            decoded_double = 1'b0;

            if ((syndrome != 4'd0) && total_parity_odd) begin
                corrected[syndrome - 4'd1] = ~corrected[syndrome - 4'd1];
                decoded_single = 1'b1;
            end else if ((syndrome == 4'd0) && total_parity_odd) begin
                decoded_single = 1'b1;
            end else if ((syndrome != 4'd0) && !total_parity_odd) begin
                decoded_double = 1'b1;
            end

            decoded_data = extract_data(corrected);
        end
    endtask

    task golden_advance;
        input next_rst;
        input next_in_valid;
        input [12:0] next_codeword;
        reg [7:0] next_data;
        reg next_single;
        reg next_double;
        begin
            if (next_rst) begin
                golden_out_valid = 1'b0;
                golden_out_data = 8'h00;
                golden_out_single = 1'b0;
                golden_out_double = 1'b0;
            end else begin
                if (next_in_valid) begin
                    decode_expected(next_codeword, next_data, next_single, next_double);
                    golden_out_valid = 1'b1;
                    golden_out_data = next_data;
                    golden_out_single = next_single;
                    golden_out_double = next_double;
                end else begin
                    golden_out_valid = 1'b0;
                    golden_out_data = 8'h00;
                    golden_out_single = 1'b0;
                    golden_out_double = 1'b0;
                end
            end
        end
    endtask

    task check_outputs;
        begin
            tests_run = tests_run + 1;

            if (out_valid !== golden_out_valid) begin
                $display("ERROR [cycle %0d]: out_valid mismatch. Expected %b, got %b",
                         cycle_count, golden_out_valid, out_valid);
                errors = errors + 1;
            end

            if (golden_out_valid) begin
                if (data_out !== golden_out_data) begin
                    $display("ERROR [cycle %0d]: data_out mismatch. Expected %02h, got %02h",
                             cycle_count, golden_out_data, data_out);
                    errors = errors + 1;
                end
                if (single_error_corrected !== golden_out_single) begin
                    $display("ERROR [cycle %0d]: single_error_corrected mismatch. Expected %b, got %b",
                             cycle_count, golden_out_single, single_error_corrected);
                    errors = errors + 1;
                end
                if (double_error_detected !== golden_out_double) begin
                    $display("ERROR [cycle %0d]: double_error_detected mismatch. Expected %b, got %b",
                             cycle_count, golden_out_double, double_error_detected);
                    errors = errors + 1;
                end
            end else begin
                if (single_error_corrected !== 1'b0) begin
                    $display("ERROR [cycle %0d]: single_error_corrected should be 0 when out_valid is low",
                             cycle_count);
                    errors = errors + 1;
                end
                if (double_error_detected !== 1'b0) begin
                    $display("ERROR [cycle %0d]: double_error_detected should be 0 when out_valid is low",
                             cycle_count);
                    errors = errors + 1;
                end
            end
        end
    endtask

    task step_cycle;
        input next_rst;
        input next_in_valid;
        input [12:0] next_codeword;
        begin
            rst = next_rst;
            in_valid = next_in_valid;
            codeword_in = next_codeword;

            @(posedge clk);
            #1;

            cycle_count = cycle_count + 1;
            golden_advance(next_rst, next_in_valid, next_codeword);
            check_outputs();
        end
    endtask

    initial begin
        rst = 1'b0;
        in_valid = 1'b0;
        codeword_in = 13'd0;

        golden_out_valid = 1'b0;
        golden_out_data = 8'h00;
        golden_out_single = 1'b0;
        golden_out_double = 1'b0;

        clean_patterns[0] = 8'h00;
        clean_patterns[1] = 8'hFF;
        clean_patterns[2] = 8'h55;
        clean_patterns[3] = 8'hAA;
        clean_patterns[4] = 8'h01;
        clean_patterns[5] = 8'h80;
        clean_patterns[6] = 8'h0F;
        clean_patterns[7] = 8'hF0;
        clean_patterns[8] = 8'h3C;
        clean_patterns[9] = 8'hC3;
        clean_patterns[10] = 8'h69;
        clean_patterns[11] = 8'h96;

        single_patterns[0] = 8'hA5;
        single_patterns[1] = 8'h3C;

        mixed_patterns[0] = 8'h12;
        mixed_patterns[1] = 8'h34;
        mixed_patterns[2] = 8'h56;
        mixed_patterns[3] = 8'h78;

        double_a[0] = 3;   double_b[0] = 5;
        double_a[1] = 3;   double_b[1] = 6;
        double_a[2] = 5;   double_b[2] = 10;
        double_a[3] = 12;  double_b[3] = 13;
        double_a[4] = 1;   double_b[4] = 2;
        double_a[5] = 1;   double_b[5] = 13;
        double_a[6] = 4;   double_b[6] = 9;
        double_a[7] = 8;   double_b[7] = 11;
        double_a[8] = 2;   double_b[8] = 7;
        double_a[9] = 6;   double_b[9] = 12;
        double_a[10] = 7;  double_b[10] = 13;
        double_a[11] = 9;  double_b[11] = 10;

        step_cycle(1'b1, 1'b0, 13'd0);
        step_cycle(1'b1, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b0, 13'd0);

        for (i = 0; i < 12; i = i + 1) begin
            step_cycle(1'b0, 1'b1, encode_codeword(clean_patterns[i]));
            if ((i % 3) == 2)
                step_cycle(1'b0, 1'b0, 13'd0);
        end
        step_cycle(1'b0, 1'b0, 13'd0);

        for (i = 0; i < 2; i = i + 1) begin
            for (pos = 1; pos <= 13; pos = pos + 1)
                step_cycle(1'b0, 1'b1, flip_one_bit(encode_codeword(single_patterns[i]), pos));
            step_cycle(1'b0, 1'b0, 13'd0);
        end

        for (i = 0; i < 12; i = i + 1)
            step_cycle(1'b0, 1'b1,
                       flip_one_bit(flip_one_bit(encode_codeword(clean_patterns[i]), double_a[i]), double_b[i]));
        step_cycle(1'b0, 1'b0, 13'd0);

        step_cycle(1'b0, 1'b1, encode_codeword(mixed_patterns[0]));
        step_cycle(1'b0, 1'b1, flip_one_bit(encode_codeword(mixed_patterns[1]), 11));
        step_cycle(1'b0, 1'b1, flip_one_bit(flip_one_bit(encode_codeword(mixed_patterns[2]), 1), 13));
        step_cycle(1'b0, 1'b1, encode_codeword(mixed_patterns[3]));
        step_cycle(1'b0, 1'b0, 13'd0);

        step_cycle(1'b0, 1'b1, encode_codeword(8'h5A));
        step_cycle(1'b0, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b1, flip_one_bit(encode_codeword(8'hC9), 13));
        step_cycle(1'b0, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b1, flip_one_bit(flip_one_bit(encode_codeword(8'h96), 2), 11));
        step_cycle(1'b0, 1'b0, 13'd0);

        step_cycle(1'b0, 1'b1, encode_codeword(8'hDE));
        step_cycle(1'b1, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b1, flip_one_bit(encode_codeword(8'hAD), 4));
        step_cycle(1'b0, 1'b1, flip_one_bit(flip_one_bit(encode_codeword(8'hBE), 3), 8));
        step_cycle(1'b0, 1'b0, 13'd0);

        for (mix_idx = 0; mix_idx < 4; mix_idx = mix_idx + 1)
            step_cycle(1'b0, 1'b1, flip_one_bit(encode_codeword(clean_patterns[mix_idx + 4]), mix_idx + 1));

        step_cycle(1'b0, 1'b0, 13'd0);
        step_cycle(1'b0, 1'b0, 13'd0);

        $display("=======================================");
        $display("  Hamming SECDED Decoder Testbench");
        $display("=======================================");
        $display("Tests Run: %0d", tests_run);

        if (errors == 0)
            $display("TEST_RESULT: PASS");
        else
            $display("TEST_RESULT: FAIL (%0d errors)", errors);

        $finish;
    end
endmodule
