`timescale 1ns / 1ps

// =============================================================================
// CNN 3x3 Convolution Layer - Reference Implementation
// 28x28 input, 4 output filters, stride=1, pad=1, bias + ReLU
// Uses line buffers for sliding window, parallel MACs for 4 filters
// =============================================================================

module cnn_conv3x3 (
    input wire clk,
    input wire rst,
    input wire start,
    input wire [7:0] pixel_in,
    input wire pixel_valid,
    output wire pixel_ready,
    output reg [31:0] out_pixel,
    output reg out_valid,
    input wire out_ready,
    output reg done
);

    // Hardcoded kernels and biases (must match spec and generate_golden.py)
    // Filter 0: Horizontal edge [-1,-2,-1; 0,0,0; 1,2,1]
    // Filter 1: Vertical edge   [-1,0,1; -2,0,2; -1,0,1]
    // Filter 2: Blur            [1,2,1; 2,4,2; 1,2,1]
    // Filter 3: Sharpen         [0,-1,0; -1,5,-1; 0,-1,0]
    localparam [71:0] KERNEL_0 = 72'h010201000000FFFEFF;
    localparam [71:0] KERNEL_1 = 72'h0100FF0200FE0100FF;
    localparam [71:0] KERNEL_2 = 72'h010201020402010201;
    localparam [71:0] KERNEL_3 = 72'h00FF00FF05FF00FF00;
    localparam [31:0] BIAS = 32'h0A080000;  // [0, 0, 8, 10]

    // Image parameters
    localparam IMG_W = 28;
    localparam IMG_H = 28;
    localparam IMG_SIZE = IMG_W * IMG_H;
    
    // Unpack kernels (9 x 8-bit signed weights each)
    wire signed [7:0] k0 [0:8];
    wire signed [7:0] k1 [0:8];
    wire signed [7:0] k2 [0:8];
    wire signed [7:0] k3 [0:8];
    
    genvar gi;
    generate
        for (gi = 0; gi < 9; gi = gi + 1) begin : unpack_kernels
            assign k0[gi] = KERNEL_0[gi*8 +: 8];
            assign k1[gi] = KERNEL_1[gi*8 +: 8];
            assign k2[gi] = KERNEL_2[gi*8 +: 8];
            assign k3[gi] = KERNEL_3[gi*8 +: 8];
        end
    endgenerate
    
    // Unpack biases (4 x 8-bit signed)
    wire signed [7:0] bias0 = BIAS[7:0];
    wire signed [7:0] bias1 = BIAS[15:8];
    wire signed [7:0] bias2 = BIAS[23:16];
    wire signed [7:0] bias3 = BIAS[31:24];
    
    // State machine
    localparam IDLE = 2'd0;
    localparam PROCESSING = 2'd1;
    localparam FINISHING = 2'd2;
    localparam DONE_STATE = 2'd3;
    
    reg [1:0] state;
    
    // Input counters
    reg [9:0] in_count;   // 0 to 783
    reg [4:0] in_row;     // 0 to 27
    reg [4:0] in_col;     // 0 to 27
    
    // Output counters
    reg [9:0] out_count;
    reg [4:0] out_row;
    reg [4:0] out_col;
    
    // Line buffers: 3 rows x 28 columns x 8 bits
    // We store 3 rows to compute the 3x3 window
    reg [7:0] line_buf [0:2][0:27];
    
    // Which line buffer row to write to (circular)
    reg [1:0] write_row;
    
    // 3x3 window registers
    reg [7:0] win [0:2][0:2];
    
    // Pipeline registers for output
    reg output_pending;
    reg [31:0] output_data;
    
    // Control signals
    wire input_handshake = pixel_valid && pixel_ready;
    wire output_handshake = out_valid && out_ready;
    
    // Ready when we can accept input (during PROCESSING state and not stalled)
    assign pixel_ready = (state == PROCESSING) && (in_count < IMG_SIZE) && (!output_pending || out_ready);
    
    // Compute when window is valid (need at least 2 rows + 1 col of data)
    // First valid output is at in_count = IMG_W + 1 (after receiving row 1, col 1)
    wire window_valid = (in_count >= IMG_W + 1);
    
    // Row indices for reading (with wrap-around for padding)
    wire [4:0] row_m1 = (out_row == 0) ? 5'd0 : out_row - 1;
    wire [4:0] row_p1 = (out_row == IMG_H - 1) ? IMG_H - 1 : out_row + 1;
    
    // Column indices for reading (with wrap-around for padding)
    wire [4:0] col_m1 = (out_col == 0) ? 5'd0 : out_col - 1;
    wire [4:0] col_p1 = (out_col == IMG_W - 1) ? IMG_W - 1 : out_col + 1;
    
    // Line buffer row mapping (circular buffer)
    // write_row points to where we're currently writing
    // Row assignments based on current output position
    wire [1:0] buf_row0 = (write_row + 3 - (in_row >= 2 ? 2 : in_row)) % 3;
    wire [1:0] buf_row1 = (buf_row0 + 1) % 3;
    wire [1:0] buf_row2 = (buf_row0 + 2) % 3;
    
    integer i, j;
    
    // MAC computation for one filter
    function automatic signed [15:0] compute_mac;
        input [7:0] w00, w01, w02, w10, w11, w12, w20, w21, w22;
        input signed [7:0] k00, k01, k02, k10, k11, k12, k20, k21, k22;
        reg signed [15:0] acc;
        begin
            acc = 0;
            acc = acc + $signed({1'b0, w00}) * k00;
            acc = acc + $signed({1'b0, w01}) * k01;
            acc = acc + $signed({1'b0, w02}) * k02;
            acc = acc + $signed({1'b0, w10}) * k10;
            acc = acc + $signed({1'b0, w11}) * k11;
            acc = acc + $signed({1'b0, w12}) * k12;
            acc = acc + $signed({1'b0, w20}) * k20;
            acc = acc + $signed({1'b0, w21}) * k21;
            acc = acc + $signed({1'b0, w22}) * k22;
            compute_mac = acc;
        end
    endfunction
    
    // Window pixel fetch with boundary handling (zero padding)
    function [7:0] get_pixel;
        input [4:0] r, c;
        input [4:0] out_r, out_c;
        input [9:0] current_in_count;
        reg [1:0] buf_row_idx;
        begin
            // Check if pixel is in padded region
            if (r < 0 || r >= IMG_H || c < 0 || c >= IMG_W) begin
                get_pixel = 8'd0;
            end
            // Check if pixel has been received yet
            else if (r * IMG_W + c >= current_in_count) begin
                get_pixel = 8'd0;  // Not yet received
            end
            else begin
                // Map to correct line buffer row
                buf_row_idx = (r % 3);
                get_pixel = line_buf[buf_row_idx][c];
            end
        end
    endfunction
    
    // State machine
    always @(posedge clk) begin
        if (rst) begin
            state <= IDLE;
            in_count <= 0;
            in_row <= 0;
            in_col <= 0;
            out_count <= 0;
            out_row <= 0;
            out_col <= 0;
            write_row <= 0;
            out_valid <= 0;
            out_pixel <= 0;
            done <= 0;
            output_pending <= 0;
            
            // Clear line buffers
            for (i = 0; i < 3; i = i + 1)
                for (j = 0; j < 28; j = j + 1)
                    line_buf[i][j] <= 0;
        end
        else begin
            done <= 0;  // Default
            
            case (state)
                IDLE: begin
                    if (start) begin
                        state <= PROCESSING;
                        in_count <= 0;
                        in_row <= 0;
                        in_col <= 0;
                        out_count <= 0;
                        out_row <= 0;
                        out_col <= 0;
                        write_row <= 0;
                        output_pending <= 0;
                    end
                end
                
                PROCESSING: begin
                    // Handle input
                    if (input_handshake) begin
                        // Store pixel in line buffer
                        line_buf[in_row % 3][in_col] <= pixel_in;
                        
                        // Update input position
                        in_count <= in_count + 1;
                        if (in_col == IMG_W - 1) begin
                            in_col <= 0;
                            in_row <= in_row + 1;
                            write_row <= (write_row + 1) % 3;
                        end
                        else begin
                            in_col <= in_col + 1;
                        end
                    end
                    
                    // Handle output handshake
                    if (output_handshake) begin
                        output_pending <= 0;
                        out_valid <= 0;
                    end
                    
                    // Generate output when ready
                    if (!output_pending || output_handshake) begin
                        // Check if we can produce output
                        // Need pixels from rows [out_row-1, out_row, out_row+1] and cols [out_col-1, out_col, out_col+1]
                        // Last needed pixel is at (min(out_row+1, IMG_H-1), min(out_col+1, IMG_W-1))
                        if (out_count < IMG_SIZE) begin
                            // Calculate which input pixels we need
                            // For position (out_row, out_col), we need up to (out_row+1, out_col+1)
                            // which corresponds to input index (out_row+1)*IMG_W + (out_col+1)
                            // But with padding, edge pixels have reduced requirements
                            
                            reg [9:0] needed_idx;
                            reg have_enough;
                            
                            if (out_row == IMG_H - 1) begin
                                // Last row: only need current and previous rows
                                needed_idx = out_row * IMG_W + (out_col == IMG_W - 1 ? out_col : out_col + 1);
                            end
                            else begin
                                // Need next row too
                                needed_idx = (out_row + 1) * IMG_W + (out_col == IMG_W - 1 ? out_col : out_col + 1);
                            end
                            
                            have_enough = (in_count > needed_idx) || (in_count >= IMG_SIZE);
                            
                            if (have_enough) begin
                                // Fetch 3x3 window with zero padding
                                reg [7:0] w00, w01, w02, w10, w11, w12, w20, w21, w22;
                                reg signed [15:0] mac0, mac1, mac2, mac3;
                                reg signed [15:0] sum0, sum1, sum2, sum3;
                                reg [7:0] out0, out1, out2, out3;
                                
                                // Row 0 (out_row - 1)
                                if (out_row == 0) begin
                                    w00 = 0; w01 = 0; w02 = 0;
                                end
                                else begin
                                    w00 = (out_col == 0) ? 8'd0 : line_buf[(out_row - 1) % 3][out_col - 1];
                                    w01 = line_buf[(out_row - 1) % 3][out_col];
                                    w02 = (out_col == IMG_W - 1) ? 8'd0 : line_buf[(out_row - 1) % 3][out_col + 1];
                                end
                                
                                // Row 1 (out_row)
                                w10 = (out_col == 0) ? 8'd0 : line_buf[out_row % 3][out_col - 1];
                                w11 = line_buf[out_row % 3][out_col];
                                w12 = (out_col == IMG_W - 1) ? 8'd0 : line_buf[out_row % 3][out_col + 1];
                                
                                // Row 2 (out_row + 1)
                                if (out_row == IMG_H - 1) begin
                                    w20 = 0; w21 = 0; w22 = 0;
                                end
                                else begin
                                    w20 = (out_col == 0) ? 8'd0 : line_buf[(out_row + 1) % 3][out_col - 1];
                                    w21 = line_buf[(out_row + 1) % 3][out_col];
                                    w22 = (out_col == IMG_W - 1) ? 8'd0 : line_buf[(out_row + 1) % 3][out_col + 1];
                                end
                                
                                // Compute MAC for each filter
                                mac0 = compute_mac(w00, w01, w02, w10, w11, w12, w20, w21, w22,
                                                   k0[0], k0[1], k0[2], k0[3], k0[4], k0[5], k0[6], k0[7], k0[8]);
                                mac1 = compute_mac(w00, w01, w02, w10, w11, w12, w20, w21, w22,
                                                   k1[0], k1[1], k1[2], k1[3], k1[4], k1[5], k1[6], k1[7], k1[8]);
                                mac2 = compute_mac(w00, w01, w02, w10, w11, w12, w20, w21, w22,
                                                   k2[0], k2[1], k2[2], k2[3], k2[4], k2[5], k2[6], k2[7], k2[8]);
                                mac3 = compute_mac(w00, w01, w02, w10, w11, w12, w20, w21, w22,
                                                   k3[0], k3[1], k3[2], k3[3], k3[4], k3[5], k3[6], k3[7], k3[8]);
                                
                                // Add bias
                                sum0 = mac0 + bias0;
                                sum1 = mac1 + bias1;
                                sum2 = mac2 + bias2;
                                sum3 = mac3 + bias3;
                                
                                // ReLU + saturate to [0, 255]
                                out0 = (sum0 < 0) ? 8'd0 : (sum0 > 255) ? 8'd255 : sum0[7:0];
                                out1 = (sum1 < 0) ? 8'd0 : (sum1 > 255) ? 8'd255 : sum1[7:0];
                                out2 = (sum2 < 0) ? 8'd0 : (sum2 > 255) ? 8'd255 : sum2[7:0];
                                out3 = (sum3 < 0) ? 8'd0 : (sum3 > 255) ? 8'd255 : sum3[7:0];
                                
                                // Pack output
                                out_pixel <= {out3, out2, out1, out0};
                                out_valid <= 1;
                                output_pending <= 1;
                                
                                // Update output position
                                out_count <= out_count + 1;
                                if (out_col == IMG_W - 1) begin
                                    out_col <= 0;
                                    out_row <= out_row + 1;
                                end
                                else begin
                                    out_col <= out_col + 1;
                                end
                                
                                // Check if done
                                if (out_count == IMG_SIZE - 1) begin
                                    state <= FINISHING;
                                end
                            end
                        end
                    end
                end
                
                FINISHING: begin
                    // Wait for last output to be consumed
                    if (output_handshake || !output_pending) begin
                        out_valid <= 0;
                        output_pending <= 0;
                        done <= 1;
                        state <= DONE_STATE;
                    end
                end
                
                DONE_STATE: begin
                    state <= IDLE;
                end
            endcase
        end
    end

endmodule
