diff options
author | Max Bernstein <[email protected]> | 2025-04-25 14:04:45 -0400 |
---|---|---|
committer | GitHub <[email protected]> | 2025-04-25 11:04:45 -0700 |
commit | e0545a02503983e8824d0fb5972c15d51093d927 (patch) | |
tree | 3919b37bbbdc18e8088e6f10f0123e932109b923 | |
parent | 871d07a20e3de00bdd15a2d522e9a4220889fe60 (diff) |
ZJIT: Bail out of HIR translation if we can't handle a send flag (#13182)
Bail out of HIR translation if we can't handle a send flag
Notes
Notes:
Merged-By: k0kubun <[email protected]>
-rw-r--r-- | zjit/src/cruby.rs | 3 | ||||
-rw-r--r-- | zjit/src/hir.rs | 133 |
2 files changed, 134 insertions, 2 deletions
diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 82b77e0402..b813cf6ba4 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -887,12 +887,15 @@ mod manual_defs { // From vm_callinfo.h - uses calculation that seems to confuse bindgen pub const VM_CALL_ARGS_SIMPLE: u32 = 1 << VM_CALL_ARGS_SIMPLE_bit; pub const VM_CALL_ARGS_SPLAT: u32 = 1 << VM_CALL_ARGS_SPLAT_bit; + pub const VM_CALL_ARGS_SPLAT_MUT: u32 = 1 << VM_CALL_ARGS_SPLAT_MUT_bit; pub const VM_CALL_ARGS_BLOCKARG: u32 = 1 << VM_CALL_ARGS_BLOCKARG_bit; pub const VM_CALL_FORWARDING: u32 = 1 << VM_CALL_FORWARDING_bit; pub const VM_CALL_FCALL: u32 = 1 << VM_CALL_FCALL_bit; pub const VM_CALL_KWARG: u32 = 1 << VM_CALL_KWARG_bit; pub const VM_CALL_KW_SPLAT: u32 = 1 << VM_CALL_KW_SPLAT_bit; + pub const VM_CALL_KW_SPLAT_MUT: u32 = 1 << VM_CALL_KW_SPLAT_MUT_bit; pub const VM_CALL_TAILCALL: u32 = 1 << VM_CALL_TAILCALL_bit; + pub const VM_CALL_SUPER : u32 = 1 << VM_CALL_SUPER_bit; pub const VM_CALL_ZSUPER : u32 = 1 << VM_CALL_ZSUPER_bit; pub const VM_CALL_OPT_SEND : u32 = 1 << VM_CALL_OPT_SEND_bit; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 196b438f17..17f7cb0e54 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -1410,7 +1410,7 @@ impl<'a> std::fmt::Display for FunctionPrinter<'a> { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct FrameState { iseq: IseqPtr, insn_idx: usize, @@ -1556,10 +1556,26 @@ fn compute_jump_targets(iseq: *const rb_iseq_t) -> Vec<u32> { result } -#[derive(Debug)] +#[derive(Debug, PartialEq)] +pub enum CallType { + Splat, + BlockArg, + Kwarg, + KwSplat, + Tailcall, + Super, + Zsuper, + OptSend, + KwSplatMut, + SplatMut, + Forwarding, +} + +#[derive(Debug, PartialEq)] pub enum ParseError { StackUnderflow(FrameState), UnknownOpcode(String), + UnhandledCallType(CallType), } fn num_lead_params(iseq: *const rb_iseq_t) -> usize { @@ -1573,6 +1589,22 @@ fn num_locals(iseq: *const rb_iseq_t) -> usize { (unsafe { get_iseq_body_local_table_size(iseq) }) as usize } +/// If we can't handle the type of send (yet), bail out. +fn filter_translatable_calls(flag: u32) -> Result<(), ParseError> { + if (flag & VM_CALL_KW_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplatMut)); } + if (flag & VM_CALL_ARGS_SPLAT_MUT) != 0 { return Err(ParseError::UnhandledCallType(CallType::SplatMut)); } + if (flag & VM_CALL_ARGS_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::Splat)); } + if (flag & VM_CALL_KW_SPLAT) != 0 { return Err(ParseError::UnhandledCallType(CallType::KwSplat)); } + if (flag & VM_CALL_ARGS_BLOCKARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::BlockArg)); } + if (flag & VM_CALL_KWARG) != 0 { return Err(ParseError::UnhandledCallType(CallType::Kwarg)); } + if (flag & VM_CALL_TAILCALL) != 0 { return Err(ParseError::UnhandledCallType(CallType::Tailcall)); } + if (flag & VM_CALL_SUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Super)); } + if (flag & VM_CALL_ZSUPER) != 0 { return Err(ParseError::UnhandledCallType(CallType::Zsuper)); } + if (flag & VM_CALL_OPT_SEND) != 0 { return Err(ParseError::UnhandledCallType(CallType::OptSend)); } + if (flag & VM_CALL_FORWARDING) != 0 { return Err(ParseError::UnhandledCallType(CallType::Forwarding)); } + Ok(()) +} + /// Compile ISEQ into High-level IR pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let mut fun = Function::new(iseq); @@ -1757,6 +1789,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { // NB: opt_neq has two cd; get_arg(0) is for eq and get_arg(1) is for neq let cd: *const rb_call_data = get_arg(pc, 1).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; + filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -1797,6 +1830,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { YARVINSN_opt_send_without_block => { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let call_info = unsafe { rb_get_call_data_ci(cd) }; + filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; let argc = unsafe { vm_ci_argc((*cd).ci) }; @@ -1819,6 +1853,7 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { let cd: *const rb_call_data = get_arg(pc, 0).as_ptr(); let blockiseq: IseqPtr = get_arg(pc, 1).as_iseq(); let call_info = unsafe { rb_get_call_data_ci(cd) }; + filter_translatable_calls(unsafe { rb_vm_ci_flag(call_info) })?; let argc = unsafe { vm_ci_argc((*cd).ci) }; let method_name = unsafe { @@ -2134,6 +2169,16 @@ mod tests { expected_hir.assert_eq(&actual_hir); } + #[track_caller] + fn assert_compile_fails(method: &str, reason: ParseError) { + let iseq = crate::cruby::with_rubyvm(|| get_method_iseq(method)); + unsafe { crate::cruby::rb_zjit_profile_disable(iseq) }; + let result = iseq_to_hir(iseq); + assert!(result.is_err(), "Expected an error but succesfully compiled to HIR"); + assert_eq!(result.unwrap_err(), reason); + } + + #[test] fn test_putobject() { eval("def test = 123"); @@ -2610,6 +2655,90 @@ mod tests { Return v13 "#]]); } + + #[test] + fn test_cant_compile_splat() { + eval(" + def test(a) = foo(*a) + "); + assert_compile_fails("test", ParseError::UnknownOpcode("splatarray".into())) + } + + #[test] + fn test_cant_compile_block_arg() { + eval(" + def test(a) = foo(&a) + "); + assert_compile_fails("test", ParseError::UnhandledCallType(CallType::BlockArg)) + } + + #[test] + fn test_cant_compile_kwarg() { + eval(" + def test(a) = foo(a: 1) + "); + assert_compile_fails("test", ParseError::UnhandledCallType(CallType::Kwarg)) + } + + #[test] + fn test_cant_compile_kw_splat() { + eval(" + def test(a) = foo(**a) + "); + assert_compile_fails("test", ParseError::UnhandledCallType(CallType::KwSplat)) + } + + // TODO(max): Figure out how to generate a call with TAILCALL flag + + #[test] + fn test_cant_compile_super() { + eval(" + def test = super() + "); + assert_compile_fails("test", ParseError::UnknownOpcode("invokesuper".into())) + } + + #[test] + fn test_cant_compile_zsuper() { + eval(" + def test = super + "); + assert_compile_fails("test", ParseError::UnknownOpcode("invokesuper".into())) + } + + #[test] + fn test_cant_compile_super_forward() { + eval(" + def test(...) = super(...) + "); + assert_compile_fails("test", ParseError::UnknownOpcode("invokesuperforward".into())) + } + + // TODO(max): Figure out how to generate a call with OPT_SEND flag + + #[test] + fn test_cant_compile_kw_splat_mut() { + eval(" + def test(a) = foo **a, b: 1 + "); + assert_compile_fails("test", ParseError::UnknownOpcode("putspecialobject".into())) + } + + #[test] + fn test_cant_compile_splat_mut() { + eval(" + def test(*) = foo *, 1 + "); + assert_compile_fails("test", ParseError::UnknownOpcode("splatarray".into())) + } + + #[test] + fn test_cant_compile_forwarding() { + eval(" + def test(...) = foo(...) + "); + assert_compile_fails("test", ParseError::UnknownOpcode("sendforward".into())) + } } #[cfg(test)] |