diff options
author | Takashi Kokubun <[email protected]> | 2023-02-16 22:29:58 -0800 |
---|---|---|
committer | Takashi Kokubun <[email protected]> | 2023-03-05 23:28:59 -0800 |
commit | 2cc4f506bac0748277b41a4a5eb6f0ec41dd7344 (patch) | |
tree | 447fda3d7e563383f04292504b526d786475b906 | |
parent | 2603d7a0b7b0e698bed8910a8ad5edefe236de77 (diff) |
Implement optimized send
Notes
Notes:
Merged: https://github.com/ruby/ruby/pull/7448
-rw-r--r-- | lib/ruby_vm/mjit/assembler.rb | 57 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/compiler.rb | 13 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/exit_compiler.rb | 8 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/insn_compiler.rb | 197 | ||||
-rw-r--r-- | lib/ruby_vm/mjit/jit_state.rb | 13 | ||||
-rw-r--r-- | mjit_c.c | 1 | ||||
-rw-r--r-- | mjit_c.h | 7 | ||||
-rw-r--r-- | mjit_c.rb | 22 | ||||
-rwxr-xr-x | tool/mjit/bindgen.rb | 2 |
9 files changed, 247 insertions, 73 deletions
diff --git a/lib/ruby_vm/mjit/assembler.rb b/lib/ruby_vm/mjit/assembler.rb index abf0ae6c08..989d069776 100644 --- a/lib/ruby_vm/mjit/assembler.rb +++ b/lib/ruby_vm/mjit/assembler.rb @@ -1,5 +1,8 @@ # frozen_string_literal: true module RubyVM::MJIT + # 8-bit memory access + class BytePtr < Data.define(:reg, :disp); end + # 32-bit memory access class DwordPtr < Data.define(:reg, :disp); end @@ -66,7 +69,7 @@ module RubyVM::MJIT def add(dst, src) case [dst, src] # ADD r/m64, imm8 (Mod 00: [reg]) - in [[Symbol => dst_reg], Integer => src_imm] if r64?(dst_reg) && imm8?(src_imm) + in [Array[Symbol => dst_reg], Integer => src_imm] if r64?(dst_reg) && imm8?(src_imm) # REX.W + 83 /0 ib # MI: Operand 1: ModRM:r/m (r, w), Operand 2: imm8/16/32 insn( @@ -132,7 +135,7 @@ module RubyVM::MJIT imm: imm32(src_imm), ) # AND r64, r/m64 (Mod 01: [reg]+disp8) - in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) # REX.W + 23 /r # RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r) insn( @@ -274,7 +277,7 @@ module RubyVM::MJIT mod_rm: ModRM[mod: Mod11, reg: dst_reg, rm: src_reg], ) # CMOVZ r64, r/m64 (Mod 01: [reg]+disp8) - in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) # REX.W + 0F 44 /r # RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r) insn( @@ -290,6 +293,16 @@ module RubyVM::MJIT def cmp(left, right) case [left, right] + # CMP r/m8, imm8 (Mod 01: [reg]+disp8) + in [BytePtr[reg: left_reg, disp: left_disp], Integer => right_imm] if r64?(left_reg) && imm8?(left_disp) && imm8?(right_imm) + # 80 /7 ib + # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32 + insn( + opcode: 0x80, + mod_rm: ModRM[mod: Mod01, reg: 7, rm: left_reg], + disp: left_disp, + imm: imm8(right_imm), + ) # CMP r/m32, imm32 (Mod 01: [reg]+disp8) in [DwordPtr[reg: left_reg, disp: left_disp], Integer => right_imm] if imm8?(left_disp) && imm32?(right_imm) # 81 /7 id @@ -301,7 +314,7 @@ module RubyVM::MJIT imm: imm32(right_imm), ) # CMP r/m64, imm8 (Mod 01: [reg]+disp8) - in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if r64?(left_reg) && imm8?(left_disp) && imm8?(right_imm) + in [Array[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if r64?(left_reg) && imm8?(left_disp) && imm8?(right_imm) # REX.W + 83 /7 ib # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32 insn( @@ -321,8 +334,18 @@ module RubyVM::MJIT mod_rm: ModRM[mod: Mod11, reg: 7, rm: left_reg], imm: imm8(right_imm), ) + # CMP r/m64, imm32 (Mod 11: reg) + in [Symbol => left_reg, Integer => right_imm] if r64?(left_reg) && imm32?(right_imm) + # REX.W + 81 /7 id + # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32 + insn( + prefix: REX_W, + opcode: 0x81, + mod_rm: ModRM[mod: Mod11, reg: 7, rm: left_reg], + imm: imm32(right_imm), + ) # CMP r/m64, r64 (Mod 01: [reg]+disp8) - in [[Symbol => left_reg, Integer => left_disp], Symbol => right_reg] if r64?(right_reg) + in [Array[Symbol => left_reg, Integer => left_disp], Symbol => right_reg] if r64?(right_reg) # REX.W + 39 /r # MR: Operand 1: ModRM:r/m (r), Operand 2: ModRM:reg (r) insn( @@ -393,7 +416,7 @@ module RubyVM::MJIT # E9 cd insn(opcode: 0xe9, imm: rel32(dst_addr)) # JMP r/m64 (Mod 01: [reg]+disp8) - in [Symbol => dst_reg, Integer => dst_disp] if imm8?(dst_disp) + in Array[Symbol => dst_reg, Integer => dst_disp] if imm8?(dst_disp) # FF /4 insn(opcode: 0xff, mod_rm: ModRM[mod: Mod01, reg: 4, rm: dst_reg], disp: dst_disp) # JMP r/m64 (Mod 11: reg) @@ -456,7 +479,7 @@ module RubyVM::MJIT def lea(dst, src) case [dst, src] # LEA r64,m (Mod 01: [reg]+disp8) - in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) # REX.W + 8D /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( @@ -466,7 +489,7 @@ module RubyVM::MJIT disp: imm8(src_disp), ) # LEA r64,m (Mod 10: [reg]+disp32) - in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp) + in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp) # REX.W + 8D /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( @@ -485,7 +508,7 @@ module RubyVM::MJIT in Symbol => dst_reg case src # MOV r64, r/m64 (Mod 00: [reg]) - in [Symbol => src_reg] if r64?(dst_reg) && r64?(src_reg) + in Array[Symbol => src_reg] if r64?(dst_reg) && r64?(src_reg) # REX.W + 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( @@ -494,7 +517,7 @@ module RubyVM::MJIT mod_rm: ModRM[mod: Mod00, reg: dst_reg, rm: src_reg], ) # MOV r64, r/m64 (Mod 01: [reg]+disp8) - in [Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + in Array[Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) # REX.W + 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( @@ -504,7 +527,7 @@ module RubyVM::MJIT disp: src_disp, ) # MOV r64, r/m64 (Mod 10: [reg]+disp16) - in [Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp) + in Array[Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp) # REX.W + 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( @@ -523,7 +546,7 @@ module RubyVM::MJIT mod_rm: ModRM[mod: Mod11, reg: dst_reg, rm: src_reg], ) # MOV r32 r/m32 (Mod 01: [reg]+disp8) - in [Symbol => src_reg, Integer => src_disp] if r32?(dst_reg) && imm8?(src_disp) + in Array[Symbol => src_reg, Integer => src_disp] if r32?(dst_reg) && imm8?(src_disp) # 8B /r # RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r) insn( @@ -563,7 +586,7 @@ module RubyVM::MJIT else raise NotImplementedError, "mov: not-implemented operands: #{dst.inspect}, #{src.inspect}" end - in [Symbol => dst_reg] + in Array[Symbol => dst_reg] case src # MOV r/m64, imm32 (Mod 00: [reg]) in Integer => src_imm if r64?(dst_reg) && imm32?(src_imm) @@ -587,7 +610,7 @@ module RubyVM::MJIT else raise NotImplementedError, "mov: not-implemented operands: #{dst.inspect}, #{src.inspect}" end - in [Symbol => dst_reg, Integer => dst_disp] + in Array[Symbol => dst_reg, Integer => dst_disp] # Optimize encoding when disp is 0 return mov([dst_reg], src) if dst_disp == 0 @@ -645,7 +668,7 @@ module RubyVM::MJIT def or(dst, src) case [dst, src] # OR r64, r/m64 (Mod 01: [reg]+disp8) - in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) + in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp) # REX.W + 0B /r # RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r) insn( @@ -734,7 +757,7 @@ module RubyVM::MJIT def test(left, right) case [left, right] # TEST r/m8*, imm8 (Mod 01: [reg]+disp8) - in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm8?(right_imm) && right_imm >= 0 + in [Array[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm8?(right_imm) && right_imm >= 0 # REX + F6 /0 ib # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32 insn( @@ -744,7 +767,7 @@ module RubyVM::MJIT imm: imm8(right_imm), ) # TEST r/m64, imm32 (Mod 01: [reg]+disp8) - in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm32?(right_imm) + in [Array[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm32?(right_imm) # REX.W + F7 /0 id # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32 insn( diff --git a/lib/ruby_vm/mjit/compiler.rb b/lib/ruby_vm/mjit/compiler.rb index 5aac3626fa..b34df4c392 100644 --- a/lib/ruby_vm/mjit/compiler.rb +++ b/lib/ruby_vm/mjit/compiler.rb @@ -170,6 +170,17 @@ module RubyVM::MJIT insn = self.class.decode_insn(iseq.body.iseq_encoded[index]) jit.pc = (iseq.body.iseq_encoded + index).to_i + # If previous instruction requested to record the boundary + if jit.record_boundary_patch_point + # Generate an exit to this instruction and record it + exit_pos = Assembler.new.then do |ocb_asm| + @exit_compiler.compile_side_exit(jit.pc, ctx, ocb_asm) + @ocb.write(ocb_asm) + end + Invariants.record_global_inval_patch(asm, exit_pos) + jit.record_boundary_patch_point = false + end + case status = @insn_compiler.compile(jit, ctx, asm, insn) when KeepCompiling index += insn.len @@ -177,7 +188,7 @@ module RubyVM::MJIT # TODO: pad nops if entry exit exists break when CantCompile - @exit_compiler.compile_side_exit(jit, ctx, asm) + @exit_compiler.compile_side_exit(jit.pc, ctx, asm) break else raise "compiling #{insn.name} returned unexpected status: #{status.inspect}" diff --git a/lib/ruby_vm/mjit/exit_compiler.rb b/lib/ruby_vm/mjit/exit_compiler.rb index 20645fdb9e..bf85a340d7 100644 --- a/lib/ruby_vm/mjit/exit_compiler.rb +++ b/lib/ruby_vm/mjit/exit_compiler.rb @@ -63,15 +63,15 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def compile_side_exit(jit, ctx, asm) + def compile_side_exit(pc, ctx, asm) # Increment per-insn exit counter - incr_insn_exit(jit.pc, asm) + incr_insn_exit(pc, asm) # Fix pc/sp offsets for the interpreter - save_pc_and_sp(jit.pc, ctx.dup, asm) # dup to avoid sp_offset update + save_pc_and_sp(pc, ctx.dup, asm) # dup to avoid sp_offset update # Restore callee-saved registers - asm.comment("exit to interpreter on #{pc_to_insn(jit.pc).name}") + asm.comment("exit to interpreter on #{pc_to_insn(pc).name}") asm.pop(SP) asm.pop(EC) asm.pop(CFP) diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb index 1df6809c2b..a06030fd6f 100644 --- a/lib/ruby_vm/mjit/insn_compiler.rb +++ b/lib/ruby_vm/mjit/insn_compiler.rb @@ -966,7 +966,7 @@ module RubyVM::MJIT recv_opnd = ctx.stack_opnd(1) not_array_exit = counted_exit(side_exit, :optaref_recv_not_array) - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile return CantCompile end @@ -1001,7 +1001,7 @@ module RubyVM::MJIT # Guard that the receiver is a Hash not_hash_exit = counted_exit(side_exit, :optaref_recv_not_hash) - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile return CantCompile end @@ -1051,12 +1051,12 @@ module RubyVM::MJIT side_exit = side_exit(jit, ctx) # Guard receiver is an Array - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile return CantCompile end # Guard key is a fixnum - if jit_guard_known_class(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile return CantCompile end @@ -1090,7 +1090,7 @@ module RubyVM::MJIT side_exit = side_exit(jit, ctx) # Guard receiver is a Hash - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile return CantCompile end @@ -1425,7 +1425,7 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_guard_known_class(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5) + def jit_guard_known_klass(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5) # Only memory operand is supported for now assert_equal(true, obj_opnd.is_a?(Array)) @@ -1445,9 +1445,13 @@ module RubyVM::MJIT asm.comment('guard object is fixnum') asm.test(obj_opnd, C.RUBY_FIXNUM_FLAG) jit_chain_guard(:jz, jit, ctx, asm, side_exit, limit:) - elsif known_klass == Symbol - asm.incr_counter(:send_guard_symbol) - return CantCompile + elsif known_klass == Symbol && static_symbol?(comptime_obj) + # We will guard STATIC vs DYNAMIC as though they were separate classes + # DYNAMIC symbols can be handled by the general else case below + asm.comment('guard object is static symbol') + assert_equal(8, C.RUBY_SPECIAL_SHIFT) + asm.cmp(BytePtr[*obj_opnd], C.RUBY_SYMBOL_FLAG) + jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:) elsif known_klass == Float asm.incr_counter(:send_guard_float) return CantCompile @@ -1471,7 +1475,7 @@ module RubyVM::MJIT # Bail if receiver class is different from known_klass klass_opnd = [obj_opnd, C.RBasic.offsetof(:klass)] - asm.comment('guard known class') + asm.comment("guard known class #{known_klass}") asm.mov(:rcx, to_value(known_klass)) asm.cmp(klass_opnd, :rcx) jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:) @@ -1593,7 +1597,7 @@ module RubyVM::MJIT end # Guard that a is a String - if jit_guard_known_class(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile return false end @@ -1609,7 +1613,7 @@ module RubyVM::MJIT # Otherwise guard that b is a T_STRING (from type info) or String (from runtime guard) # Note: any T_STRING is valid here, but we check for a ::String for simplicity # To pass a mutable static variable (rb_cString) requires an unsafe block - if jit_guard_known_class(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile return false end @@ -1639,6 +1643,7 @@ module RubyVM::MJIT # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] def jit_prepare_routine_call(jit, ctx, asm) + jit.record_boundary_patch_point = true jit_save_pc(jit, asm) jit_save_sp(jit, ctx, asm) end @@ -1672,6 +1677,17 @@ module RubyVM::MJIT reset_depth.chain_depth = 0 next_pc = jit.pc + jit.insn.len * C.VALUE.size + + # We are at the end of the current instruction. Record the boundary. + if jit.record_boundary_patch_point + exit_pos = Assembler.new.then do |ocb_asm| + @exit_compiler.compile_side_exit(next_pc, ctx, ocb_asm) + @ocb.write(ocb_asm) + end + Invariants.record_global_inval_patch(asm, exit_pos) + jit.record_boundary_patch_point = false + end + stub_next_block(jit.iseq, next_pc, reset_depth, asm, comment: 'jump_to_next_insn') end @@ -1781,7 +1797,8 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_method(jit, ctx, asm, mid, argc, flags) + # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND + def jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift: 0) # Specialize on a compile-time receiver, and split a block for chain guards unless jit.at_current_insn? defer_compilation(jit, ctx, asm) @@ -1791,22 +1808,22 @@ module RubyVM::MJIT # Generate a side exit side_exit = side_exit(jit, ctx) - # Calculate a receiver index + # kw_splat is not supported yet if flags & C.VM_CALL_KW_SPLAT != 0 - # recv_index calculation may not work for this asm.incr_counter(:send_kw_splat) return CantCompile end - recv_index = argc # TODO: +1 for VM_CALL_ARGS_BLOCKARG # Get a compile-time receiver and its class - comptime_recv = jit.peek_at_stack(recv_index) + recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0) + recv_idx += send_shift + comptime_recv = jit.peek_at_stack(recv_idx) comptime_recv_klass = C.rb_class_of(comptime_recv) # Guard the receiver class (part of vm_search_method_fastpath) - recv_opnd = ctx.stack_opnd(recv_index) + recv_opnd = ctx.stack_opnd(recv_idx) megamorphic_exit = counted_exit(side_exit, :send_klass_megamorphic) - if jit_guard_known_class(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile return CantCompile end @@ -1838,27 +1855,27 @@ module RubyVM::MJIT # Invalidate on redefinition (part of vm_search_method_fastpath) Invariants.assume_method_lookup_stable(jit, cme) - jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd) + jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:) end # vm_call_method_each_type # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd) + def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:) case cme.def.type when C.VM_METHOD_TYPE_ISEQ - jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc) + jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:) when C.VM_METHOD_TYPE_NOTIMPLEMENTED asm.incr_counter(:send_notimplemented) return CantCompile when C.VM_METHOD_TYPE_CFUNC - jit_call_cfunc(jit, ctx, asm, cme, flags, argc) + jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:) when C.VM_METHOD_TYPE_ATTRSET asm.incr_counter(:send_attrset) return CantCompile when C.VM_METHOD_TYPE_IVAR - jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd) + jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:) when C.VM_METHOD_TYPE_MISSING asm.incr_counter(:send_missing) return CantCompile @@ -1869,7 +1886,7 @@ module RubyVM::MJIT asm.incr_counter(:send_alias) return CantCompile when C.VM_METHOD_TYPE_OPTIMIZED - jit_call_optimized(jit, ctx, asm, cme, flags, argc) + jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:) when C.VM_METHOD_TYPE_UNDEF asm.incr_counter(:send_undef) return CantCompile @@ -1889,7 +1906,7 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc) + def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:) iseq = def_iseq_ptr(cme.def) opt_pc = jit_callee_setup_arg(jit, ctx, asm, flags, argc, iseq) if opt_pc == CantCompile @@ -1902,19 +1919,24 @@ module RubyVM::MJIT asm.incr_counter(:send_tailcall) return CantCompile end - jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq) + jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:) end # vm_call_iseq_setup_normal (vm_call_iseq_setup_2 -> vm_call_iseq_setup_normal) # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq) + def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:) + # We will not have side exits from here. Adjust the stack. + if flags & C.VM_CALL_OPT_SEND != 0 + jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:) + end + # Save caller SP and PC before pushing a callee frame for backtrace and side exits asm.comment('save SP to caller CFP') - # Not setting this to SP register. This cfp->sp will be copied to SP on leave insn. - sp_index = -(1 + argc) # Pop receiver and arguments for side exits # TODO: subtract one more for VM_CALL_ARGS_BLOCKARG - asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * sp_index)) + recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0) + # Skip setting this to SP register. This cfp->sp will be copied to SP on leave insn. + asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * -(1 + recv_idx))) # Pop receiver and arguments to prepare for side exits asm.mov([CFP, C.rb_control_frame_t.offsetof(:sp)], :rax) jit_save_pc(jit, asm, comment: 'save PC to caller CFP') @@ -1937,7 +1959,7 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_cfunc(jit, ctx, asm, cme, flags, argc) + def jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:) if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile return CantCompile end @@ -1945,14 +1967,14 @@ module RubyVM::MJIT return CantCompile end - jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc) + jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:) end # jit_call_cfunc_with_frame # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc) + def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:) cfunc = cme.def.body.cfunc if argc + 1 > 6 @@ -1981,6 +2003,11 @@ module RubyVM::MJIT return CantCompile end + # We will not have side exits from here. Adjust the stack. + if flags & C.VM_CALL_OPT_SEND != 0 + jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:) + end + # Check interrupts before SP motion to safely side-exit with the original SP. jit_check_ints(jit, ctx, asm) @@ -2026,11 +2053,11 @@ module RubyVM::MJIT EndBlock end - # vm_call_ivar + # vm_call_ivar (+ part of vm_call_method_each_type) # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd) + def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:) if flags & C.VM_CALL_ARGS_SPLAT != 0 asm.incr_counter(:send_ivar_splat) return CantCompile @@ -2041,6 +2068,7 @@ module RubyVM::MJIT return CantCompile end + # We don't support jit_call_opt_send_shift_stack for this yet. if flags & C.VM_CALL_OPT_SEND != 0 asm.incr_counter(:send_ivar_opt_send) return CantCompile @@ -2048,7 +2076,7 @@ module RubyVM::MJIT ivar_id = cme.def.body.attr.id - if flags & C.VM_CALL_OPT_SEND != 0 + if flags & C.VM_CALL_ARGS_BLOCKARG != 0 asm.incr_counter(:send_ivar_blockarg) return CantCompile end @@ -2060,11 +2088,10 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_optimized(jit, ctx, asm, cme, flags, argc) + def jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:) case cme.def.body.optimized.type when C.OPTIMIZED_METHOD_TYPE_SEND - asm.incr_counter(:send_optimized_send) - return CantCompile + jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:) when C.OPTIMIZED_METHOD_TYPE_CALL asm.incr_counter(:send_optimized_call) return CantCompile @@ -2083,6 +2110,87 @@ module RubyVM::MJIT end end + # vm_call_opt_send + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:) + if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile + return CantCompile + end + + if argc == 0 + asm.incr_counter(:send_optimized_send_no_args) + return CantCompile + end + + argc -= 1 + # We aren't handling `send(:send, ...)` yet. This might work, but not tested yet. + if send_shift > 0 + asm.incr_counter(:send_optimized_send_send) + return CantCompile + end + # Ideally, we want to shift the stack here, but it's not safe until you reach the point + # where you never exit. `send_shift` signals to lazily shift the stack by this amount. + send_shift += 1 + + kw_splat = flags & C.VM_CALL_KW_SPLAT != 0 + jit_call_symbol(jit, ctx, asm, cme, C.VM_CALL_FCALL, argc, kw_splat, send_shift:) + end + + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:) + # We don't support `send(:send, ...)` for now. + assert_equal(1, send_shift) + + asm.comment('shift stack') + (0...argc).reverse_each do |i| + opnd = ctx.stack_opnd(i) + opnd2 = ctx.stack_opnd(i + 1) + asm.mov(:rax, opnd) + asm.mov(opnd2, :rax) + end + + ctx.stack_pop(1) + end + + # vm_call_symbol + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_symbol(jit, ctx, asm, cme, flags, argc, kw_splat, send_shift:) + flags |= C.VM_CALL_OPT_SEND | (kw_splat ? C.VM_CALL_KW_SPLAT : 0) + + comptime_symbol = jit.peek_at_stack(argc) + if comptime_symbol.class != String && !static_symbol?(comptime_symbol) + asm.incr_counter(:send_optimized_send_not_sym_or_str) + return CantCompile + end + + mid = C.get_symbol_id(comptime_symbol) + if mid == 0 + asm.incr_counter(:send_optimized_send_null_mid) + return CantCompile + end + + asm.comment("Guard #{comptime_symbol.inspect} is on stack") + mid_changed_exit = counted_exit(side_exit(jit, ctx), :send_optimized_send_mid_changed) + if jit_guard_known_klass(jit, ctx, asm, comptime_symbol.class, ctx.stack_opnd(argc), comptime_symbol, mid_changed_exit) == CantCompile + return CantCompile + end + asm.mov(C_ARGS[0], ctx.stack_opnd(argc)) + asm.call(C.rb_get_symbol_id) + asm.cmp(C_RET, mid) + jit_chain_guard(:jne, jit, ctx, asm, mid_changed_exit) + + if flags & C.VM_CALL_FCALL != 0 + return jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift:) + end + + raise NotImplementedError # unreachable for now + end + # vm_push_frame # # Frame structure: @@ -2253,8 +2361,11 @@ module RubyVM::MJIT end def fixnum?(obj) - flag = C.RUBY_FIXNUM_FLAG - (C.to_value(obj) & flag) == flag + (C.to_value(obj) & C.RUBY_FIXNUM_FLAG) == C.RUBY_FIXNUM_FLAG + end + + def static_symbol?(obj) + (C.to_value(obj) & 0xff) == C.RUBY_SYMBOL_FLAG end # @param jit [RubyVM::MJIT::JITState] @@ -2296,7 +2407,7 @@ module RubyVM::MJIT return side_exit end asm = Assembler.new - @exit_compiler.compile_side_exit(jit, ctx, asm) + @exit_compiler.compile_side_exit(jit.pc, ctx, asm) jit.side_exits[jit.pc] = @ocb.write(asm) end diff --git a/lib/ruby_vm/mjit/jit_state.rb b/lib/ruby_vm/mjit/jit_state.rb index b90d2f8020..3e5fe996fa 100644 --- a/lib/ruby_vm/mjit/jit_state.rb +++ b/lib/ruby_vm/mjit/jit_state.rb @@ -1,12 +1,13 @@ module RubyVM::MJIT class JITState < Struct.new( - :iseq, # @param `RubyVM::MJIT::CPointer::Struct_rb_iseq_t` - :pc, # @param [Integer] The JIT target PC - :cfp, # @param `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t` The JIT source CFP (before MJIT is called) - :block, # @param [RubyVM::MJIT::Block] - :side_exits, # @param [Hash{ Integer => Integer }] { PC => address } + :iseq, # @param `RubyVM::MJIT::CPointer::Struct_rb_iseq_t` + :pc, # @param [Integer] The JIT target PC + :cfp, # @param `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t` The JIT source CFP (before MJIT is called) + :block, # @param [RubyVM::MJIT::Block] + :side_exits, # @param [Hash{ Integer => Integer }] { PC => address } + :record_boundary_patch_point, # @param [TrueClass,FalseClass] ) - def initialize(side_exits: {}, **) = super + def initialize(side_exits: {}, record_boundary_patch_point: false, **) = super def insn Compiler.decode_insn(C.VALUE.new(pc).*) @@ -117,6 +117,7 @@ mjit_for_each_iseq(rb_execution_context_t *ec, VALUE self, VALUE block) } extern bool rb_simple_iseq_p(const rb_iseq_t *iseq); +extern ID rb_get_symbol_id(VALUE name); #include "mjit_c.rbinc" @@ -144,13 +144,18 @@ MJIT_RUNTIME_COUNTERS( send_ivar_opt_send, send_ivar_blockarg, - send_optimized_send, send_optimized_call, send_optimized_block_call, send_optimized_struct_aref, send_optimized_struct_aset, send_optimized_unknown_type, + send_optimized_send_no_args, + send_optimized_send_not_sym_or_str, + send_optimized_send_mid_changed, + send_optimized_send_null_mid, + send_optimized_send_send, + send_guard_symbol, send_guard_float, @@ -180,6 +180,14 @@ module RubyVM::MJIT # :nodoc: all Primitive.cexpr! 'SIZET2NUM((size_t)rb_hash_aset)' end + def get_symbol_id(name) + Primitive.cexpr! 'SIZET2NUM((size_t)rb_get_symbol_id(name))' + end + + def rb_get_symbol_id + Primitive.cexpr! 'SIZET2NUM((size_t)rb_get_symbol_id)' + end + #======================================================================================== # # Old stuff @@ -629,6 +637,14 @@ module RubyVM::MJIT # :nodoc: all Primitive.cexpr! %q{ ULONG2NUM(RUBY_IMMEDIATE_MASK) } end + def C.RUBY_SPECIAL_SHIFT + Primitive.cexpr! %q{ ULONG2NUM(RUBY_SPECIAL_SHIFT) } + end + + def C.RUBY_SYMBOL_FLAG + Primitive.cexpr! %q{ ULONG2NUM(RUBY_SYMBOL_FLAG) } + end + def C.RUBY_T_ARRAY Primitive.cexpr! %q{ ULONG2NUM(RUBY_T_ARRAY) } end @@ -1165,12 +1181,16 @@ module RubyVM::MJIT # :nodoc: all send_ivar_splat: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_ivar_splat)")], send_ivar_opt_send: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_ivar_opt_send)")], send_ivar_blockarg: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_ivar_blockarg)")], - send_optimized_send: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send)")], send_optimized_call: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_call)")], send_optimized_block_call: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_block_call)")], send_optimized_struct_aref: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_struct_aref)")], send_optimized_struct_aset: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_struct_aset)")], send_optimized_unknown_type: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_unknown_type)")], + send_optimized_send_no_args: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_no_args)")], + send_optimized_send_not_sym_or_str: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_not_sym_or_str)")], + send_optimized_send_mid_changed: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_mid_changed)")], + send_optimized_send_null_mid: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_null_mid)")], + send_optimized_send_send: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_send)")], send_guard_symbol: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_guard_symbol)")], send_guard_float: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_guard_float)")], getivar_megamorphic: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), getivar_megamorphic)")], diff --git a/tool/mjit/bindgen.rb b/tool/mjit/bindgen.rb index dc2e76dd68..6e67702ea4 100755 --- a/tool/mjit/bindgen.rb +++ b/tool/mjit/bindgen.rb @@ -424,6 +424,8 @@ generator = BindingGenerator.new( INVALID_SHAPE_ID OBJ_TOO_COMPLEX_SHAPE_ID RUBY_FIXNUM_FLAG + RUBY_SYMBOL_FLAG + RUBY_SPECIAL_SHIFT RUBY_IMMEDIATE_MASK RARRAY_EMBED_LEN_MASK RARRAY_EMBED_LEN_SHIFT |