`timescale 1ns / 1ps

// AES-128 Encryption - Reference Implementation
// Iterative architecture, 10 rounds per FIPS-197

module aes128_encrypt (
    input wire clk,
    input wire rst,
    input wire load_key,
    input wire load_data,
    input wire [31:0] data_in,
    input wire start,
    output reg [31:0] data_out,
    input wire read_data,
    output reg done,
    output reg busy
);

    // State machine
    localparam IDLE = 3'd0, ENCRYPT = 3'd1, WAIT_READ = 3'd2;
    reg [2:0] state;
    
    // Key and data registers
    reg [127:0] key_reg;
    reg [127:0] state_reg;
    reg [127:0] round_key;
    reg [3:0] round;
    reg [1:0] load_cnt, read_cnt;
    reg load_key_s, load_data_s, start_s, read_data_s;
    reg [31:0] data_in_s;
    
    // AES S-box (ROM)
    wire [7:0] sbox_out [0:15];
    wire [7:0] sbox_in [0:15];
    
    // S-box lookup
    function [7:0] sbox;
        input [7:0] x;
        reg [7:0] rom [0:255];
    begin
        rom[8'h00]=8'h63; rom[8'h01]=8'h7c; rom[8'h02]=8'h77; rom[8'h03]=8'h7b;
        rom[8'h04]=8'hf2; rom[8'h05]=8'h6b; rom[8'h06]=8'h6f; rom[8'h07]=8'hc5;
        rom[8'h08]=8'h30; rom[8'h09]=8'h01; rom[8'h0a]=8'h67; rom[8'h0b]=8'h2b;
        rom[8'h0c]=8'hfe; rom[8'h0d]=8'hd7; rom[8'h0e]=8'hab; rom[8'h0f]=8'h76;
        rom[8'h10]=8'hca; rom[8'h11]=8'h82; rom[8'h12]=8'hc9; rom[8'h13]=8'h7d;
        rom[8'h14]=8'hfa; rom[8'h15]=8'h59; rom[8'h16]=8'h47; rom[8'h17]=8'hf0;
        rom[8'h18]=8'had; rom[8'h19]=8'hd4; rom[8'h1a]=8'ha2; rom[8'h1b]=8'haf;
        rom[8'h1c]=8'h9c; rom[8'h1d]=8'ha4; rom[8'h1e]=8'h72; rom[8'h1f]=8'hc0;
        rom[8'h20]=8'hb7; rom[8'h21]=8'hfd; rom[8'h22]=8'h93; rom[8'h23]=8'h26;
        rom[8'h24]=8'h36; rom[8'h25]=8'h3f; rom[8'h26]=8'hf7; rom[8'h27]=8'hcc;
        rom[8'h28]=8'h34; rom[8'h29]=8'ha5; rom[8'h2a]=8'he5; rom[8'h2b]=8'hf1;
        rom[8'h2c]=8'h71; rom[8'h2d]=8'hd8; rom[8'h2e]=8'h31; rom[8'h2f]=8'h15;
        rom[8'h30]=8'h04; rom[8'h31]=8'hc7; rom[8'h32]=8'h23; rom[8'h33]=8'hc3;
        rom[8'h34]=8'h18; rom[8'h35]=8'h96; rom[8'h36]=8'h05; rom[8'h37]=8'h9a;
        rom[8'h38]=8'h07; rom[8'h39]=8'h12; rom[8'h3a]=8'h80; rom[8'h3b]=8'he2;
        rom[8'h3c]=8'heb; rom[8'h3d]=8'h27; rom[8'h3e]=8'hb2; rom[8'h3f]=8'h75;
        rom[8'h40]=8'h09; rom[8'h41]=8'h83; rom[8'h42]=8'h2c; rom[8'h43]=8'h1a;
        rom[8'h44]=8'h1b; rom[8'h45]=8'h6e; rom[8'h46]=8'h5a; rom[8'h47]=8'ha0;
        rom[8'h48]=8'h52; rom[8'h49]=8'h3b; rom[8'h4a]=8'hd6; rom[8'h4b]=8'hb3;
        rom[8'h4c]=8'h29; rom[8'h4d]=8'he3; rom[8'h4e]=8'h2f; rom[8'h4f]=8'h84;
        rom[8'h50]=8'h53; rom[8'h51]=8'hd1; rom[8'h52]=8'h00; rom[8'h53]=8'hed;
        rom[8'h54]=8'h20; rom[8'h55]=8'hfc; rom[8'h56]=8'hb1; rom[8'h57]=8'h5b;
        rom[8'h58]=8'h6a; rom[8'h59]=8'hcb; rom[8'h5a]=8'hbe; rom[8'h5b]=8'h39;
        rom[8'h5c]=8'h4a; rom[8'h5d]=8'h4c; rom[8'h5e]=8'h58; rom[8'h5f]=8'hcf;
        rom[8'h60]=8'hd0; rom[8'h61]=8'hef; rom[8'h62]=8'haa; rom[8'h63]=8'hfb;
        rom[8'h64]=8'h43; rom[8'h65]=8'h4d; rom[8'h66]=8'h33; rom[8'h67]=8'h85;
        rom[8'h68]=8'h45; rom[8'h69]=8'hf9; rom[8'h6a]=8'h02; rom[8'h6b]=8'h7f;
        rom[8'h6c]=8'h50; rom[8'h6d]=8'h3c; rom[8'h6e]=8'h9f; rom[8'h6f]=8'ha8;
        rom[8'h70]=8'h51; rom[8'h71]=8'ha3; rom[8'h72]=8'h40; rom[8'h73]=8'h8f;
        rom[8'h74]=8'h92; rom[8'h75]=8'h9d; rom[8'h76]=8'h38; rom[8'h77]=8'hf5;
        rom[8'h78]=8'hbc; rom[8'h79]=8'hb6; rom[8'h7a]=8'hda; rom[8'h7b]=8'h21;
        rom[8'h7c]=8'h10; rom[8'h7d]=8'hff; rom[8'h7e]=8'hf3; rom[8'h7f]=8'hd2;
        rom[8'h80]=8'hcd; rom[8'h81]=8'h0c; rom[8'h82]=8'h13; rom[8'h83]=8'hec;
        rom[8'h84]=8'h5f; rom[8'h85]=8'h97; rom[8'h86]=8'h44; rom[8'h87]=8'h17;
        rom[8'h88]=8'hc4; rom[8'h89]=8'ha7; rom[8'h8a]=8'h7e; rom[8'h8b]=8'h3d;
        rom[8'h8c]=8'h64; rom[8'h8d]=8'h5d; rom[8'h8e]=8'h19; rom[8'h8f]=8'h73;
        rom[8'h90]=8'h60; rom[8'h91]=8'h81; rom[8'h92]=8'h4f; rom[8'h93]=8'hdc;
        rom[8'h94]=8'h22; rom[8'h95]=8'h2a; rom[8'h96]=8'h90; rom[8'h97]=8'h88;
        rom[8'h98]=8'h46; rom[8'h99]=8'hee; rom[8'h9a]=8'hb8; rom[8'h9b]=8'h14;
        rom[8'h9c]=8'hde; rom[8'h9d]=8'h5e; rom[8'h9e]=8'h0b; rom[8'h9f]=8'hdb;
        rom[8'ha0]=8'he0; rom[8'ha1]=8'h32; rom[8'ha2]=8'h3a; rom[8'ha3]=8'h0a;
        rom[8'ha4]=8'h49; rom[8'ha5]=8'h06; rom[8'ha6]=8'h24; rom[8'ha7]=8'h5c;
        rom[8'ha8]=8'hc2; rom[8'ha9]=8'hd3; rom[8'haa]=8'hac; rom[8'hab]=8'h62;
        rom[8'hac]=8'h91; rom[8'had]=8'h95; rom[8'hae]=8'he4; rom[8'haf]=8'h79;
        rom[8'hb0]=8'he7; rom[8'hb1]=8'hc8; rom[8'hb2]=8'h37; rom[8'hb3]=8'h6d;
        rom[8'hb4]=8'h8d; rom[8'hb5]=8'hd5; rom[8'hb6]=8'h4e; rom[8'hb7]=8'ha9;
        rom[8'hb8]=8'h6c; rom[8'hb9]=8'h56; rom[8'hba]=8'hf4; rom[8'hbb]=8'hea;
        rom[8'hbc]=8'h65; rom[8'hbd]=8'h7a; rom[8'hbe]=8'hae; rom[8'hbf]=8'h08;
        rom[8'hc0]=8'hba; rom[8'hc1]=8'h78; rom[8'hc2]=8'h25; rom[8'hc3]=8'h2e;
        rom[8'hc4]=8'h1c; rom[8'hc5]=8'ha6; rom[8'hc6]=8'hb4; rom[8'hc7]=8'hc6;
        rom[8'hc8]=8'he8; rom[8'hc9]=8'hdd; rom[8'hca]=8'h74; rom[8'hcb]=8'h1f;
        rom[8'hcc]=8'h4b; rom[8'hcd]=8'hbd; rom[8'hce]=8'h8b; rom[8'hcf]=8'h8a;
        rom[8'hd0]=8'h70; rom[8'hd1]=8'h3e; rom[8'hd2]=8'hb5; rom[8'hd3]=8'h66;
        rom[8'hd4]=8'h48; rom[8'hd5]=8'h03; rom[8'hd6]=8'hf6; rom[8'hd7]=8'h0e;
        rom[8'hd8]=8'h61; rom[8'hd9]=8'h35; rom[8'hda]=8'h57; rom[8'hdb]=8'hb9;
        rom[8'hdc]=8'h86; rom[8'hdd]=8'hc1; rom[8'hde]=8'h1d; rom[8'hdf]=8'h9e;
        rom[8'he0]=8'he1; rom[8'he1]=8'hf8; rom[8'he2]=8'h98; rom[8'he3]=8'h11;
        rom[8'he4]=8'h69; rom[8'he5]=8'hd9; rom[8'he6]=8'h8e; rom[8'he7]=8'h94;
        rom[8'he8]=8'h9b; rom[8'he9]=8'h1e; rom[8'hea]=8'h87; rom[8'heb]=8'he9;
        rom[8'hec]=8'hce; rom[8'hed]=8'h55; rom[8'hee]=8'h28; rom[8'hef]=8'hdf;
        rom[8'hf0]=8'h8c; rom[8'hf1]=8'ha1; rom[8'hf2]=8'h89; rom[8'hf3]=8'h0d;
        rom[8'hf4]=8'hbf; rom[8'hf5]=8'he6; rom[8'hf6]=8'h42; rom[8'hf7]=8'h68;
        rom[8'hf8]=8'h41; rom[8'hf9]=8'h99; rom[8'hfa]=8'h2d; rom[8'hfb]=8'h0f;
        rom[8'hfc]=8'hb0; rom[8'hfd]=8'h54; rom[8'hfe]=8'hbb; rom[8'hff]=8'h16;
        sbox = rom[x];
    end
    endfunction

    // GF(2^8) multiply by 2
    function [7:0] xtime;
        input [7:0] x;
        xtime = (x[7]) ? ((x << 1) ^ 8'h1b) : (x << 1);
    endfunction
    
    // SubBytes: apply sbox to all 16 bytes
    function [127:0] sub_bytes;
        input [127:0] s;
        sub_bytes = {sbox(s[127:120]), sbox(s[119:112]), sbox(s[111:104]), sbox(s[103:96]),
                     sbox(s[95:88]),   sbox(s[87:80]),   sbox(s[79:72]),   sbox(s[71:64]),
                     sbox(s[63:56]),   sbox(s[55:48]),   sbox(s[47:40]),   sbox(s[39:32]),
                     sbox(s[31:24]),   sbox(s[23:16]),   sbox(s[15:8]),    sbox(s[7:0])};
    endfunction
    
    // ShiftRows: AES state is column-major, so we shift within each row
    // State bytes: [s0 s4 s8  s12]   [0  4  8  12]
    //              [s1 s5 s9  s13] = [1  5  9  13]
    //              [s2 s6 s10 s14]   [2  6  10 14]
    //              [s3 s7 s11 s15]   [3  7  11 15]
    // Row 0: no shift, Row 1: left 1, Row 2: left 2, Row 3: left 3
    function [127:0] shift_rows;
        input [127:0] s;
        reg [7:0] b [0:15];
        reg [7:0] r [0:15];
    begin
        // Unpack bytes (big-endian: s[127:120] = byte 0, s[7:0] = byte 15)
        b[0] = s[127:120]; b[1] = s[119:112]; b[2] = s[111:104]; b[3] = s[103:96];
        b[4] = s[95:88];   b[5] = s[87:80];   b[6] = s[79:72];   b[7] = s[71:64];
        b[8] = s[63:56];   b[9] = s[55:48];   b[10] = s[47:40];  b[11] = s[39:32];
        b[12] = s[31:24];  b[13] = s[23:16];  b[14] = s[15:8];   b[15] = s[7:0];
        
        // Row 0 (bytes 0,4,8,12): no shift
        r[0] = b[0]; r[4] = b[4]; r[8] = b[8]; r[12] = b[12];
        // Row 1 (bytes 1,5,9,13): left shift by 1
        r[1] = b[5]; r[5] = b[9]; r[9] = b[13]; r[13] = b[1];
        // Row 2 (bytes 2,6,10,14): left shift by 2
        r[2] = b[10]; r[6] = b[14]; r[10] = b[2]; r[14] = b[6];
        // Row 3 (bytes 3,7,11,15): left shift by 3
        r[3] = b[15]; r[7] = b[3]; r[11] = b[7]; r[15] = b[11];
        
        shift_rows = {r[0], r[1], r[2], r[3], r[4], r[5], r[6], r[7],
                      r[8], r[9], r[10], r[11], r[12], r[13], r[14], r[15]};
    end
    endfunction
    
    // MixColumns for one column [a0, a1, a2, a3]
    function [31:0] mix_column;
        input [31:0] col;
        reg [7:0] a0, a1, a2, a3, r0, r1, r2, r3;
    begin
        a0 = col[31:24]; a1 = col[23:16]; a2 = col[15:8]; a3 = col[7:0];
        // MixColumn matrix: [2,3,1,1; 1,2,3,1; 1,1,2,3; 3,1,1,2]
        r0 = xtime(a0) ^ (xtime(a1) ^ a1) ^ a2 ^ a3;      // 2*a0 + 3*a1 + a2 + a3
        r1 = a0 ^ xtime(a1) ^ (xtime(a2) ^ a2) ^ a3;      // a0 + 2*a1 + 3*a2 + a3
        r2 = a0 ^ a1 ^ xtime(a2) ^ (xtime(a3) ^ a3);      // a0 + a1 + 2*a2 + 3*a3
        r3 = (xtime(a0) ^ a0) ^ a1 ^ a2 ^ xtime(a3);      // 3*a0 + a1 + a2 + 2*a3
        mix_column = {r0, r1, r2, r3};
    end
    endfunction
    
    // MixColumns for full state
    function [127:0] mix_columns;
        input [127:0] s;
        mix_columns = {mix_column(s[127:96]), mix_column(s[95:64]),
                       mix_column(s[63:32]), mix_column(s[31:0])};
    endfunction
    
    // Key schedule: expand one round key
    function [127:0] key_expand;
        input [127:0] k;
        input [3:0] rnd;
        reg [31:0] w0, w1, w2, w3, temp;
        reg [7:0] rc;
    begin
        w0 = k[127:96]; w1 = k[95:64]; w2 = k[63:32]; w3 = k[31:0];
        case (rnd)
            0: rc = 8'h01; 1: rc = 8'h02; 2: rc = 8'h04; 3: rc = 8'h08;
            4: rc = 8'h10; 5: rc = 8'h20; 6: rc = 8'h40; 7: rc = 8'h80;
            8: rc = 8'h1b; default: rc = 8'h36;
        endcase
        // RotWord(w3) then SubWord, then XOR with Rcon
        temp = {sbox(w3[23:16]) ^ rc, sbox(w3[15:8]), sbox(w3[7:0]), sbox(w3[31:24])};
        w0 = w0 ^ temp;
        w1 = w1 ^ w0;
        w2 = w2 ^ w1;
        w3 = w3 ^ w2;
        key_expand = {w0, w1, w2, w3};
    end
    endfunction

    // Sample bus controls/data on the opposite clock edge to avoid race with
    // edge-driven testbench stimulus.
    always @(negedge clk or posedge rst) begin
        if (rst) begin
            load_key_s  <= 1'b0;
            load_data_s <= 1'b0;
            start_s     <= 1'b0;
            read_data_s <= 1'b0;
            data_in_s   <= 32'b0;
        end else begin
            load_key_s  <= load_key;
            load_data_s <= load_data;
            start_s     <= start;
            read_data_s <= read_data;
            data_in_s   <= data_in;
        end
    end

    always @(posedge clk) begin
        if (rst) begin
            state <= IDLE;
            key_reg <= 0; state_reg <= 0; round_key <= 0;
            round <= 0; load_cnt <= 0; read_cnt <= 0;
            data_out <= 0; done <= 0; busy <= 0;
        end else begin
            
            case (state)
                IDLE: begin
                    done <= 0;
                    busy <= 0;
                    if (load_key_s) begin
                        key_reg <= {key_reg[95:0], data_in_s};
                        load_cnt <= load_cnt + 1;
                    end
                    else if (load_data_s) begin
                        state_reg <= {state_reg[95:0], data_in_s};
                        load_cnt <= load_cnt + 1;
                    end
                    else if (start_s) begin
                        busy <= 1;
                        round <= 0;
                        round_key <= key_reg;
                        state_reg <= state_reg ^ key_reg; // Round 0: AddRoundKey
                        state <= ENCRYPT;
                    end
                end
                
                ENCRYPT: begin
                    if (round < 9) begin
                        // Rounds 1-9
                        round_key <= key_expand(round_key, round);
                        state_reg <= mix_columns(shift_rows(sub_bytes(state_reg))) ^ key_expand(round_key, round);
                        round <= round + 1;
                    end else begin
                        // Round 10 (no MixColumns)
                        round_key <= key_expand(round_key, round);
                        state_reg <= shift_rows(sub_bytes(state_reg)) ^ key_expand(round_key, round);
                        done <= 1;
                        busy <= 0;
                        read_cnt <= 0;
                        state <= WAIT_READ;
                    end
                end
                
                WAIT_READ: begin
                    done <= 1;
                    case (read_cnt)
                        2'd0: data_out <= state_reg[127:96];
                        2'd1: data_out <= state_reg[95:64];
                        2'd2: data_out <= state_reg[63:32];
                        2'd3: data_out <= state_reg[31:0];
                    endcase
                    if (read_data_s) begin
                        read_cnt <= read_cnt + 1;
                        if (read_cnt == 2'd3) begin
                            done <= 0;
                            state <= IDLE;
                        end
                    end
                end
            endcase
        end
    end

endmodule
