diff options
author | Max Bernstein <[email protected]> | 2025-04-30 19:28:22 -0400 |
---|---|---|
committer | GitHub <[email protected]> | 2025-04-30 16:28:22 -0700 |
commit | 5411b504a5e75d553a423b4e5dbe63b9c45e906f (patch) | |
tree | 6b5dc22e1727e7b48d934017f395192bf5afb419 | |
parent | 7866e124a852c344b5762eb917c03a1f95d9058d (diff) |
ZJIT: Use RefCell to allow path compression in union-find (#13218)
Use RefCell to allow path compression in union-find
When I wrote the original version I didn't understand the interior
mutability pattern, but now I do! With this commit, we should have a
more optimal union-find implementation.
Notes
Notes:
Merged-By: k0kubun <[email protected]>
-rw-r--r-- | zjit/src/hir.rs | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index f4ecb46383..d54617efe4 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -619,7 +619,7 @@ impl<T: Copy + Into<usize> + PartialEq> UnionFind<T> { } /// Find the set representative for `insn` without doing path compression. - pub fn find_const(&self, insn: T) -> T { + fn find_const(&self, insn: T) -> T { let mut result = insn; loop { match self.at(result) { @@ -645,7 +645,7 @@ pub struct Function { // TODO: get method name and source location from the ISEQ insns: Vec<Insn>, - union_find: UnionFind<InsnId>, + union_find: std::cell::RefCell<UnionFind<InsnId>>, insn_types: Vec<Type>, blocks: Vec<Block>, entry_block: BlockId, @@ -657,7 +657,7 @@ impl Function { iseq, insns: vec![], insn_types: vec![], - union_find: UnionFind::new(), + union_find: UnionFind::new().into(), blocks: vec![Block::default()], entry_block: BlockId(0), } @@ -740,7 +740,7 @@ impl Function { macro_rules! find { ( $x:expr ) => { { - self.union_find.find_const($x) + self.union_find.borrow_mut().find($x) } }; } @@ -749,12 +749,12 @@ impl Function { { BranchEdge { target: $edge.target, - args: $edge.args.iter().map(|x| self.union_find.find_const(*x)).collect(), + args: $edge.args.iter().map(|x| find!(*x)).collect(), } } }; } - let insn_id = self.union_find.find_const(insn_id); + let insn_id = self.union_find.borrow_mut().find(insn_id); use Insn::*; match &self.insns[insn_id.0] { result@(PutSelf | Const {..} | Param {..} | NewArray {..} | GetConstantPath {..} @@ -822,12 +822,12 @@ impl Function { /// Replace `insn` with the new instruction `replacement`, which will get appended to `insns`. fn make_equal_to(&mut self, insn: InsnId, replacement: InsnId) { // Don't push it to the block - self.union_find.make_equal_to(insn, replacement); + self.union_find.borrow_mut().make_equal_to(insn, replacement); } fn type_of(&self, insn: InsnId) -> Type { assert!(self.insns[insn.0].has_output()); - self.insn_types[self.union_find.find_const(insn).0] + self.insn_types[self.union_find.borrow_mut().find(insn).0] } /// Check if the type of `insn` is a subtype of `ty`. @@ -1978,19 +1978,19 @@ mod union_find_tests { } #[test] - fn test_find_const_returns_target() { + fn test_find_returns_target() { let mut uf = UnionFind::new(); uf.make_equal_to(3, 4); - assert_eq!(uf.find_const(3usize), 4); + assert_eq!(uf.find(3usize), 4); } #[test] - fn test_find_const_returns_transitive_target() { + fn test_find_returns_transitive_target() { let mut uf = UnionFind::new(); uf.make_equal_to(3, 4); uf.make_equal_to(4, 5); - assert_eq!(uf.find_const(3usize), 5); - assert_eq!(uf.find_const(4usize), 5); + assert_eq!(uf.find(3usize), 5); + assert_eq!(uf.find(4usize), 5); } #[test] |