`timescale 1ns / 1ps

module fft4_frame (
    input  wire                  clk,
    input  wire                  rst,
    input  wire                  in_valid,
    input  wire signed [11:0]    in_real,
    input  wire signed [11:0]    in_imag,
    output reg                   out_valid,
    output reg signed [13:0]     out_real,
    output reg signed [13:0]     out_imag
);
    reg signed [11:0] x0_real, x1_real, x2_real;
    reg signed [11:0] x0_imag, x1_imag, x2_imag;

    reg signed [13:0] bin0_real, bin1_real, bin2_real, bin3_real;
    reg signed [13:0] bin0_imag, bin1_imag, bin2_imag, bin3_imag;

    reg [1:0] sample_count;
    reg [1:0] out_index;
    reg output_active;

    function automatic signed [13:0] ext14;
        input signed [11:0] value;
        begin
            ext14 = {{2{value[11]}}, value};
        end
    endfunction

    always @(posedge clk) begin
        if (rst) begin
            x0_real <= 12'sd0;
            x1_real <= 12'sd0;
            x2_real <= 12'sd0;
            x0_imag <= 12'sd0;
            x1_imag <= 12'sd0;
            x2_imag <= 12'sd0;

            bin0_real <= 14'sd0;
            bin1_real <= 14'sd0;
            bin2_real <= 14'sd0;
            bin3_real <= 14'sd0;
            bin0_imag <= 14'sd0;
            bin1_imag <= 14'sd0;
            bin2_imag <= 14'sd0;
            bin3_imag <= 14'sd0;

            sample_count <= 2'd0;
            out_index <= 2'd0;
            output_active <= 1'b0;
            out_valid <= 1'b0;
            out_real <= 14'sd0;
            out_imag <= 14'sd0;
        end else begin
            if (output_active) begin
                out_valid <= 1'b1;
                case (out_index)
                    2'd0: begin
                        out_real <= bin0_real;
                        out_imag <= bin0_imag;
                    end
                    2'd1: begin
                        out_real <= bin1_real;
                        out_imag <= bin1_imag;
                    end
                    2'd2: begin
                        out_real <= bin2_real;
                        out_imag <= bin2_imag;
                    end
                    default: begin
                        out_real <= bin3_real;
                        out_imag <= bin3_imag;
                    end
                endcase

                if (out_index == 2'd3) begin
                    out_index <= 2'd0;
                    output_active <= 1'b0;
                end else begin
                    out_index <= out_index + 2'd1;
                end
            end else begin
                out_valid <= 1'b0;
            end

            if (!output_active && in_valid) begin
                case (sample_count)
                    2'd0: begin
                        x0_real <= in_real;
                        x0_imag <= in_imag;
                        sample_count <= 2'd1;
                    end
                    2'd1: begin
                        x1_real <= in_real;
                        x1_imag <= in_imag;
                        sample_count <= 2'd2;
                    end
                    2'd2: begin
                        x2_real <= in_real;
                        x2_imag <= in_imag;
                        sample_count <= 2'd3;
                    end
                    default: begin
                        bin0_real <= ext14(x0_real) + ext14(x1_real) + ext14(x2_real) + ext14(in_real);
                        bin0_imag <= ext14(x0_imag) + ext14(x1_imag) + ext14(x2_imag) + ext14(in_imag);

                        bin1_real <= ext14(x0_real) + ext14(x1_imag) - ext14(x2_real) - ext14(in_imag);
                        bin1_imag <= ext14(x0_imag) - ext14(x1_real) - ext14(x2_imag) + ext14(in_real);

                        bin2_real <= ext14(x0_real) - ext14(x1_real) + ext14(x2_real) - ext14(in_real);
                        bin2_imag <= ext14(x0_imag) - ext14(x1_imag) + ext14(x2_imag) - ext14(in_imag);

                        bin3_real <= ext14(x0_real) - ext14(x1_imag) - ext14(x2_real) + ext14(in_imag);
                        bin3_imag <= ext14(x0_imag) + ext14(x1_real) - ext14(x2_imag) - ext14(in_real);

                        sample_count <= 2'd0;
                        out_index <= 2'd0;
                        output_active <= 1'b1;
                    end
                endcase
            end
        end
    end
endmodule
