module rv32_micro_core (
    input  wire        clk,
    input  wire        rst,
    output wire [31:0] imem_addr,
    input  wire [31:0] imem_rdata,
    output wire [31:0] dmem_addr,
    output wire [31:0] dmem_wdata,
    output wire        dmem_we,
    input  wire [31:0] dmem_rdata,
    output reg         halted,
    output reg         trap
);
    reg [31:0] pc;
    reg [31:0] regs [0:31];

    wire [31:0] instr;
    wire [6:0] opcode;
    wire [2:0] funct3;
    wire [6:0] funct7;
    wire [4:0] rd_idx;
    wire [4:0] rs1_idx;
    wire [4:0] rs2_idx;

    wire [31:0] rs1_val;
    wire [31:0] rs2_val;

    wire [31:0] imm_i;
    wire [31:0] imm_s;
    wire [31:0] imm_b;
    wire [31:0] imm_u;
    wire [31:0] imm_j;

    reg [31:0] next_pc;
    reg [31:0] writeback_data;
    reg        writeback_en;
    reg        illegal_instr;
    reg        trap_now;
    reg        halt_now;
    reg [31:0] mem_addr_calc;

    integer i;

    assign instr   = imem_rdata;
    assign opcode  = instr[6:0];
    assign funct3  = instr[14:12];
    assign funct7  = instr[31:25];
    assign rd_idx  = instr[11:7];
    assign rs1_idx = instr[19:15];
    assign rs2_idx = instr[24:20];

    assign rs1_val = (rs1_idx == 5'd0) ? 32'b0 : regs[rs1_idx];
    assign rs2_val = (rs2_idx == 5'd0) ? 32'b0 : regs[rs2_idx];

    assign imm_i = {{20{instr[31]}}, instr[31:20]};
    assign imm_s = {{20{instr[31]}}, instr[31:25], instr[11:7]};
    assign imm_b = {{19{instr[31]}}, instr[31], instr[7], instr[30:25], instr[11:8], 1'b0};
    assign imm_u = {instr[31:12], 12'b0};
    assign imm_j = {{11{instr[31]}}, instr[31], instr[19:12], instr[20], instr[30:21], 1'b0};

    assign imem_addr = pc;
    assign dmem_addr = mem_addr_calc;
    assign dmem_wdata = rs2_val;
    assign dmem_we = (!rst && !halted && !trap && (opcode == 7'b0100011) &&
                      (funct3 == 3'b010) && !illegal_instr && !trap_now);

    always @(*) begin
        next_pc = pc + 32'd4;
        writeback_data = 32'b0;
        writeback_en = 1'b0;
        illegal_instr = 1'b0;
        trap_now = 1'b0;
        halt_now = 1'b0;
        mem_addr_calc = 32'b0;

        case (opcode)
            7'b0110111: begin
                writeback_en = 1'b1;
                writeback_data = imm_u;
            end

            7'b0010011: begin
                if (funct3 == 3'b000) begin
                    writeback_en = 1'b1;
                    writeback_data = rs1_val + imm_i;
                end else begin
                    illegal_instr = 1'b1;
                end
            end

            7'b0110011: begin
                case (funct3)
                    3'b000: begin
                        if (funct7 == 7'b0000000) begin
                            writeback_en = 1'b1;
                            writeback_data = rs1_val + rs2_val;
                        end else if (funct7 == 7'b0100000) begin
                            writeback_en = 1'b1;
                            writeback_data = rs1_val - rs2_val;
                        end else begin
                            illegal_instr = 1'b1;
                        end
                    end

                    3'b111: begin
                        if (funct7 == 7'b0000000) begin
                            writeback_en = 1'b1;
                            writeback_data = rs1_val & rs2_val;
                        end else begin
                            illegal_instr = 1'b1;
                        end
                    end

                    3'b110: begin
                        if (funct7 == 7'b0000000) begin
                            writeback_en = 1'b1;
                            writeback_data = rs1_val | rs2_val;
                        end else begin
                            illegal_instr = 1'b1;
                        end
                    end

                    3'b100: begin
                        if (funct7 == 7'b0000000) begin
                            writeback_en = 1'b1;
                            writeback_data = rs1_val ^ rs2_val;
                        end else begin
                            illegal_instr = 1'b1;
                        end
                    end

                    default: begin
                        illegal_instr = 1'b1;
                    end
                endcase
            end

            7'b0000011: begin
                if (funct3 == 3'b010) begin
                    mem_addr_calc = rs1_val + imm_i;
                    if (mem_addr_calc[1:0] != 2'b00) begin
                        trap_now = 1'b1;
                    end else begin
                        writeback_en = 1'b1;
                        writeback_data = dmem_rdata;
                    end
                end else begin
                    illegal_instr = 1'b1;
                end
            end

            7'b0100011: begin
                if (funct3 == 3'b010) begin
                    mem_addr_calc = rs1_val + imm_s;
                    if (mem_addr_calc[1:0] != 2'b00)
                        trap_now = 1'b1;
                end else begin
                    illegal_instr = 1'b1;
                end
            end

            7'b1100011: begin
                case (funct3)
                    3'b000: begin
                        if (rs1_val == rs2_val) begin
                            next_pc = pc + imm_b;
                            if (next_pc[1:0] != 2'b00)
                                trap_now = 1'b1;
                        end
                    end

                    3'b001: begin
                        if (rs1_val != rs2_val) begin
                            next_pc = pc + imm_b;
                            if (next_pc[1:0] != 2'b00)
                                trap_now = 1'b1;
                        end
                    end

                    default: begin
                        illegal_instr = 1'b1;
                    end
                endcase
            end

            7'b1101111: begin
                next_pc = pc + imm_j;
                if (next_pc[1:0] != 2'b00) begin
                    trap_now = 1'b1;
                end else begin
                    writeback_en = 1'b1;
                    writeback_data = pc + 32'd4;
                end
            end

            7'b1110011: begin
                if (instr == 32'h00100073) begin
                    halt_now = 1'b1;
                end else begin
                    illegal_instr = 1'b1;
                end
            end

            default: begin
                illegal_instr = 1'b1;
            end
        endcase

        if (illegal_instr)
            trap_now = 1'b1;
    end

    always @(posedge clk) begin
        if (rst) begin
            pc <= 32'b0;
            halted <= 1'b0;
            trap <= 1'b0;
            for (i = 0; i < 32; i = i + 1)
                regs[i] <= 32'b0;
        end else if (!halted && !trap) begin
            if (trap_now) begin
                trap <= 1'b1;
            end else if (halt_now) begin
                halted <= 1'b1;
            end else begin
                pc <= next_pc;
                if (writeback_en && (rd_idx != 5'd0))
                    regs[rd_idx] <= writeback_data;
            end
            regs[0] <= 32'b0;
        end else begin
            regs[0] <= 32'b0;
        end
    end
endmodule
