diff options
Diffstat (limited to 'zjit/src/hir.rs')
-rw-r--r-- | zjit/src/hir.rs | 108 |
1 files changed, 92 insertions, 16 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index f4ecb46383..da56738231 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -141,6 +141,7 @@ impl<'a> std::fmt::Display for InvariantPrinter<'a> { write!(f, "BOPRedefined(")?; match klass { INTEGER_REDEFINED_OP_FLAG => write!(f, "INTEGER_REDEFINED_OP_FLAG")?, + ARRAY_REDEFINED_OP_FLAG => write!(f, "ARRAY_REDEFINED_OP_FLAG")?, _ => write!(f, "{klass}")?, } write!(f, ", ")?; @@ -156,6 +157,7 @@ impl<'a> std::fmt::Display for InvariantPrinter<'a> { BOP_LE => write!(f, "BOP_LE")?, BOP_GT => write!(f, "BOP_GT")?, BOP_GE => write!(f, "BOP_GE")?, + BOP_MAX => write!(f, "BOP_MAX")?, _ => write!(f, "{bop}")?, } write!(f, ")") @@ -310,6 +312,7 @@ pub enum Insn { NewArray { elements: Vec<InsnId>, state: InsnId }, ArraySet { array: InsnId, idx: usize, val: InsnId }, ArrayDup { val: InsnId, state: InsnId }, + ArrayMax { elements: Vec<InsnId>, state: InsnId }, // Check if the value is truthy and "return" a C boolean. In reality, we will likely fuse this // with IfTrue/IfFalse in the backend to generate jcc. @@ -441,6 +444,15 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { } Ok(()) } + Insn::ArrayMax { elements, .. } => { + write!(f, "ArrayMax")?; + let mut prefix = " "; + for element in elements { + write!(f, "{prefix}{element}")?; + prefix = ", "; + } + Ok(()) + } Insn::ArraySet { array, idx, val } => { write!(f, "ArraySet {array}, {idx}, {val}") } Insn::ArrayDup { val, .. } => { write!(f, "ArrayDup {val}") } Insn::StringCopy { val } => { write!(f, "StringCopy {val}") } @@ -619,7 +631,7 @@ impl<T: Copy + Into<usize> + PartialEq> UnionFind<T> { } /// Find the set representative for `insn` without doing path compression. - pub fn find_const(&self, insn: T) -> T { + fn find_const(&self, insn: T) -> T { let mut result = insn; loop { match self.at(result) { @@ -645,7 +657,7 @@ pub struct Function { // TODO: get method name and source location from the ISEQ insns: Vec<Insn>, - union_find: UnionFind<InsnId>, + union_find: std::cell::RefCell<UnionFind<InsnId>>, insn_types: Vec<Type>, blocks: Vec<Block>, entry_block: BlockId, @@ -657,7 +669,7 @@ impl Function { iseq, insns: vec![], insn_types: vec![], - union_find: UnionFind::new(), + union_find: UnionFind::new().into(), blocks: vec![Block::default()], entry_block: BlockId(0), } @@ -740,7 +752,14 @@ impl Function { macro_rules! find { ( $x:expr ) => { { - self.union_find.find_const($x) + self.union_find.borrow_mut().find($x) + } + }; + } + macro_rules! find_vec { + ( $x:expr ) => { + { + $x.iter().map(|arg| find!(*arg)).collect() } }; } @@ -749,15 +768,15 @@ impl Function { { BranchEdge { target: $edge.target, - args: $edge.args.iter().map(|x| self.union_find.find_const(*x)).collect(), + args: find_vec!($edge.args), } } }; } - let insn_id = self.union_find.find_const(insn_id); + let insn_id = self.union_find.borrow_mut().find(insn_id); use Insn::*; match &self.insns[insn_id.0] { - result@(PutSelf | Const {..} | Param {..} | NewArray {..} | GetConstantPath {..} + result@(PutSelf | Const {..} | Param {..} | GetConstantPath {..} | PatchPoint {..}) => result.clone(), Snapshot { state: FrameState { iseq, insn_idx, pc, stack, locals } } => Snapshot { @@ -816,18 +835,20 @@ impl Function { ArrayDup { val , state } => ArrayDup { val: find!(*val), state: *state }, CCall { cfun, args, name, return_type } => CCall { cfun: *cfun, args: args.iter().map(|arg| find!(*arg)).collect(), name: *name, return_type: *return_type }, Defined { .. } => todo!("find(Defined)"), + NewArray { elements, state } => NewArray { elements: find_vec!(*elements), state: find!(*state) }, + ArrayMax { elements, state } => ArrayMax { elements: find_vec!(*elements), state: find!(*state) }, } } /// Replace `insn` with the new instruction `replacement`, which will get appended to `insns`. fn make_equal_to(&mut self, insn: InsnId, replacement: InsnId) { // Don't push it to the block - self.union_find.make_equal_to(insn, replacement); + self.union_find.borrow_mut().make_equal_to(insn, replacement); } fn type_of(&self, insn: InsnId) -> Type { assert!(self.insns[insn.0].has_output()); - self.insn_types[self.union_find.find_const(insn).0] + self.insn_types[self.union_find.borrow_mut().find(insn).0] } /// Check if the type of `insn` is a subtype of `ty`. @@ -882,6 +903,7 @@ impl Function { Insn::PutSelf => types::BasicObject, Insn::Defined { .. } => types::BasicObject, Insn::GetConstantPath { .. } => types::BasicObject, + Insn::ArrayMax { .. } => types::BasicObject, } } @@ -1309,9 +1331,13 @@ impl Function { necessary[insn_id.0] = true; match self.find(insn_id) { Insn::PutSelf | Insn::Const { .. } | Insn::Param { .. } - | Insn::NewArray { .. } | Insn::PatchPoint(..) - | Insn::GetConstantPath { .. } => + | Insn::PatchPoint(..) | Insn::GetConstantPath { .. } => {} + Insn::ArrayMax { elements, state } + | Insn::NewArray { elements, state } => { + worklist.extend(elements); + worklist.push_back(state); + } Insn::StringCopy { val } | Insn::StringIntern { val } | Insn::Return { val } @@ -1636,6 +1662,7 @@ pub enum CallType { pub enum ParseError { StackUnderflow(FrameState), UnknownOpcode(String), + UnknownNewArraySend(String), UnhandledCallType(CallType), } @@ -1755,6 +1782,26 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> { elements.reverse(); state.stack_push(fun.push_insn(block, Insn::NewArray { elements, state: exit_id })); } + YARVINSN_opt_newarray_send => { + let count = get_arg(pc, 0).as_usize(); + let method = get_arg(pc, 1).as_u32(); + let mut elements = vec![]; + for _ in 0..count { + elements.push(state.stack_pop()?); + } + elements.reverse(); + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state.clone() }); + let (bop, insn) = match method { + VM_OPT_NEWARRAY_SEND_MAX => (BOP_MAX, Insn::ArrayMax { elements, state: exit_id }), + VM_OPT_NEWARRAY_SEND_MIN => return Err(ParseError::UnknownNewArraySend("min".into())), + VM_OPT_NEWARRAY_SEND_HASH => return Err(ParseError::UnknownNewArraySend("hash".into())), + VM_OPT_NEWARRAY_SEND_PACK => return Err(ParseError::UnknownNewArraySend("pack".into())), + VM_OPT_NEWARRAY_SEND_PACK_BUFFER => return Err(ParseError::UnknownNewArraySend("pack_buffer".into())), + _ => return Err(ParseError::UnknownNewArraySend(format!("{method}"))), + }; + fun.push_insn(block, Insn::PatchPoint(Invariant::BOPRedefined { klass: ARRAY_REDEFINED_OP_FLAG, bop })); + state.stack_push(fun.push_insn(block, insn)); + } YARVINSN_duparray => { let val = fun.push_insn(block, Insn::Const { val: Const::Value(get_arg(pc, 0)) }); let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state.clone() }); @@ -1978,19 +2025,19 @@ mod union_find_tests { } #[test] - fn test_find_const_returns_target() { + fn test_find_returns_target() { let mut uf = UnionFind::new(); uf.make_equal_to(3, 4); - assert_eq!(uf.find_const(3usize), 4); + assert_eq!(uf.find(3usize), 4); } #[test] - fn test_find_const_returns_transitive_target() { + fn test_find_returns_transitive_target() { let mut uf = UnionFind::new(); uf.make_equal_to(3, 4); uf.make_equal_to(4, 5); - assert_eq!(uf.find_const(3usize), 5); - assert_eq!(uf.find_const(4usize), 5); + assert_eq!(uf.find(3usize), 5); + assert_eq!(uf.find(4usize), 5); } #[test] @@ -2831,6 +2878,35 @@ mod tests { Return v10 "#]]); } + + #[test] + fn test_opt_newarray_send_max_no_elements() { + eval(" + def test = [].max + "); + // TODO(max): Rewrite to nil + assert_method_hir("test", expect![[r#" + fn test: + bb0(): + PatchPoint BOPRedefined(ARRAY_REDEFINED_OP_FLAG, BOP_MAX) + v3:BasicObject = ArrayMax + Return v3 + "#]]); + } + + #[test] + fn test_opt_newarray_send_max() { + eval(" + def test(a,b) = [a,b].max + "); + assert_method_hir("test", expect![[r#" + fn test: + bb0(v0:BasicObject, v1:BasicObject): + PatchPoint BOPRedefined(ARRAY_REDEFINED_OP_FLAG, BOP_MAX) + v5:BasicObject = ArrayMax v0, v1 + Return v5 + "#]]); + } } #[cfg(test)] |