diff options
author | Takashi Kokubun <[email protected]> | 2023-02-16 22:29:58 -0800 |
---|---|---|
committer | Takashi Kokubun <[email protected]> | 2023-03-05 23:28:59 -0800 |
commit | 2cc4f506bac0748277b41a4a5eb6f0ec41dd7344 (patch) | |
tree | 447fda3d7e563383f04292504b526d786475b906 /lib/ruby_vm/mjit/insn_compiler.rb | |
parent | 2603d7a0b7b0e698bed8910a8ad5edefe236de77 (diff) |
Implement optimized send
Notes
Notes:
Merged: https://github.com/ruby/ruby/pull/7448
Diffstat (limited to 'lib/ruby_vm/mjit/insn_compiler.rb')
-rw-r--r-- | lib/ruby_vm/mjit/insn_compiler.rb | 197 |
1 files changed, 154 insertions, 43 deletions
diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb index 1df6809c2b..a06030fd6f 100644 --- a/lib/ruby_vm/mjit/insn_compiler.rb +++ b/lib/ruby_vm/mjit/insn_compiler.rb @@ -966,7 +966,7 @@ module RubyVM::MJIT recv_opnd = ctx.stack_opnd(1) not_array_exit = counted_exit(side_exit, :optaref_recv_not_array) - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_array_exit) == CantCompile return CantCompile end @@ -1001,7 +1001,7 @@ module RubyVM::MJIT # Guard that the receiver is a Hash not_hash_exit = counted_exit(side_exit, :optaref_recv_not_hash) - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv_opnd, comptime_recv, not_hash_exit) == CantCompile return CantCompile end @@ -1051,12 +1051,12 @@ module RubyVM::MJIT side_exit = side_exit(jit, ctx) # Guard receiver is an Array - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile return CantCompile end # Guard key is a fixnum - if jit_guard_known_class(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_key.class, key, comptime_key, side_exit) == CantCompile return CantCompile end @@ -1090,7 +1090,7 @@ module RubyVM::MJIT side_exit = side_exit(jit, ctx) # Guard receiver is a Hash - if jit_guard_known_class(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv.class, recv, comptime_recv, side_exit) == CantCompile return CantCompile end @@ -1425,7 +1425,7 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_guard_known_class(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5) + def jit_guard_known_klass(jit, ctx, asm, known_klass, obj_opnd, comptime_obj, side_exit, limit: 5) # Only memory operand is supported for now assert_equal(true, obj_opnd.is_a?(Array)) @@ -1445,9 +1445,13 @@ module RubyVM::MJIT asm.comment('guard object is fixnum') asm.test(obj_opnd, C.RUBY_FIXNUM_FLAG) jit_chain_guard(:jz, jit, ctx, asm, side_exit, limit:) - elsif known_klass == Symbol - asm.incr_counter(:send_guard_symbol) - return CantCompile + elsif known_klass == Symbol && static_symbol?(comptime_obj) + # We will guard STATIC vs DYNAMIC as though they were separate classes + # DYNAMIC symbols can be handled by the general else case below + asm.comment('guard object is static symbol') + assert_equal(8, C.RUBY_SPECIAL_SHIFT) + asm.cmp(BytePtr[*obj_opnd], C.RUBY_SYMBOL_FLAG) + jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:) elsif known_klass == Float asm.incr_counter(:send_guard_float) return CantCompile @@ -1471,7 +1475,7 @@ module RubyVM::MJIT # Bail if receiver class is different from known_klass klass_opnd = [obj_opnd, C.RBasic.offsetof(:klass)] - asm.comment('guard known class') + asm.comment("guard known class #{known_klass}") asm.mov(:rcx, to_value(known_klass)) asm.cmp(klass_opnd, :rcx) jit_chain_guard(:jne, jit, ctx, asm, side_exit, limit:) @@ -1593,7 +1597,7 @@ module RubyVM::MJIT end # Guard that a is a String - if jit_guard_known_class(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_a.class, a_opnd, comptime_a, side_exit) == CantCompile return false end @@ -1609,7 +1613,7 @@ module RubyVM::MJIT # Otherwise guard that b is a T_STRING (from type info) or String (from runtime guard) # Note: any T_STRING is valid here, but we check for a ::String for simplicity # To pass a mutable static variable (rb_cString) requires an unsafe block - if jit_guard_known_class(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_b.class, b_opnd, comptime_b, side_exit) == CantCompile return false end @@ -1639,6 +1643,7 @@ module RubyVM::MJIT # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] def jit_prepare_routine_call(jit, ctx, asm) + jit.record_boundary_patch_point = true jit_save_pc(jit, asm) jit_save_sp(jit, ctx, asm) end @@ -1672,6 +1677,17 @@ module RubyVM::MJIT reset_depth.chain_depth = 0 next_pc = jit.pc + jit.insn.len * C.VALUE.size + + # We are at the end of the current instruction. Record the boundary. + if jit.record_boundary_patch_point + exit_pos = Assembler.new.then do |ocb_asm| + @exit_compiler.compile_side_exit(next_pc, ctx, ocb_asm) + @ocb.write(ocb_asm) + end + Invariants.record_global_inval_patch(asm, exit_pos) + jit.record_boundary_patch_point = false + end + stub_next_block(jit.iseq, next_pc, reset_depth, asm, comment: 'jump_to_next_insn') end @@ -1781,7 +1797,8 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_method(jit, ctx, asm, mid, argc, flags) + # @param send_shift [Integer] The number of shifts needed for VM_CALL_OPT_SEND + def jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift: 0) # Specialize on a compile-time receiver, and split a block for chain guards unless jit.at_current_insn? defer_compilation(jit, ctx, asm) @@ -1791,22 +1808,22 @@ module RubyVM::MJIT # Generate a side exit side_exit = side_exit(jit, ctx) - # Calculate a receiver index + # kw_splat is not supported yet if flags & C.VM_CALL_KW_SPLAT != 0 - # recv_index calculation may not work for this asm.incr_counter(:send_kw_splat) return CantCompile end - recv_index = argc # TODO: +1 for VM_CALL_ARGS_BLOCKARG # Get a compile-time receiver and its class - comptime_recv = jit.peek_at_stack(recv_index) + recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0) + recv_idx += send_shift + comptime_recv = jit.peek_at_stack(recv_idx) comptime_recv_klass = C.rb_class_of(comptime_recv) # Guard the receiver class (part of vm_search_method_fastpath) - recv_opnd = ctx.stack_opnd(recv_index) + recv_opnd = ctx.stack_opnd(recv_idx) megamorphic_exit = counted_exit(side_exit, :send_klass_megamorphic) - if jit_guard_known_class(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile + if jit_guard_known_klass(jit, ctx, asm, comptime_recv_klass, recv_opnd, comptime_recv, megamorphic_exit) == CantCompile return CantCompile end @@ -1838,27 +1855,27 @@ module RubyVM::MJIT # Invalidate on redefinition (part of vm_search_method_fastpath) Invariants.assume_method_lookup_stable(jit, cme) - jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd) + jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:) end # vm_call_method_each_type # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd) + def jit_call_method_each_type(jit, ctx, asm, argc, flags, cme, comptime_recv, recv_opnd, send_shift:) case cme.def.type when C.VM_METHOD_TYPE_ISEQ - jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc) + jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:) when C.VM_METHOD_TYPE_NOTIMPLEMENTED asm.incr_counter(:send_notimplemented) return CantCompile when C.VM_METHOD_TYPE_CFUNC - jit_call_cfunc(jit, ctx, asm, cme, flags, argc) + jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:) when C.VM_METHOD_TYPE_ATTRSET asm.incr_counter(:send_attrset) return CantCompile when C.VM_METHOD_TYPE_IVAR - jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd) + jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:) when C.VM_METHOD_TYPE_MISSING asm.incr_counter(:send_missing) return CantCompile @@ -1869,7 +1886,7 @@ module RubyVM::MJIT asm.incr_counter(:send_alias) return CantCompile when C.VM_METHOD_TYPE_OPTIMIZED - jit_call_optimized(jit, ctx, asm, cme, flags, argc) + jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:) when C.VM_METHOD_TYPE_UNDEF asm.incr_counter(:send_undef) return CantCompile @@ -1889,7 +1906,7 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc) + def jit_call_iseq_setup(jit, ctx, asm, cme, flags, argc, send_shift:) iseq = def_iseq_ptr(cme.def) opt_pc = jit_callee_setup_arg(jit, ctx, asm, flags, argc, iseq) if opt_pc == CantCompile @@ -1902,19 +1919,24 @@ module RubyVM::MJIT asm.incr_counter(:send_tailcall) return CantCompile end - jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq) + jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:) end # vm_call_iseq_setup_normal (vm_call_iseq_setup_2 -> vm_call_iseq_setup_normal) # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq) + def jit_call_iseq_setup_normal(jit, ctx, asm, cme, flags, argc, iseq, send_shift:) + # We will not have side exits from here. Adjust the stack. + if flags & C.VM_CALL_OPT_SEND != 0 + jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:) + end + # Save caller SP and PC before pushing a callee frame for backtrace and side exits asm.comment('save SP to caller CFP') - # Not setting this to SP register. This cfp->sp will be copied to SP on leave insn. - sp_index = -(1 + argc) # Pop receiver and arguments for side exits # TODO: subtract one more for VM_CALL_ARGS_BLOCKARG - asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * sp_index)) + recv_idx = argc + (flags & C.VM_CALL_ARGS_BLOCKARG != 0 ? 1 : 0) + # Skip setting this to SP register. This cfp->sp will be copied to SP on leave insn. + asm.lea(:rax, ctx.sp_opnd(C.VALUE.size * -(1 + recv_idx))) # Pop receiver and arguments to prepare for side exits asm.mov([CFP, C.rb_control_frame_t.offsetof(:sp)], :rax) jit_save_pc(jit, asm, comment: 'save PC to caller CFP') @@ -1937,7 +1959,7 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_cfunc(jit, ctx, asm, cme, flags, argc) + def jit_call_cfunc(jit, ctx, asm, cme, flags, argc, send_shift:) if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile return CantCompile end @@ -1945,14 +1967,14 @@ module RubyVM::MJIT return CantCompile end - jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc) + jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:) end # jit_call_cfunc_with_frame # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc) + def jit_call_cfunc_with_frame(jit, ctx, asm, cme, flags, argc, send_shift:) cfunc = cme.def.body.cfunc if argc + 1 > 6 @@ -1981,6 +2003,11 @@ module RubyVM::MJIT return CantCompile end + # We will not have side exits from here. Adjust the stack. + if flags & C.VM_CALL_OPT_SEND != 0 + jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:) + end + # Check interrupts before SP motion to safely side-exit with the original SP. jit_check_ints(jit, ctx, asm) @@ -2026,11 +2053,11 @@ module RubyVM::MJIT EndBlock end - # vm_call_ivar + # vm_call_ivar (+ part of vm_call_method_each_type) # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd) + def jit_call_ivar(jit, ctx, asm, cme, flags, argc, comptime_recv, recv_opnd, send_shift:) if flags & C.VM_CALL_ARGS_SPLAT != 0 asm.incr_counter(:send_ivar_splat) return CantCompile @@ -2041,6 +2068,7 @@ module RubyVM::MJIT return CantCompile end + # We don't support jit_call_opt_send_shift_stack for this yet. if flags & C.VM_CALL_OPT_SEND != 0 asm.incr_counter(:send_ivar_opt_send) return CantCompile @@ -2048,7 +2076,7 @@ module RubyVM::MJIT ivar_id = cme.def.body.attr.id - if flags & C.VM_CALL_OPT_SEND != 0 + if flags & C.VM_CALL_ARGS_BLOCKARG != 0 asm.incr_counter(:send_ivar_blockarg) return CantCompile end @@ -2060,11 +2088,10 @@ module RubyVM::MJIT # @param jit [RubyVM::MJIT::JITState] # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] - def jit_call_optimized(jit, ctx, asm, cme, flags, argc) + def jit_call_optimized(jit, ctx, asm, cme, flags, argc, send_shift:) case cme.def.body.optimized.type when C.OPTIMIZED_METHOD_TYPE_SEND - asm.incr_counter(:send_optimized_send) - return CantCompile + jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:) when C.OPTIMIZED_METHOD_TYPE_CALL asm.incr_counter(:send_optimized_call) return CantCompile @@ -2083,6 +2110,87 @@ module RubyVM::MJIT end end + # vm_call_opt_send + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_opt_send(jit, ctx, asm, cme, flags, argc, send_shift:) + if jit_caller_setup_arg(jit, ctx, asm, flags) == CantCompile + return CantCompile + end + + if argc == 0 + asm.incr_counter(:send_optimized_send_no_args) + return CantCompile + end + + argc -= 1 + # We aren't handling `send(:send, ...)` yet. This might work, but not tested yet. + if send_shift > 0 + asm.incr_counter(:send_optimized_send_send) + return CantCompile + end + # Ideally, we want to shift the stack here, but it's not safe until you reach the point + # where you never exit. `send_shift` signals to lazily shift the stack by this amount. + send_shift += 1 + + kw_splat = flags & C.VM_CALL_KW_SPLAT != 0 + jit_call_symbol(jit, ctx, asm, cme, C.VM_CALL_FCALL, argc, kw_splat, send_shift:) + end + + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_opt_send_shift_stack(ctx, asm, argc, send_shift:) + # We don't support `send(:send, ...)` for now. + assert_equal(1, send_shift) + + asm.comment('shift stack') + (0...argc).reverse_each do |i| + opnd = ctx.stack_opnd(i) + opnd2 = ctx.stack_opnd(i + 1) + asm.mov(:rax, opnd) + asm.mov(opnd2, :rax) + end + + ctx.stack_pop(1) + end + + # vm_call_symbol + # @param jit [RubyVM::MJIT::JITState] + # @param ctx [RubyVM::MJIT::Context] + # @param asm [RubyVM::MJIT::Assembler] + def jit_call_symbol(jit, ctx, asm, cme, flags, argc, kw_splat, send_shift:) + flags |= C.VM_CALL_OPT_SEND | (kw_splat ? C.VM_CALL_KW_SPLAT : 0) + + comptime_symbol = jit.peek_at_stack(argc) + if comptime_symbol.class != String && !static_symbol?(comptime_symbol) + asm.incr_counter(:send_optimized_send_not_sym_or_str) + return CantCompile + end + + mid = C.get_symbol_id(comptime_symbol) + if mid == 0 + asm.incr_counter(:send_optimized_send_null_mid) + return CantCompile + end + + asm.comment("Guard #{comptime_symbol.inspect} is on stack") + mid_changed_exit = counted_exit(side_exit(jit, ctx), :send_optimized_send_mid_changed) + if jit_guard_known_klass(jit, ctx, asm, comptime_symbol.class, ctx.stack_opnd(argc), comptime_symbol, mid_changed_exit) == CantCompile + return CantCompile + end + asm.mov(C_ARGS[0], ctx.stack_opnd(argc)) + asm.call(C.rb_get_symbol_id) + asm.cmp(C_RET, mid) + jit_chain_guard(:jne, jit, ctx, asm, mid_changed_exit) + + if flags & C.VM_CALL_FCALL != 0 + return jit_call_method(jit, ctx, asm, mid, argc, flags, send_shift:) + end + + raise NotImplementedError # unreachable for now + end + # vm_push_frame # # Frame structure: @@ -2253,8 +2361,11 @@ module RubyVM::MJIT end def fixnum?(obj) - flag = C.RUBY_FIXNUM_FLAG - (C.to_value(obj) & flag) == flag + (C.to_value(obj) & C.RUBY_FIXNUM_FLAG) == C.RUBY_FIXNUM_FLAG + end + + def static_symbol?(obj) + (C.to_value(obj) & 0xff) == C.RUBY_SYMBOL_FLAG end # @param jit [RubyVM::MJIT::JITState] @@ -2296,7 +2407,7 @@ module RubyVM::MJIT return side_exit end asm = Assembler.new - @exit_compiler.compile_side_exit(jit, ctx, asm) + @exit_compiler.compile_side_exit(jit.pc, ctx, asm) jit.side_exits[jit.pc] = @ocb.write(asm) end |