module softmax_approx (
    input  wire [7:0] x0,
    input  wire [7:0] x1,
    input  wire [7:0] x2,
    input  wire [7:0] x3,
    output wire [7:0] y0,
    output wire [7:0] y1,
    output wire [7:0] y2,
    output wire [7:0] y3
);

    wire signed [8:0] sx0 = {x0[7], x0};
    wire signed [8:0] sx1 = {x1[7], x1};
    wire signed [8:0] sx2 = {x2[7], x2};
    wire signed [8:0] sx3 = {x3[7], x3};

    wire signed [8:0] max01 = (sx0 > sx1) ? sx0 : sx1;
    wire signed [8:0] max23 = (sx2 > sx3) ? sx2 : sx3;
    wire signed [8:0] max_v  = (max01 > max23) ? max01 : max23;

    wire signed [8:0] d0 = sx0 - max_v;
    wire signed [8:0] d1 = sx1 - max_v;
    wire signed [8:0] d2 = sx2 - max_v;
    wire signed [8:0] d3 = sx3 - max_v;

    function [7:0] exp_piecewise;
        input signed [8:0] d;
        reg signed [15:0] d16;
        reg signed [15:0] e;
        begin
            d16 = {{7{d[8]}}, d};
            if (d16 > -16'sd4)
                e = 16'sd256 + (d16 * 16'sd16);
            else if (d16 > -16'sd8)
                e = 16'sd192 + ((d16 + 16'sd4) * 16'sd12);
            else if (d16 > -16'sd16)
                e = 16'sd144 + ((d16 + 16'sd8) * 16'sd9);
            else if (d16 > -16'sd32)
                e = 16'sd92 + (((d16 + 16'sd16) * 16'sd92) >>> 4);
            else if (d16 > -16'sd64)
                e = 16'sd34 + (((d16 + 16'sd32) * 16'sd34) >>> 4);
            else
                e = 16'sd0;

            if (e < 0)
                exp_piecewise = 8'd0;
            else if (e > 255)
                exp_piecewise = 8'd255;
            else
                exp_piecewise = e[7:0];
        end
    endfunction

    wire [7:0] e0 = exp_piecewise(d0);
    wire [7:0] e1 = exp_piecewise(d1);
    wire [7:0] e2 = exp_piecewise(d2);
    wire [7:0] e3 = exp_piecewise(d3);

    wire [9:0] sum_e = {2'b00, e0} + {2'b00, e1} + {2'b00, e2} + {2'b00, e3};

    wire [15:0] num0 = {e0, 8'b0};
    wire [15:0] num1 = {e1, 8'b0};
    wire [15:0] num2 = {e2, 8'b0};
    wire [15:0] num3 = {e3, 8'b0};

    wire [15:0] den = {6'b0, sum_e};

    wire [15:0] q0 = (sum_e == 0) ? 16'd64 : (num0 / den);
    wire [15:0] q1 = (sum_e == 0) ? 16'd64 : (num1 / den);
    wire [15:0] q2 = (sum_e == 0) ? 16'd64 : (num2 / den);
    wire [15:0] q3 = (sum_e == 0) ? 16'd64 : (num3 / den);

    assign y0 = (q0 > 255) ? 8'd255 : q0[7:0];
    assign y1 = (q1 > 255) ? 8'd255 : q1[7:0];
    assign y2 = (q2 > 255) ? 8'd255 : q2[7:0];
    assign y3 = (q3 > 255) ? 8'd255 : q3[7:0];

endmodule
