summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakashi Kokubun <[email protected]>2023-02-16 22:29:58 -0800
committerTakashi Kokubun <[email protected]>2023-03-05 23:28:59 -0800
commit2cc4f506bac0748277b41a4a5eb6f0ec41dd7344 (patch)
tree447fda3d7e563383f04292504b526d786475b906
parent2603d7a0b7b0e698bed8910a8ad5edefe236de77 (diff)
Implement optimized send
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/7448
-rw-r--r--lib/ruby_vm/mjit/assembler.rb57
-rw-r--r--lib/ruby_vm/mjit/compiler.rb13
-rw-r--r--lib/ruby_vm/mjit/exit_compiler.rb8
-rw-r--r--lib/ruby_vm/mjit/insn_compiler.rb197
-rw-r--r--lib/ruby_vm/mjit/jit_state.rb13
-rw-r--r--mjit_c.c1
-rw-r--r--mjit_c.h7
-rw-r--r--mjit_c.rb22
-rwxr-xr-xtool/mjit/bindgen.rb2
9 files changed, 247 insertions, 73 deletions
diff --git a/lib/ruby_vm/mjit/assembler.rb b/lib/ruby_vm/mjit/assembler.rb
index abf0ae6c08..989d069776 100644
--- a/lib/ruby_vm/mjit/assembler.rb
+++ b/lib/ruby_vm/mjit/assembler.rb
@@ -1,5 +1,8 @@
# frozen_string_literal: true
module RubyVM::MJIT
+ # 8-bit memory access
+ class BytePtr < Data.define(:reg, :disp); end
+
# 32-bit memory access
class DwordPtr < Data.define(:reg, :disp); end
@@ -66,7 +69,7 @@ module RubyVM::MJIT
def add(dst, src)
case [dst, src]
# ADD r/m64, imm8 (Mod 00: [reg])
- in [[Symbol => dst_reg], Integer => src_imm] if r64?(dst_reg) && imm8?(src_imm)
+ in [Array[Symbol => dst_reg], Integer => src_imm] if r64?(dst_reg) && imm8?(src_imm)
# REX.W + 83 /0 ib
# MI: Operand 1: ModRM:r/m (r, w), Operand 2: imm8/16/32
insn(
@@ -132,7 +135,7 @@ module RubyVM::MJIT
imm: imm32(src_imm),
)
# AND r64, r/m64 (Mod 01: [reg]+disp8)
- in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
+ in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
# REX.W + 23 /r
# RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r)
insn(
@@ -274,7 +277,7 @@ module RubyVM::MJIT
mod_rm: ModRM[mod: Mod11, reg: dst_reg, rm: src_reg],
)
# CMOVZ r64, r/m64 (Mod 01: [reg]+disp8)
- in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
+ in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
# REX.W + 0F 44 /r
# RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r)
insn(
@@ -290,6 +293,16 @@ module RubyVM::MJIT
def cmp(left, right)
case [left, right]
+ # CMP r/m8, imm8 (Mod 01: [reg]+disp8)
+ in [BytePtr[reg: left_reg, disp: left_disp], Integer => right_imm] if r64?(left_reg) && imm8?(left_disp) && imm8?(right_imm)
+ # 80 /7 ib
+ # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
+ insn(
+ opcode: 0x80,
+ mod_rm: ModRM[mod: Mod01, reg: 7, rm: left_reg],
+ disp: left_disp,
+ imm: imm8(right_imm),
+ )
# CMP r/m32, imm32 (Mod 01: [reg]+disp8)
in [DwordPtr[reg: left_reg, disp: left_disp], Integer => right_imm] if imm8?(left_disp) && imm32?(right_imm)
# 81 /7 id
@@ -301,7 +314,7 @@ module RubyVM::MJIT
imm: imm32(right_imm),
)
# CMP r/m64, imm8 (Mod 01: [reg]+disp8)
- in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if r64?(left_reg) && imm8?(left_disp) && imm8?(right_imm)
+ in [Array[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if r64?(left_reg) && imm8?(left_disp) && imm8?(right_imm)
# REX.W + 83 /7 ib
# MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
insn(
@@ -321,8 +334,18 @@ module RubyVM::MJIT
mod_rm: ModRM[mod: Mod11, reg: 7, rm: left_reg],
imm: imm8(right_imm),
)
+ # CMP r/m64, imm32 (Mod 11: reg)
+ in [Symbol => left_reg, Integer => right_imm] if r64?(left_reg) && imm32?(right_imm)
+ # REX.W + 81 /7 id
+ # MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
+ insn(
+ prefix: REX_W,
+ opcode: 0x81,
+ mod_rm: ModRM[mod: Mod11, reg: 7, rm: left_reg],
+ imm: imm32(right_imm),
+ )
# CMP r/m64, r64 (Mod 01: [reg]+disp8)
- in [[Symbol => left_reg, Integer => left_disp], Symbol => right_reg] if r64?(right_reg)
+ in [Array[Symbol => left_reg, Integer => left_disp], Symbol => right_reg] if r64?(right_reg)
# REX.W + 39 /r
# MR: Operand 1: ModRM:r/m (r), Operand 2: ModRM:reg (r)
insn(
@@ -393,7 +416,7 @@ module RubyVM::MJIT
# E9 cd
insn(opcode: 0xe9, imm: rel32(dst_addr))
# JMP r/m64 (Mod 01: [reg]+disp8)
- in [Symbol => dst_reg, Integer => dst_disp] if imm8?(dst_disp)
+ in Array[Symbol => dst_reg, Integer => dst_disp] if imm8?(dst_disp)
# FF /4
insn(opcode: 0xff, mod_rm: ModRM[mod: Mod01, reg: 4, rm: dst_reg], disp: dst_disp)
# JMP r/m64 (Mod 11: reg)
@@ -456,7 +479,7 @@ module RubyVM::MJIT
def lea(dst, src)
case [dst, src]
# LEA r64,m (Mod 01: [reg]+disp8)
- in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
+ in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
# REX.W + 8D /r
# RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r)
insn(
@@ -466,7 +489,7 @@ module RubyVM::MJIT
disp: imm8(src_disp),
)
# LEA r64,m (Mod 10: [reg]+disp32)
- in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp)
+ in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp)
# REX.W + 8D /r
# RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r)
insn(
@@ -485,7 +508,7 @@ module RubyVM::MJIT
in Symbol => dst_reg
case src
# MOV r64, r/m64 (Mod 00: [reg])
- in [Symbol => src_reg] if r64?(dst_reg) && r64?(src_reg)
+ in Array[Symbol => src_reg] if r64?(dst_reg) && r64?(src_reg)
# REX.W + 8B /r
# RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r)
insn(
@@ -494,7 +517,7 @@ module RubyVM::MJIT
mod_rm: ModRM[mod: Mod00, reg: dst_reg, rm: src_reg],
)
# MOV r64, r/m64 (Mod 01: [reg]+disp8)
- in [Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
+ in Array[Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
# REX.W + 8B /r
# RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r)
insn(
@@ -504,7 +527,7 @@ module RubyVM::MJIT
disp: src_disp,
)
# MOV r64, r/m64 (Mod 10: [reg]+disp16)
- in [Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp)
+ in Array[Symbol => src_reg, Integer => src_disp] if r64?(dst_reg) && r64?(src_reg) && imm32?(src_disp)
# REX.W + 8B /r
# RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r)
insn(
@@ -523,7 +546,7 @@ module RubyVM::MJIT
mod_rm: ModRM[mod: Mod11, reg: dst_reg, rm: src_reg],
)
# MOV r32 r/m32 (Mod 01: [reg]+disp8)
- in [Symbol => src_reg, Integer => src_disp] if r32?(dst_reg) && imm8?(src_disp)
+ in Array[Symbol => src_reg, Integer => src_disp] if r32?(dst_reg) && imm8?(src_disp)
# 8B /r
# RM: Operand 1: ModRM:reg (w), Operand 2: ModRM:r/m (r)
insn(
@@ -563,7 +586,7 @@ module RubyVM::MJIT
else
raise NotImplementedError, "mov: not-implemented operands: #{dst.inspect}, #{src.inspect}"
end
- in [Symbol => dst_reg]
+ in Array[Symbol => dst_reg]
case src
# MOV r/m64, imm32 (Mod 00: [reg])
in Integer => src_imm if r64?(dst_reg) && imm32?(src_imm)
@@ -587,7 +610,7 @@ module RubyVM::MJIT
else
raise NotImplementedError, "mov: not-implemented operands: #{dst.inspect}, #{src.inspect}"
end
- in [Symbol => dst_reg, Integer => dst_disp]
+ in Array[Symbol => dst_reg, Integer => dst_disp]
# Optimize encoding when disp is 0
return mov([dst_reg], src) if dst_disp == 0
@@ -645,7 +668,7 @@ module RubyVM::MJIT
def or(dst, src)
case [dst, src]
# OR r64, r/m64 (Mod 01: [reg]+disp8)
- in [Symbol => dst_reg, [Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
+ in [Symbol => dst_reg, Array[Symbol => src_reg, Integer => src_disp]] if r64?(dst_reg) && r64?(src_reg) && imm8?(src_disp)
# REX.W + 0B /r
# RM: Operand 1: ModRM:reg (r, w), Operand 2: ModRM:r/m (r)
insn(
@@ -734,7 +757,7 @@ module RubyVM::MJIT
def test(left, right)
case [left, right]
# TEST r/m8*, imm8 (Mod 01: [reg]+disp8)
- in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm8?(right_imm) && right_imm >= 0
+ in [Array[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm8?(right_imm) && right_imm >= 0
# REX + F6 /0 ib
# MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
insn(
@@ -744,7 +767,7 @@ module RubyVM::MJIT
imm: imm8(right_imm),
)
# TEST r/m64, imm32 (Mod 01: [reg]+disp8)
- in [[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm32?(right_imm)
+ in [Array[Symbol => left_reg, Integer => left_disp], Integer => right_imm] if imm32?(right_imm)
# REX.W + F7 /0 id
# MI: Operand 1: ModRM:r/m (r), Operand 2: imm8/16/32
insn(
diff --git a/lib/ruby_vm/mjit/compiler.rb b/lib/ruby_vm/mjit/compiler.rb
index 5aac3626fa..b34df4c392 100644
--- a/lib/ruby_vm/mjit/compiler.rb
+++ b/lib/ruby_vm/mjit/compiler.rb
@@ -170,6 +170,17 @@ module RubyVM::MJIT
insn = self.class.decode_insn(iseq.body.iseq_encoded[index])
jit.pc = (iseq.body.iseq_encoded + index).to_i
+ # If previous instruction requested to record the boundary
+ if jit.record_boundary_patch_point
+ # Generate an exit to this instruction and record it
+ exit_pos = Assembler.new.then do |ocb_asm|
+ @exit_compiler.compile_side_exit(jit.pc, ctx, ocb_asm)
+ @ocb.write(ocb_asm)
+ end
+ Invariants.record_global_inval_patch(asm, exit_pos)
+ jit.record_boundary_patch_point = false
+ end
+
case status = @insn_compiler.compile(jit, ctx, asm, insn)
when KeepCompiling
index += insn.len
@@ -177,7 +188,7 @@ module RubyVM::MJIT
# TODO: pad nops if entry exit exists
break
when CantCompile
- @exit_compiler.compile_side_exit(jit, ctx, asm)
+ @exit_compiler.compile_side_exit(jit.pc, ctx, asm)
break
else
raise "compiling #{insn.name} returned unexpected status: #{status.inspect}"
diff --git a/lib/ruby_vm/mjit/exit_compiler.rb b/lib/ruby_vm/mjit/exit_compiler.rb
index 20645fdb9e..bf85a340d7 100644
--- a/lib/ruby_vm/mjit/exit_compiler.rb
+++ b/lib/ruby_vm/mjit/exit_compiler.rb
@@ -63,15 +63,15 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def compile_side_exit(jit, ctx, asm)
+ def compile_side_exit(pc, ctx, asm)
# Increment per-insn exit counter
- incr_insn_exit(jit.pc, asm)
+ incr_insn_exit(pc, asm)
# Fix pc/sp offsets for the interpreter
- save_pc_and_sp(jit.pc, ctx.dup, asm) # dup to avoid sp_offset update
+ save_pc_and_sp(pc, ctx.dup, asm) # dup to avoid sp_offset update
# Restore callee-saved registers
- asm.comment("exit to interpreter on #{pc_to_insn(jit.pc).name}")
+ asm.comment("exit to interpreter on #{pc_to_insn(pc).name}")
asm.pop(SP)
asm.pop(EC)
asm.pop(CFP)
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb
index 1df6809c2b..a06030fd6f 100644
--- a/lib/ruby_vm/mjit/insn_compiler.rb
+++ b/lib/ruby_vm/mjit/insn_compiler.rb
@@ -966,7 +966,7 @@ module RubyVM::MJIT
recv_opnd = ctx.stack_opnd(1)
not_array_exit = counted_exit(side_exit, :optaref_recv_not_array)
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile
return CantCompile
end
@@ -1001,7 +1001,7 @@ module RubyVM::MJIT
# Guard that the receiver is a Hash
not_hash_exit = counted_exit(side_exit, :optaref_recv_not_hash)
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile
return CantCompile
end
@@ -1051,12 +1051,12 @@ module RubyVM::MJIT
side_exit = side_exit(jit, ctx)
# Guard receiver is an Array
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
return CantCompile
end
# Guard key is a fixnum
- if jit_guard_known_class(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile
return CantCompile
end
@@ -1090,7 +1090,7 @@ module RubyVM::MJIT
side_exit = side_exit(jit, ctx)
# Guard receiver is a Hash
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
return CantCompile
end
@@ -1425,7 +1425,7 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_guard_known_class(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5)
+ def jit_guard_known_klass(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5)
# Only memory operand is supported for now
assert_equal(true, obj_opnd.is_a?(Array))
@@ -1445,9 +1445,13 @@ module RubyVM::MJIT
asm.comment('guard object is fixnum')
asm.test(obj_opnd, C.RUBY_FIXNUM_FLAG)
jit_chain_guard(:jz, jit, ctx, asm, side_exit, limit:)
- elsif known_klass == Symbol
- asm.incr_counter(:send_guard_symbol)
- return CantCompile
+ elsif known_klass == Symbol && static_symbol?(comptime_obj)
+ # We will guard STATIC vs DYNAMIC as though they were separate classes
+ # DYNAMIC symbols can be handled by the general else case below
+ asm.comment('guard object is static symbol')
+ assert_equal(8, C.RUBY_SPECIAL_SHIFT)
+ asm.cmp(BytePtr[*obj_opnd], C.RUBY_SYMBOL_FLAG)
+ jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:)
elsif known_klass == Float
asm.incr_counter(:send_guard_float)
return CantCompile
@@ -1471,7 +1475,7 @@ module RubyVM::MJIT
# Bail if receiver class is different from known_klass
klass_opnd = [obj_opnd, C.RBasic.offsetof(:klass)]
- asm.comment('guard known class')
+ asm.comment("guard known class #{known_klass}")
asm.mov(:rcx, to_value(known_klass))
asm.cmp(klass_opnd, :rcx)
jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:)
@@ -1593,7 +1597,7 @@ module RubyVM::MJIT
end
# Guard that a is a String
- if jit_guard_known_class(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile
return false
end
@@ -1609,7 +1613,7 @@ module RubyVM::MJIT
# Otherwise guard that b is a T_STRING (from type info) or String (from runtime guard)
# Note: any T_STRING is valid here, but we check for a ::String for simplicity
# To pass a mutable static variable (rb_cString) requires an unsafe block
- if jit_guard_known_class(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile
return false
end
@@ -1639,6 +1643,7 @@ module RubyVM::MJIT
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
def jit_prepare_routine_call(jit, ctx, asm)
+ jit.record_boundary_patch_point = true
jit_save_pc(jit, asm)
jit_save_sp(jit, ctx, asm)
end
@@ -1672,6 +1677,17 @@ module RubyVM::MJIT
reset_depth.chain_depth = 0
next_pc = jit.pc + jit.insn.len * C.VALUE.size
+
+ # We are at the end of the current instruction. Record the boundary.
+ if jit.record_boundary_patch_point
+ exit_pos = Assembler.new.then do |ocb_asm|
+ @exit_compiler.compile_side_exit(next_pc, ctx, ocb_asm)
+ @ocb.write(ocb_asm)
+ end
+ Invariants.record_global_inval_patch(asm, exit_pos)
+ jit.record_boundary_patch_point = false
+ end
+
stub_next_block(jit.iseq, next_pc, reset_depth, asm, comment: 'jump_to_next_insn')
end
@@ -1781,7 +1797,8 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_method(jit, ctx, asm, mid, argc, flags)
+ # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND
+ def jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift: 0)
# Specialize on a compile-time receiver, and split a block for chain guards
unless jit.at_current_insn?
defer_compilation(jit, ctx, asm)
@@ -1791,22 +1808,22 @@ module RubyVM::MJIT
# Generate a side exit
side_exit = side_exit(jit, ctx)
- # Calculate a receiver index
+ # kw_splat is not supported yet
if flags & C.VM_CALL_KW_SPLAT != 0
- # recv_index calculation may not work for this
asm.incr_counter(:send_kw_splat)
return CantCompile
end
- recv_index = argc # TODO: +1 for VM_CALL_ARGS_BLOCKARG
# Get a compile-time receiver and its class
- comptime_recv = jit.peek_at_stack(recv_index)
+ recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0)
+ recv_idx += send_shift
+ comptime_recv = jit.peek_at_stack(recv_idx)
comptime_recv_klass = C.rb_class_of(comptime_recv)
# Guard the receiver class (part of vm_search_method_fastpath)
- recv_opnd = ctx.stack_opnd(recv_index)
+ recv_opnd = ctx.stack_opnd(recv_idx)
megamorphic_exit = counted_exit(side_exit, :send_klass_megamorphic)
- if jit_guard_known_class(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile
return CantCompile
end
@@ -1838,27 +1855,27 @@ module RubyVM::MJIT
# Invalidate on redefinition (part of vm_search_method_fastpath)
Invariants.assume_method_lookup_stable(jit, cme)
- jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd)
+ jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:)
end
# vm_call_method_each_type
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd)
+ def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:)
case cme.def.type
when C.VM_METHOD_TYPE_ISEQ
- jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc)
+ jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.VM_METHOD_TYPE_NOTIMPLEMENTED
asm.incr_counter(:send_notimplemented)
return CantCompile
when C.VM_METHOD_TYPE_CFUNC
- jit_call_cfunc(jit, ctx, asm, cme, flags, argc)
+ jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.VM_METHOD_TYPE_ATTRSET
asm.incr_counter(:send_attrset)
return CantCompile
when C.VM_METHOD_TYPE_IVAR
- jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd)
+ jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:)
when C.VM_METHOD_TYPE_MISSING
asm.incr_counter(:send_missing)
return CantCompile
@@ -1869,7 +1886,7 @@ module RubyVM::MJIT
asm.incr_counter(:send_alias)
return CantCompile
when C.VM_METHOD_TYPE_OPTIMIZED
- jit_call_optimized(jit, ctx, asm, cme, flags, argc)
+ jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.VM_METHOD_TYPE_UNDEF
asm.incr_counter(:send_undef)
return CantCompile
@@ -1889,7 +1906,7 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc)
+ def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:)
iseq = def_iseq_ptr(cme.def)
opt_pc = jit_callee_setup_arg(jit, ctx, asm, flags, argc, iseq)
if opt_pc == CantCompile
@@ -1902,19 +1919,24 @@ module RubyVM::MJIT
asm.incr_counter(:send_tailcall)
return CantCompile
end
- jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq)
+ jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:)
end
# vm_call_iseq_setup_normal (vm_call_iseq_setup_2 -> vm_call_iseq_setup_normal)
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq)
+ def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:)
+ # We will not have side exits from here. Adjust the stack.
+ if flags & C.VM_CALL_OPT_SEND != 0
+ jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:)
+ end
+
# Save caller SP and PC before pushing a callee frame for backtrace and side exits
asm.comment('save SP to caller CFP')
- # Not setting this to SP register. This cfp->sp will be copied to SP on leave insn.
- sp_index = -(1 + argc) # Pop receiver and arguments for side exits # TODO: subtract one more for VM_CALL_ARGS_BLOCKARG
- asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * sp_index))
+ recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0)
+ # Skip setting this to SP register. This cfp->sp will be copied to SP on leave insn.
+ asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * -(1 + recv_idx))) # Pop receiver and arguments to prepare for side exits
asm.mov([CFP, C.rb_control_frame_t.offsetof(:sp)], :rax)
jit_save_pc(jit, asm, comment: 'save PC to caller CFP')
@@ -1937,7 +1959,7 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_cfunc(jit, ctx, asm, cme, flags, argc)
+ def jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:)
if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile
return CantCompile
end
@@ -1945,14 +1967,14 @@ module RubyVM::MJIT
return CantCompile
end
- jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc)
+ jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:)
end
# jit_call_cfunc_with_frame
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc)
+ def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:)
cfunc = cme.def.body.cfunc
if argc + 1 > 6
@@ -1981,6 +2003,11 @@ module RubyVM::MJIT
return CantCompile
end
+ # We will not have side exits from here. Adjust the stack.
+ if flags & C.VM_CALL_OPT_SEND != 0
+ jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:)
+ end
+
# Check interrupts before SP motion to safely side-exit with the original SP.
jit_check_ints(jit, ctx, asm)
@@ -2026,11 +2053,11 @@ module RubyVM::MJIT
EndBlock
end
- # vm_call_ivar
+ # vm_call_ivar (+ part of vm_call_method_each_type)
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd)
+ def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:)
if flags & C.VM_CALL_ARGS_SPLAT != 0
asm.incr_counter(:send_ivar_splat)
return CantCompile
@@ -2041,6 +2068,7 @@ module RubyVM::MJIT
return CantCompile
end
+ # We don't support jit_call_opt_send_shift_stack for this yet.
if flags & C.VM_CALL_OPT_SEND != 0
asm.incr_counter(:send_ivar_opt_send)
return CantCompile
@@ -2048,7 +2076,7 @@ module RubyVM::MJIT
ivar_id = cme.def.body.attr.id
- if flags & C.VM_CALL_OPT_SEND != 0
+ if flags & C.VM_CALL_ARGS_BLOCKARG != 0
asm.incr_counter(:send_ivar_blockarg)
return CantCompile
end
@@ -2060,11 +2088,10 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_optimized(jit, ctx, asm, cme, flags, argc)
+ def jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:)
case cme.def.body.optimized.type
when C.OPTIMIZED_METHOD_TYPE_SEND
- asm.incr_counter(:send_optimized_send)
- return CantCompile
+ jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.OPTIMIZED_METHOD_TYPE_CALL
asm.incr_counter(:send_optimized_call)
return CantCompile
@@ -2083,6 +2110,87 @@ module RubyVM::MJIT
end
end
+ # vm_call_opt_send
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:)
+ if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile
+ return CantCompile
+ end
+
+ if argc == 0
+ asm.incr_counter(:send_optimized_send_no_args)
+ return CantCompile
+ end
+
+ argc -= 1
+ # We aren't handling `send(:send, ...)` yet. This might work, but not tested yet.
+ if send_shift > 0
+ asm.incr_counter(:send_optimized_send_send)
+ return CantCompile
+ end
+ # Ideally, we want to shift the stack here, but it's not safe until you reach the point
+ # where you never exit. `send_shift` signals to lazily shift the stack by this amount.
+ send_shift += 1
+
+ kw_splat = flags & C.VM_CALL_KW_SPLAT != 0
+ jit_call_symbol(jit, ctx, asm, cme, C.VM_CALL_FCALL, argc, kw_splat, send_shift:)
+ end
+
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:)
+ # We don't support `send(:send, ...)` for now.
+ assert_equal(1, send_shift)
+
+ asm.comment('shift stack')
+ (0...argc).reverse_each do |i|
+ opnd = ctx.stack_opnd(i)
+ opnd2 = ctx.stack_opnd(i + 1)
+ asm.mov(:rax, opnd)
+ asm.mov(opnd2, :rax)
+ end
+
+ ctx.stack_pop(1)
+ end
+
+ # vm_call_symbol
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_symbol(jit, ctx, asm, cme, flags, argc, kw_splat, send_shift:)
+ flags |= C.VM_CALL_OPT_SEND | (kw_splat ? C.VM_CALL_KW_SPLAT : 0)
+
+ comptime_symbol = jit.peek_at_stack(argc)
+ if comptime_symbol.class != String && !static_symbol?(comptime_symbol)
+ asm.incr_counter(:send_optimized_send_not_sym_or_str)
+ return CantCompile
+ end
+
+ mid = C.get_symbol_id(comptime_symbol)
+ if mid == 0
+ asm.incr_counter(:send_optimized_send_null_mid)
+ return CantCompile
+ end
+
+ asm.comment("Guard #{comptime_symbol.inspect} is on stack")
+ mid_changed_exit = counted_exit(side_exit(jit, ctx), :send_optimized_send_mid_changed)
+ if jit_guard_known_klass(jit, ctx, asm, comptime_symbol.class, ctx.stack_opnd(argc), comptime_symbol, mid_changed_exit) == CantCompile
+ return CantCompile
+ end
+ asm.mov(C_ARGS[0], ctx.stack_opnd(argc))
+ asm.call(C.rb_get_symbol_id)
+ asm.cmp(C_RET, mid)
+ jit_chain_guard(:jne, jit, ctx, asm, mid_changed_exit)
+
+ if flags & C.VM_CALL_FCALL != 0
+ return jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift:)
+ end
+
+ raise NotImplementedError # unreachable for now
+ end
+
# vm_push_frame
#
# Frame structure:
@@ -2253,8 +2361,11 @@ module RubyVM::MJIT
end
def fixnum?(obj)
- flag = C.RUBY_FIXNUM_FLAG
- (C.to_value(obj) & flag) == flag
+ (C.to_value(obj) & C.RUBY_FIXNUM_FLAG) == C.RUBY_FIXNUM_FLAG
+ end
+
+ def static_symbol?(obj)
+ (C.to_value(obj) & 0xff) == C.RUBY_SYMBOL_FLAG
end
# @param jit [RubyVM::MJIT::JITState]
@@ -2296,7 +2407,7 @@ module RubyVM::MJIT
return side_exit
end
asm = Assembler.new
- @exit_compiler.compile_side_exit(jit, ctx, asm)
+ @exit_compiler.compile_side_exit(jit.pc, ctx, asm)
jit.side_exits[jit.pc] = @ocb.write(asm)
end
diff --git a/lib/ruby_vm/mjit/jit_state.rb b/lib/ruby_vm/mjit/jit_state.rb
index b90d2f8020..3e5fe996fa 100644
--- a/lib/ruby_vm/mjit/jit_state.rb
+++ b/lib/ruby_vm/mjit/jit_state.rb
@@ -1,12 +1,13 @@
module RubyVM::MJIT
class JITState < Struct.new(
- :iseq, # @param `RubyVM::MJIT::CPointer::Struct_rb_iseq_t`
- :pc, # @param [Integer] The JIT target PC
- :cfp, # @param `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t` The JIT source CFP (before MJIT is called)
- :block, # @param [RubyVM::MJIT::Block]
- :side_exits, # @param [Hash{ Integer => Integer }] { PC => address }
+ :iseq, # @param `RubyVM::MJIT::CPointer::Struct_rb_iseq_t`
+ :pc, # @param [Integer] The JIT target PC
+ :cfp, # @param `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t` The JIT source CFP (before MJIT is called)
+ :block, # @param [RubyVM::MJIT::Block]
+ :side_exits, # @param [Hash{ Integer => Integer }] { PC => address }
+ :record_boundary_patch_point, # @param [TrueClass,FalseClass]
)
- def initialize(side_exits: {}, **) = super
+ def initialize(side_exits: {}, record_boundary_patch_point: false, **) = super
def insn
Compiler.decode_insn(C.VALUE.new(pc).*)
diff --git a/mjit_c.c b/mjit_c.c
index f9f43ffe17..ec2eb4bd9c 100644
--- a/mjit_c.c
+++ b/mjit_c.c
@@ -117,6 +117,7 @@ mjit_for_each_iseq(rb_execution_context_t *ec, VALUE self, VALUE block)
}
extern bool rb_simple_iseq_p(const rb_iseq_t *iseq);
+extern ID rb_get_symbol_id(VALUE name);
#include "mjit_c.rbinc"
diff --git a/mjit_c.h b/mjit_c.h
index 6506c46d52..57228ea961 100644
--- a/mjit_c.h
+++ b/mjit_c.h
@@ -144,13 +144,18 @@ MJIT_RUNTIME_COUNTERS(
send_ivar_opt_send,
send_ivar_blockarg,
- send_optimized_send,
send_optimized_call,
send_optimized_block_call,
send_optimized_struct_aref,
send_optimized_struct_aset,
send_optimized_unknown_type,
+ send_optimized_send_no_args,
+ send_optimized_send_not_sym_or_str,
+ send_optimized_send_mid_changed,
+ send_optimized_send_null_mid,
+ send_optimized_send_send,
+
send_guard_symbol,
send_guard_float,
diff --git a/mjit_c.rb b/mjit_c.rb
index ef515f5bea..4b5ee5a4e1 100644
--- a/mjit_c.rb
+++ b/mjit_c.rb
@@ -180,6 +180,14 @@ module RubyVM::MJIT # :nodoc: all
Primitive.cexpr! 'SIZET2NUM((size_t)rb_hash_aset)'
end
+ def get_symbol_id(name)
+ Primitive.cexpr! 'SIZET2NUM((size_t)rb_get_symbol_id(name))'
+ end
+
+ def rb_get_symbol_id
+ Primitive.cexpr! 'SIZET2NUM((size_t)rb_get_symbol_id)'
+ end
+
#========================================================================================
#
# Old stuff
@@ -629,6 +637,14 @@ module RubyVM::MJIT # :nodoc: all
Primitive.cexpr! %q{ ULONG2NUM(RUBY_IMMEDIATE_MASK) }
end
+ def C.RUBY_SPECIAL_SHIFT
+ Primitive.cexpr! %q{ ULONG2NUM(RUBY_SPECIAL_SHIFT) }
+ end
+
+ def C.RUBY_SYMBOL_FLAG
+ Primitive.cexpr! %q{ ULONG2NUM(RUBY_SYMBOL_FLAG) }
+ end
+
def C.RUBY_T_ARRAY
Primitive.cexpr! %q{ ULONG2NUM(RUBY_T_ARRAY) }
end
@@ -1165,12 +1181,16 @@ module RubyVM::MJIT # :nodoc: all
send_ivar_splat: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_ivar_splat)")],
send_ivar_opt_send: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_ivar_opt_send)")],
send_ivar_blockarg: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_ivar_blockarg)")],
- send_optimized_send: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send)")],
send_optimized_call: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_call)")],
send_optimized_block_call: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_block_call)")],
send_optimized_struct_aref: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_struct_aref)")],
send_optimized_struct_aset: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_struct_aset)")],
send_optimized_unknown_type: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_unknown_type)")],
+ send_optimized_send_no_args: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_no_args)")],
+ send_optimized_send_not_sym_or_str: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_not_sym_or_str)")],
+ send_optimized_send_mid_changed: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_mid_changed)")],
+ send_optimized_send_null_mid: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_null_mid)")],
+ send_optimized_send_send: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_optimized_send_send)")],
send_guard_symbol: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_guard_symbol)")],
send_guard_float: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), send_guard_float)")],
getivar_megamorphic: [CType::Immediate.parse("size_t"), Primitive.cexpr!("OFFSETOF((*((struct rb_mjit_runtime_counters *)NULL)), getivar_megamorphic)")],
diff --git a/tool/mjit/bindgen.rb b/tool/mjit/bindgen.rb
index dc2e76dd68..6e67702ea4 100755
--- a/tool/mjit/bindgen.rb
+++ b/tool/mjit/bindgen.rb
@@ -424,6 +424,8 @@ generator = BindingGenerator.new(
INVALID_SHAPE_ID
OBJ_TOO_COMPLEX_SHAPE_ID
RUBY_FIXNUM_FLAG
+ RUBY_SYMBOL_FLAG
+ RUBY_SPECIAL_SHIFT
RUBY_IMMEDIATE_MASK
RARRAY_EMBED_LEN_MASK
RARRAY_EMBED_LEN_SHIFT