`timescale 1ns / 1ps

module tb_fft4_frame;
    reg clk;
    reg rst;
    reg in_valid;
    reg signed [11:0] in_real;
    reg signed [11:0] in_imag;

    wire out_valid;
    wire signed [13:0] out_real;
    wire signed [13:0] out_imag;

    integer errors = 0;
    integer tests_run = 0;
    integer i;
    integer case_id;

    reg signed [13:0] expected_real [0:3];
    reg signed [13:0] expected_imag [0:3];

    reg signed [11:0] rand_r0;
    reg signed [11:0] rand_i0;
    reg signed [11:0] rand_r1;
    reg signed [11:0] rand_i1;
    reg signed [11:0] rand_r2;
    reg signed [11:0] rand_i2;
    reg signed [11:0] rand_r3;
    reg signed [11:0] rand_i3;

    fft4_frame uut (
        .clk(clk),
        .rst(rst),
        .in_valid(in_valid),
        .in_real(in_real),
        .in_imag(in_imag),
        .out_valid(out_valid),
        .out_real(out_real),
        .out_imag(out_imag)
    );

    initial begin
        clk = 1'b0;
        forever #5 clk = ~clk;
    end

    function automatic signed [13:0] ext14;
        input signed [11:0] value;
        begin
            ext14 = {{2{value[11]}}, value};
        end
    endfunction

    function automatic signed [11:0] rand_sample;
        integer tmp;
        begin
            tmp = $urandom_range(0, 4095) - 2048;
            rand_sample = tmp[11:0];
        end
    endfunction

    task step_cycle;
        input next_rst;
        input next_in_valid;
        input signed [11:0] next_in_real;
        input signed [11:0] next_in_imag;
        begin
            rst = next_rst;
            in_valid = next_in_valid;
            in_real = next_in_real;
            in_imag = next_in_imag;

            @(posedge clk);
            #1;
            tests_run = tests_run + 1;
        end
    endtask

    task check_no_output;
        input [255:0] tag;
        begin
            if (out_valid !== 1'b0) begin
                $display("ERROR [%s case %0d]: out_valid expected 0, got 1", tag, case_id);
                errors = errors + 1;
            end
        end
    endtask

    task check_output;
        input [255:0] tag;
        input integer idx;
        begin
            if (out_valid !== 1'b1) begin
                $display("ERROR [%s case %0d]: out_valid expected 1 for bin %0d", tag, case_id, idx);
                errors = errors + 1;
            end
            if (out_real !== expected_real[idx]) begin
                $display("ERROR [%s case %0d]: out_real mismatch for bin %0d. Expected %0d, got %0d",
                         tag, case_id, idx, expected_real[idx], out_real);
                errors = errors + 1;
            end
            if (out_imag !== expected_imag[idx]) begin
                $display("ERROR [%s case %0d]: out_imag mismatch for bin %0d. Expected %0d, got %0d",
                         tag, case_id, idx, expected_imag[idx], out_imag);
                errors = errors + 1;
            end
        end
    endtask

    task compute_expected;
        input signed [11:0] x0r;
        input signed [11:0] x0i;
        input signed [11:0] x1r;
        input signed [11:0] x1i;
        input signed [11:0] x2r;
        input signed [11:0] x2i;
        input signed [11:0] x3r;
        input signed [11:0] x3i;
        begin
            expected_real[0] = ext14(x0r) + ext14(x1r) + ext14(x2r) + ext14(x3r);
            expected_imag[0] = ext14(x0i) + ext14(x1i) + ext14(x2i) + ext14(x3i);

            expected_real[1] = ext14(x0r) + ext14(x1i) - ext14(x2r) - ext14(x3i);
            expected_imag[1] = ext14(x0i) - ext14(x1r) - ext14(x2i) + ext14(x3r);

            expected_real[2] = ext14(x0r) - ext14(x1r) + ext14(x2r) - ext14(x3r);
            expected_imag[2] = ext14(x0i) - ext14(x1i) + ext14(x2i) - ext14(x3i);

            expected_real[3] = ext14(x0r) - ext14(x1i) - ext14(x2r) + ext14(x3i);
            expected_imag[3] = ext14(x0i) + ext14(x1r) - ext14(x2i) - ext14(x3r);
        end
    endtask

    task run_frame_and_check;
        input [255:0] tag;
        input signed [11:0] x0r;
        input signed [11:0] x0i;
        input signed [11:0] x1r;
        input signed [11:0] x1i;
        input signed [11:0] x2r;
        input signed [11:0] x2i;
        input signed [11:0] x3r;
        input signed [11:0] x3i;
        begin
            compute_expected(x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i);

            step_cycle(1'b0, 1'b1, x0r, x0i);
            check_no_output(tag);
            step_cycle(1'b0, 1'b1, x1r, x1i);
            check_no_output(tag);
            step_cycle(1'b0, 1'b1, x2r, x2i);
            check_no_output(tag);
            step_cycle(1'b0, 1'b1, x3r, x3i);
            check_no_output(tag);

            step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
            check_output(tag, 0);
            step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
            check_output(tag, 1);
            step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
            check_output(tag, 2);
            step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
            check_output(tag, 3);
            step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
            check_no_output(tag);
        end
    endtask

    task do_reset;
        begin
            step_cycle(1'b1, 1'b0, 12'sd0, 12'sd0);
            check_no_output("reset");
            step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
            check_no_output("post_reset_idle");
        end
    endtask

    initial begin
        rst = 1'b0;
        in_valid = 1'b0;
        in_real = 12'sd0;
        in_imag = 12'sd0;
        case_id = 0;

        do_reset;

        case_id = 1;
        run_frame_and_check("all_zero",
                            12'sd0, 12'sd0,
                            12'sd0, 12'sd0,
                            12'sd0, 12'sd0,
                            12'sd0, 12'sd0);

        case_id = 2;
        run_frame_and_check("impulse",
                            12'sd25, -12'sd11,
                            12'sd0,  12'sd0,
                            12'sd0,  12'sd0,
                            12'sd0,  12'sd0);

        case_id = 3;
        run_frame_and_check("dc_complex",
                            12'sd17, -12'sd9,
                            12'sd17, -12'sd9,
                            12'sd17, -12'sd9,
                            12'sd17, -12'sd9);

        case_id = 4;
        run_frame_and_check("bin1_tone",
                            12'sd30,  12'sd0,
                            12'sd0,  -12'sd30,
                            -12'sd30, 12'sd0,
                            12'sd0,   12'sd30);

        case_id = 5;
        run_frame_and_check("edge_values",
                            12'sd2047, -12'sd2048,
                            -12'sd2048, 12'sd2047,
                            12'sd1024, -12'sd1024,
                            -12'sd511,  12'sd511);

        case_id = 6;
        step_cycle(1'b0, 1'b1, 12'sd100, -12'sd50);
        check_no_output("partial_before_reset");
        step_cycle(1'b0, 1'b1, -12'sd20, 12'sd7);
        check_no_output("partial_before_reset");
        step_cycle(1'b1, 1'b0, 12'sd0, 12'sd0);
        check_no_output("reset_clears_partial");
        step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
        check_no_output("idle_after_partial_reset");

        case_id = 7;
        compute_expected(12'sd8, 12'sd3,
                         12'sd2, -12'sd5,
                         -12'sd7, 12'sd1,
                         12'sd4, 12'sd6);
        step_cycle(1'b0, 1'b1, 12'sd8, 12'sd3);
        check_no_output("reset_mid_output");
        step_cycle(1'b0, 1'b1, 12'sd2, -12'sd5);
        check_no_output("reset_mid_output");
        step_cycle(1'b0, 1'b1, -12'sd7, 12'sd1);
        check_no_output("reset_mid_output");
        step_cycle(1'b0, 1'b1, 12'sd4, 12'sd6);
        check_no_output("reset_mid_output");
        step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
        check_output("reset_mid_output", 0);
        step_cycle(1'b1, 1'b0, 12'sd0, 12'sd0);
        check_no_output("reset_mid_output");
        step_cycle(1'b0, 1'b0, 12'sd0, 12'sd0);
        check_no_output("reset_mid_output");

        for (i = 0; i < 100; i = i + 1) begin
            case_id = 100 + i;
            rand_r0 = rand_sample();
            rand_i0 = rand_sample();
            rand_r1 = rand_sample();
            rand_i1 = rand_sample();
            rand_r2 = rand_sample();
            rand_i2 = rand_sample();
            rand_r3 = rand_sample();
            rand_i3 = rand_sample();

            run_frame_and_check("random_frame",
                                rand_r0, rand_i0,
                                rand_r1, rand_i1,
                                rand_r2, rand_i2,
                                rand_r3, rand_i3);
        end

        $display("\n===========================================");
        if (errors == 0) begin
            $display("TEST_RESULT: PASS");
            $display("Executed %0d timed checks", tests_run);
        end else begin
            $display("TEST_RESULT: FAIL (%0d errors)", errors);
            $display("Executed %0d timed checks", tests_run);
        end
        $display("===========================================");

        $finish;
    end
endmodule
