`timescale 1ns / 1ps

// Sobel Edge Detection - Reference Implementation
// 64x64 input, Gx/Gy kernels, |Gx|+|Gy| magnitude

module sobel_edge (
    input wire clk,
    input wire rst,
    input wire start,
    input wire [7:0] pixel_in,
    input wire pixel_valid,
    output reg pixel_ready,
    output reg [7:0] pixel_out,
    output reg out_valid,
    input wire out_ready,
    output reg done
);

    localparam IMG_W = 64;
    localparam IMG_H = 64;
    localparam IMG_SIZE = IMG_W * IMG_H;
    
    localparam IDLE = 2'd0, RUNNING = 2'd1, FLUSH = 2'd2, DONE_ST = 2'd3;
    reg [1:0] state;
    
    reg [11:0] in_count, out_count;
    reg [5:0] in_row, in_col, out_row, out_col;
    
    reg [7:0] line_buf [0:2][0:63];
    
    integer i, j;
    
    always @(posedge clk) begin
        if (rst) begin
            state <= IDLE;
            in_count <= 0; out_count <= 0;
            in_row <= 0; in_col <= 0; out_row <= 0; out_col <= 0;
            pixel_ready <= 0; out_valid <= 0; pixel_out <= 0; done <= 0;
            for (i = 0; i < 3; i = i + 1)
                for (j = 0; j < 64; j = j + 1)
                    line_buf[i][j] <= 0;
        end else begin
            done <= 0;
            
            // Clear out_valid on handshake
            if (out_valid && out_ready) out_valid <= 0;
            
            case (state)
                IDLE: begin
                    if (start) begin
                        state <= RUNNING;
                        in_count <= 0; out_count <= 0;
                        in_row <= 0; in_col <= 0; out_row <= 0; out_col <= 0;
                        pixel_ready <= 1;
                    end
                end
                
                RUNNING: begin
                    // Accept input
                    if (pixel_valid && pixel_ready) begin
                        line_buf[in_row % 3][in_col] <= pixel_in;
                        in_count <= in_count + 1;
                        if (in_col == IMG_W - 1) begin in_col <= 0; in_row <= in_row + 1; end
                        else in_col <= in_col + 1;
                        
                        if (in_count == IMG_SIZE - 1) begin
                            pixel_ready <= 0;
                            state <= FLUSH;
                        end
                    end
                    
                    // Generate output when data available
                    if ((!out_valid || out_ready) && out_count < IMG_SIZE) begin
                        reg [11:0] needed_idx;
                        reg can_output;
                        
                        // For row R, col C: need pixel at (R+1, C+1) unless at boundary
                        if (out_row == IMG_H - 1)
                            needed_idx = out_row * IMG_W + out_col;
                        else
                            needed_idx = (out_row + 1) * IMG_W + (out_col < IMG_W-1 ? out_col + 1 : out_col);
                        
                        can_output = (in_count > needed_idx);
                        
                        if (can_output) begin
                            emit_output();
                        end
                    end
                end
                
                FLUSH: begin
                    // Continue outputting until done
                    if ((!out_valid || out_ready) && out_count < IMG_SIZE) begin
                        emit_output();
                    end
                    
                    if (out_count >= IMG_SIZE && (!out_valid || out_ready)) begin
                        done <= 1;
                        state <= DONE_ST;
                    end
                end
                
                DONE_ST: state <= IDLE;
            endcase
        end
    end
    
    task emit_output;
        reg [7:0] w00, w01, w02, w10, w11, w12, w20, w21, w22;
        reg signed [11:0] gx, gy;
        reg [11:0] mag;
    begin
        // Get window with zero padding
        if (out_row == 0) begin w00 = 0; w01 = 0; w02 = 0; end
        else begin
            w00 = (out_col == 0) ? 0 : line_buf[(out_row-1) % 3][out_col-1];
            w01 = line_buf[(out_row-1) % 3][out_col];
            w02 = (out_col == IMG_W-1) ? 0 : line_buf[(out_row-1) % 3][out_col+1];
        end
        
        w10 = (out_col == 0) ? 0 : line_buf[out_row % 3][out_col-1];
        w11 = line_buf[out_row % 3][out_col];
        w12 = (out_col == IMG_W-1) ? 0 : line_buf[out_row % 3][out_col+1];
        
        if (out_row == IMG_H-1) begin w20 = 0; w21 = 0; w22 = 0; end
        else begin
            w20 = (out_col == 0) ? 0 : line_buf[(out_row+1) % 3][out_col-1];
            w21 = line_buf[(out_row+1) % 3][out_col];
            w22 = (out_col == IMG_W-1) ? 0 : line_buf[(out_row+1) % 3][out_col+1];
        end
        
        // Gx = [-1,0,1; -2,0,2; -1,0,1], Gy = [-1,-2,-1; 0,0,0; 1,2,1]
        gx = -$signed({4'd0,w00}) + $signed({4'd0,w02}) - 2*$signed({4'd0,w10}) + 2*$signed({4'd0,w12}) - $signed({4'd0,w20}) + $signed({4'd0,w22});
        gy = -$signed({4'd0,w00}) - 2*$signed({4'd0,w01}) - $signed({4'd0,w02}) + $signed({4'd0,w20}) + 2*$signed({4'd0,w21}) + $signed({4'd0,w22});
        
        mag = ((gx < 0) ? -gx : gx) + ((gy < 0) ? -gy : gy);
        
        pixel_out <= (mag > 255) ? 8'd255 : mag[7:0];
        out_valid <= 1;
        
        out_count <= out_count + 1;
        if (out_col == IMG_W-1) begin out_col <= 0; out_row <= out_row + 1; end
        else out_col <= out_col + 1;
    end
    endtask

endmodule
