summaryrefslogtreecommitdiff
path: root/lib/ruby_vm/mjit/insn_compiler.rb
diff options
context:
space:
mode:
authorTakashi Kokubun <[email protected]>2023-02-16 22:29:58 -0800
committerTakashi Kokubun <[email protected]>2023-03-05 23:28:59 -0800
commit2cc4f506bac0748277b41a4a5eb6f0ec41dd7344 (patch)
tree447fda3d7e563383f04292504b526d786475b906 /lib/ruby_vm/mjit/insn_compiler.rb
parent2603d7a0b7b0e698bed8910a8ad5edefe236de77 (diff)
Implement optimized send
Notes
Notes: Merged: https://github.com/ruby/ruby/pull/7448
Diffstat (limited to 'lib/ruby_vm/mjit/insn_compiler.rb')
-rw-r--r--lib/ruby_vm/mjit/insn_compiler.rb197
1 files changed, 154 insertions, 43 deletions
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb
index 1df6809c2b..a06030fd6f 100644
--- a/lib/ruby_vm/mjit/insn_compiler.rb
+++ b/lib/ruby_vm/mjit/insn_compiler.rb
@@ -966,7 +966,7 @@ module RubyVM::MJIT
recv_opnd = ctx.stack_opnd(1)
not_array_exit = counted_exit(side_exit, :optaref_recv_not_array)
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile
return CantCompile
end
@@ -1001,7 +1001,7 @@ module RubyVM::MJIT
# Guard that the receiver is a Hash
not_hash_exit = counted_exit(side_exit, :optaref_recv_not_hash)
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile
return CantCompile
end
@@ -1051,12 +1051,12 @@ module RubyVM::MJIT
side_exit = side_exit(jit, ctx)
# Guard receiver is an Array
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
return CantCompile
end
# Guard key is a fixnum
- if jit_guard_known_class(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile
return CantCompile
end
@@ -1090,7 +1090,7 @@ module RubyVM::MJIT
side_exit = side_exit(jit, ctx)
# Guard receiver is a Hash
- if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile
return CantCompile
end
@@ -1425,7 +1425,7 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_guard_known_class(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5)
+ def jit_guard_known_klass(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5)
# Only memory operand is supported for now
assert_equal(true, obj_opnd.is_a?(Array))
@@ -1445,9 +1445,13 @@ module RubyVM::MJIT
asm.comment('guard object is fixnum')
asm.test(obj_opnd, C.RUBY_FIXNUM_FLAG)
jit_chain_guard(:jz, jit, ctx, asm, side_exit, limit:)
- elsif known_klass == Symbol
- asm.incr_counter(:send_guard_symbol)
- return CantCompile
+ elsif known_klass == Symbol && static_symbol?(comptime_obj)
+ # We will guard STATIC vs DYNAMIC as though they were separate classes
+ # DYNAMIC symbols can be handled by the general else case below
+ asm.comment('guard object is static symbol')
+ assert_equal(8, C.RUBY_SPECIAL_SHIFT)
+ asm.cmp(BytePtr[*obj_opnd], C.RUBY_SYMBOL_FLAG)
+ jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:)
elsif known_klass == Float
asm.incr_counter(:send_guard_float)
return CantCompile
@@ -1471,7 +1475,7 @@ module RubyVM::MJIT
# Bail if receiver class is different from known_klass
klass_opnd = [obj_opnd, C.RBasic.offsetof(:klass)]
- asm.comment('guard known class')
+ asm.comment("guard known class #{known_klass}")
asm.mov(:rcx, to_value(known_klass))
asm.cmp(klass_opnd, :rcx)
jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:)
@@ -1593,7 +1597,7 @@ module RubyVM::MJIT
end
# Guard that a is a String
- if jit_guard_known_class(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile
return false
end
@@ -1609,7 +1613,7 @@ module RubyVM::MJIT
# Otherwise guard that b is a T_STRING (from type info) or String (from runtime guard)
# Note: any T_STRING is valid here, but we check for a ::String for simplicity
# To pass a mutable static variable (rb_cString) requires an unsafe block
- if jit_guard_known_class(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile
return false
end
@@ -1639,6 +1643,7 @@ module RubyVM::MJIT
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
def jit_prepare_routine_call(jit, ctx, asm)
+ jit.record_boundary_patch_point = true
jit_save_pc(jit, asm)
jit_save_sp(jit, ctx, asm)
end
@@ -1672,6 +1677,17 @@ module RubyVM::MJIT
reset_depth.chain_depth = 0
next_pc = jit.pc + jit.insn.len * C.VALUE.size
+
+ # We are at the end of the current instruction. Record the boundary.
+ if jit.record_boundary_patch_point
+ exit_pos = Assembler.new.then do |ocb_asm|
+ @exit_compiler.compile_side_exit(next_pc, ctx, ocb_asm)
+ @ocb.write(ocb_asm)
+ end
+ Invariants.record_global_inval_patch(asm, exit_pos)
+ jit.record_boundary_patch_point = false
+ end
+
stub_next_block(jit.iseq, next_pc, reset_depth, asm, comment: 'jump_to_next_insn')
end
@@ -1781,7 +1797,8 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_method(jit, ctx, asm, mid, argc, flags)
+ # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND
+ def jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift: 0)
# Specialize on a compile-time receiver, and split a block for chain guards
unless jit.at_current_insn?
defer_compilation(jit, ctx, asm)
@@ -1791,22 +1808,22 @@ module RubyVM::MJIT
# Generate a side exit
side_exit = side_exit(jit, ctx)
- # Calculate a receiver index
+ # kw_splat is not supported yet
if flags & C.VM_CALL_KW_SPLAT != 0
- # recv_index calculation may not work for this
asm.incr_counter(:send_kw_splat)
return CantCompile
end
- recv_index = argc # TODO: +1 for VM_CALL_ARGS_BLOCKARG
# Get a compile-time receiver and its class
- comptime_recv = jit.peek_at_stack(recv_index)
+ recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0)
+ recv_idx += send_shift
+ comptime_recv = jit.peek_at_stack(recv_idx)
comptime_recv_klass = C.rb_class_of(comptime_recv)
# Guard the receiver class (part of vm_search_method_fastpath)
- recv_opnd = ctx.stack_opnd(recv_index)
+ recv_opnd = ctx.stack_opnd(recv_idx)
megamorphic_exit = counted_exit(side_exit, :send_klass_megamorphic)
- if jit_guard_known_class(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile
+ if jit_guard_known_klass(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile
return CantCompile
end
@@ -1838,27 +1855,27 @@ module RubyVM::MJIT
# Invalidate on redefinition (part of vm_search_method_fastpath)
Invariants.assume_method_lookup_stable(jit, cme)
- jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd)
+ jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:)
end
# vm_call_method_each_type
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd)
+ def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:)
case cme.def.type
when C.VM_METHOD_TYPE_ISEQ
- jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc)
+ jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.VM_METHOD_TYPE_NOTIMPLEMENTED
asm.incr_counter(:send_notimplemented)
return CantCompile
when C.VM_METHOD_TYPE_CFUNC
- jit_call_cfunc(jit, ctx, asm, cme, flags, argc)
+ jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.VM_METHOD_TYPE_ATTRSET
asm.incr_counter(:send_attrset)
return CantCompile
when C.VM_METHOD_TYPE_IVAR
- jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd)
+ jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:)
when C.VM_METHOD_TYPE_MISSING
asm.incr_counter(:send_missing)
return CantCompile
@@ -1869,7 +1886,7 @@ module RubyVM::MJIT
asm.incr_counter(:send_alias)
return CantCompile
when C.VM_METHOD_TYPE_OPTIMIZED
- jit_call_optimized(jit, ctx, asm, cme, flags, argc)
+ jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.VM_METHOD_TYPE_UNDEF
asm.incr_counter(:send_undef)
return CantCompile
@@ -1889,7 +1906,7 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc)
+ def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:)
iseq = def_iseq_ptr(cme.def)
opt_pc = jit_callee_setup_arg(jit, ctx, asm, flags, argc, iseq)
if opt_pc == CantCompile
@@ -1902,19 +1919,24 @@ module RubyVM::MJIT
asm.incr_counter(:send_tailcall)
return CantCompile
end
- jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq)
+ jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:)
end
# vm_call_iseq_setup_normal (vm_call_iseq_setup_2 -> vm_call_iseq_setup_normal)
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq)
+ def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:)
+ # We will not have side exits from here. Adjust the stack.
+ if flags & C.VM_CALL_OPT_SEND != 0
+ jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:)
+ end
+
# Save caller SP and PC before pushing a callee frame for backtrace and side exits
asm.comment('save SP to caller CFP')
- # Not setting this to SP register. This cfp->sp will be copied to SP on leave insn.
- sp_index = -(1 + argc) # Pop receiver and arguments for side exits # TODO: subtract one more for VM_CALL_ARGS_BLOCKARG
- asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * sp_index))
+ recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0)
+ # Skip setting this to SP register. This cfp->sp will be copied to SP on leave insn.
+ asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * -(1 + recv_idx))) # Pop receiver and arguments to prepare for side exits
asm.mov([CFP, C.rb_control_frame_t.offsetof(:sp)], :rax)
jit_save_pc(jit, asm, comment: 'save PC to caller CFP')
@@ -1937,7 +1959,7 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_cfunc(jit, ctx, asm, cme, flags, argc)
+ def jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:)
if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile
return CantCompile
end
@@ -1945,14 +1967,14 @@ module RubyVM::MJIT
return CantCompile
end
- jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc)
+ jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:)
end
# jit_call_cfunc_with_frame
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc)
+ def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:)
cfunc = cme.def.body.cfunc
if argc + 1 > 6
@@ -1981,6 +2003,11 @@ module RubyVM::MJIT
return CantCompile
end
+ # We will not have side exits from here. Adjust the stack.
+ if flags & C.VM_CALL_OPT_SEND != 0
+ jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:)
+ end
+
# Check interrupts before SP motion to safely side-exit with the original SP.
jit_check_ints(jit, ctx, asm)
@@ -2026,11 +2053,11 @@ module RubyVM::MJIT
EndBlock
end
- # vm_call_ivar
+ # vm_call_ivar (+ part of vm_call_method_each_type)
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd)
+ def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:)
if flags & C.VM_CALL_ARGS_SPLAT != 0
asm.incr_counter(:send_ivar_splat)
return CantCompile
@@ -2041,6 +2068,7 @@ module RubyVM::MJIT
return CantCompile
end
+ # We don't support jit_call_opt_send_shift_stack for this yet.
if flags & C.VM_CALL_OPT_SEND != 0
asm.incr_counter(:send_ivar_opt_send)
return CantCompile
@@ -2048,7 +2076,7 @@ module RubyVM::MJIT
ivar_id = cme.def.body.attr.id
- if flags & C.VM_CALL_OPT_SEND != 0
+ if flags & C.VM_CALL_ARGS_BLOCKARG != 0
asm.incr_counter(:send_ivar_blockarg)
return CantCompile
end
@@ -2060,11 +2088,10 @@ module RubyVM::MJIT
# @param jit [RubyVM::MJIT::JITState]
# @param ctx [RubyVM::MJIT::Context]
# @param asm [RubyVM::MJIT::Assembler]
- def jit_call_optimized(jit, ctx, asm, cme, flags, argc)
+ def jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:)
case cme.def.body.optimized.type
when C.OPTIMIZED_METHOD_TYPE_SEND
- asm.incr_counter(:send_optimized_send)
- return CantCompile
+ jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:)
when C.OPTIMIZED_METHOD_TYPE_CALL
asm.incr_counter(:send_optimized_call)
return CantCompile
@@ -2083,6 +2110,87 @@ module RubyVM::MJIT
end
end
+ # vm_call_opt_send
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:)
+ if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile
+ return CantCompile
+ end
+
+ if argc == 0
+ asm.incr_counter(:send_optimized_send_no_args)
+ return CantCompile
+ end
+
+ argc -= 1
+ # We aren't handling `send(:send, ...)` yet. This might work, but not tested yet.
+ if send_shift > 0
+ asm.incr_counter(:send_optimized_send_send)
+ return CantCompile
+ end
+ # Ideally, we want to shift the stack here, but it's not safe until you reach the point
+ # where you never exit. `send_shift` signals to lazily shift the stack by this amount.
+ send_shift += 1
+
+ kw_splat = flags & C.VM_CALL_KW_SPLAT != 0
+ jit_call_symbol(jit, ctx, asm, cme, C.VM_CALL_FCALL, argc, kw_splat, send_shift:)
+ end
+
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:)
+ # We don't support `send(:send, ...)` for now.
+ assert_equal(1, send_shift)
+
+ asm.comment('shift stack')
+ (0...argc).reverse_each do |i|
+ opnd = ctx.stack_opnd(i)
+ opnd2 = ctx.stack_opnd(i + 1)
+ asm.mov(:rax, opnd)
+ asm.mov(opnd2, :rax)
+ end
+
+ ctx.stack_pop(1)
+ end
+
+ # vm_call_symbol
+ # @param jit [RubyVM::MJIT::JITState]
+ # @param ctx [RubyVM::MJIT::Context]
+ # @param asm [RubyVM::MJIT::Assembler]
+ def jit_call_symbol(jit, ctx, asm, cme, flags, argc, kw_splat, send_shift:)
+ flags |= C.VM_CALL_OPT_SEND | (kw_splat ? C.VM_CALL_KW_SPLAT : 0)
+
+ comptime_symbol = jit.peek_at_stack(argc)
+ if comptime_symbol.class != String && !static_symbol?(comptime_symbol)
+ asm.incr_counter(:send_optimized_send_not_sym_or_str)
+ return CantCompile
+ end
+
+ mid = C.get_symbol_id(comptime_symbol)
+ if mid == 0
+ asm.incr_counter(:send_optimized_send_null_mid)
+ return CantCompile
+ end
+
+ asm.comment("Guard #{comptime_symbol.inspect} is on stack")
+ mid_changed_exit = counted_exit(side_exit(jit, ctx), :send_optimized_send_mid_changed)
+ if jit_guard_known_klass(jit, ctx, asm, comptime_symbol.class, ctx.stack_opnd(argc), comptime_symbol, mid_changed_exit) == CantCompile
+ return CantCompile
+ end
+ asm.mov(C_ARGS[0], ctx.stack_opnd(argc))
+ asm.call(C.rb_get_symbol_id)
+ asm.cmp(C_RET, mid)
+ jit_chain_guard(:jne, jit, ctx, asm, mid_changed_exit)
+
+ if flags & C.VM_CALL_FCALL != 0
+ return jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift:)
+ end
+
+ raise NotImplementedError # unreachable for now
+ end
+
# vm_push_frame
#
# Frame structure:
@@ -2253,8 +2361,11 @@ module RubyVM::MJIT
end
def fixnum?(obj)
- flag = C.RUBY_FIXNUM_FLAG
- (C.to_value(obj) & flag) == flag
+ (C.to_value(obj) & C.RUBY_FIXNUM_FLAG) == C.RUBY_FIXNUM_FLAG
+ end
+
+ def static_symbol?(obj)
+ (C.to_value(obj) & 0xff) == C.RUBY_SYMBOL_FLAG
end
# @param jit [RubyVM::MJIT::JITState]
@@ -2296,7 +2407,7 @@ module RubyVM::MJIT
return side_exit
end
asm = Assembler.new
- @exit_compiler.compile_side_exit(jit, ctx, asm)
+ @exit_compiler.compile_side_exit(jit.pc, ctx, asm)
jit.side_exits[jit.pc] = @ocb.write(asm)
end