-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Info
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  [email protected]
-- Stability   :  experimental
-- Portability :
--
-- Functions related to info/data calculation in Equality Graph data structure
-- Heavily based on hegg (https://github.com/alt-romes/hegg by alt-romes)
--
-----------------------------------------------------------------------------

module Algorithm.EqSat.Info where

import Control.Lens ( over )
import Control.Monad --(forM, forM_, when, foldM, void)
import Control.Monad.State
import Data.AEq (AEq ((~==)))
import Data.IntMap (IntMap) -- , delete, empty, insert, toList)
import qualified Data.IntMap as IntMap
import Data.Map (Map)
import qualified Data.Map as Map
import Data.SRTree
import Data.SRTree.Eval (evalFun, evalOp, PVector)
import Data.HashSet (HashSet)
import qualified Data.HashSet as Set
import qualified Data.IntSet as IntSet
import Algorithm.EqSat.Egraph
import Data.AEq (AEq ((~==)))
import Algorithm.EqSat.Queries
import Data.Maybe
import qualified Data.Set as TrueSet
import Data.Sequence (Seq(..), (><))

import Debug.Trace

-- * Data related functions 

-- | join data from two e-classes
-- TODO: instead of folding, just do not apply rules
-- list of values instead of single value
joinData :: EClassData -> EClassData -> EClassData
joinData :: EClassData -> EClassData -> EClassData
joinData (EData Cost
c1 ENode
b1 Consts
cn1 Maybe Double
fit1 Maybe Double
dl1 [PVector]
p1 Cost
sz1) (EData Cost
c2 ENode
b2 Consts
cn2 Maybe Double
fit2 Maybe Double
dl2 [PVector]
p2 Cost
sz2) =
  --EData (min c1 c2) b (combineConsts cn1 cn2) (minMaybe fit1 fit2) (bestParam p1 p2 fit1 fit2) (min sz1 sz2)
  Cost
-> ENode
-> Consts
-> Maybe Double
-> Maybe Double
-> [PVector]
-> Cost
-> EClassData
EData (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
c1 Cost
c2) (ENode -> ENode -> ENode
forall {p}. p -> p -> p
choose ENode
b1 ENode
b2) (Consts -> Consts -> Consts
forall {p}. p -> p -> p
choose Consts
cn1 Consts
cn2) (Maybe Double -> Maybe Double -> Maybe Double
forall {a}. Ord a => Maybe a -> Maybe a -> Maybe a
maxMaybe Maybe Double
fit1 Maybe Double
fit2) (Maybe Double -> Maybe Double -> Maybe Double
forall {p}. p -> p -> p
choose Maybe Double
dl1 Maybe Double
dl2) ([PVector] -> [PVector] -> [PVector]
forall {p}. p -> p -> p
choose [PVector]
p1 [PVector]
p2) (Cost -> Cost -> Cost
forall {p}. p -> p -> p
choose Cost
sz1 Cost
sz2)
  where
    isFst :: Bool
isFst = Cost
c1 Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
<= Cost
c2
    choose :: p -> p -> p
choose p
x p
y = if Bool
isFst then p
x else p
y
    chooseF :: p -> p -> p
chooseF p
x p
y = if Bool
maxIsFst then p
x else p
y

    maxIsFst :: Bool
maxIsFst = case (Maybe Double
fit1, Maybe Double
fit2) of
                 (Maybe Double
Nothing, Maybe Double
Nothing) -> Bool
True
                 (Maybe Double
Nothing,  Just Double
f) -> Bool
False
                 (Just Double
f , Maybe Double
Nothing) -> Bool
True
                 (Just Double
f1, Just Double
f2) -> Double
f1 Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
f2

    maxMaybe :: Maybe a -> Maybe a -> Maybe a
maxMaybe Maybe a
Nothing Maybe a
x = Maybe a
x
    maxMaybe Maybe a
x Maybe a
Nothing = Maybe a
x
    maxMaybe Maybe a
x Maybe a
y       = Maybe a -> Maybe a -> Maybe a
forall a. Ord a => a -> a -> a
max Maybe a
x Maybe a
y

    bestParam :: Maybe a -> Maybe a -> Maybe a -> Maybe a -> Maybe a
bestParam Maybe a
Nothing Maybe a
x Maybe a
_ Maybe a
_ = Maybe a
x
    bestParam Maybe a
x Maybe a
Nothing Maybe a
_ Maybe a
_ = Maybe a
x
    bestParam Maybe a
x Maybe a
y (Just a
f1) (Just a
f2) = if a
f1 a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
f2 then Maybe a
x else Maybe a
y

    b :: ENode
b = if Cost
c1 Cost -> Cost -> Bool
forall a. Ord a => a -> a -> Bool
<= Cost
c2 then ENode
b1 else ENode
b2
    combineConsts :: Consts -> Consts -> Consts
combineConsts (ConstVal Double
x) (ConstVal Double
y)
      | Double -> Double
forall a. Num a => a -> a
abs (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
y) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1e-7   = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
y)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x = Double -> Consts
ConstVal Double
y 
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
y Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
y = Double -> Consts
ConstVal Double
x
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
y = Double -> Consts
ConstVal Double
x
      | Double
x Double -> Double -> Bool
forall a. AEq a => a -> a -> Bool
~== Double
y = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ (Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
y)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
2
      | Double -> Double
forall a. Num a => a -> a
abs (Double
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
y) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1e-6 Bool -> Bool -> Bool
|| Double -> Double
forall a. Num a => a -> a
abs (Double
y Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
x) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1e-6 = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
forall a. Ord a => a -> a -> a
min Double
x Double
y
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
y = Double -> Consts
ConstVal Double
x
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
y = Double -> Consts
ConstVal Double
y
      | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
&& Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
y = Double -> Consts
ConstVal Double
x
      | Bool
otherwise          = [Char] -> Consts
forall a. HasCallStack => [Char] -> a
error ([Char] -> Consts) -> [Char] -> Consts
forall a b. (a -> b) -> a -> b
$ [Char]
"Combining different values: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show Double
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show Double
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Double -> [Char]
forall a. Show a => a -> [Char]
show (Double
xDouble -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
y)
    combineConsts (ParamIx Cost
ix) (ParamIx Cost
iy) = Cost -> Consts
ParamIx (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
ix Cost
iy)
    combineConsts Consts
NotConst Consts
x = Consts
x
    combineConsts Consts
x Consts
NotConst = Consts
x
    combineConsts (ParamIx Cost
ix) (ConstVal Double
x) = Double -> Consts
ConstVal Double
x
    combineConsts (ConstVal Double
x) (ParamIx Cost
ix) = Double -> Consts
ConstVal Double
x -- p - p = 0
    combineConsts Consts
x Consts
y = [Char] -> Consts
forall a. HasCallStack => [Char] -> a
error (Consts -> [Char]
forall a. Show a => a -> [Char]
show Consts
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Consts -> [Char]
forall a. Show a => a -> [Char]
show Consts
y)

-- | Calculate e-node data (constant values and cost)
makeAnalysis :: Monad m => CostFun -> ENode -> EGraphST m EClassData
makeAnalysis :: forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m EClassData
makeAnalysis CostFun
costFun ENode
enode =
  do consts <- ENode -> EGraphST m Consts
forall (m :: * -> *). Monad m => ENode -> EGraphST m Consts
calculateConsts ENode
enode
     enode' <- canonize enode
     cost   <- calculateCost costFun enode'
     sz <- sum <$> mapM (\Cost
ecId -> (EGraph -> Cost) -> EGraphST m Cost
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClassData -> Cost
_size (EClassData -> Cost) -> (EGraph -> EClassData) -> EGraph -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> EClassData)
-> (EGraph -> EClass) -> EGraph -> EClassData
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
ecId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)) (childrenOf enode')
     pure $ EData cost enode' consts Nothing Nothing [] (sz+1)

getChildrenMinHeight :: Monad m => ENode -> EGraphST m Int
getChildrenMinHeight :: forall (m :: * -> *). Monad m => ENode -> EGraphST m Cost
getChildrenMinHeight ENode
enode = do
  let children :: [Cost]
children = ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
enode
      minimum' :: [a] -> a
minimum' [] = a
0
      minimum' [a]
xs = [a] -> a
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [a]
xs
  [Cost] -> Cost
forall {a}. (Num a, Ord a) => [a] -> a
minimum' ([Cost] -> Cost) -> StateT EGraph m [Cost] -> StateT EGraph m Cost
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Cost -> StateT EGraph m Cost) -> [Cost] -> StateT EGraph m [Cost]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Cost
ec -> (EGraph -> Cost) -> StateT EGraph m Cost
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (EClass -> Cost
_height (EClass -> Cost) -> (EGraph -> EClass) -> EGraph -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
ec) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)) [Cost]
children

-- | update the heights of each e-class
-- won't work if there's no root
calculateHeights :: Monad m => EGraphST m ()
calculateHeights :: forall (m :: * -> *). Monad m => EGraphST m ()
calculateHeights =
  do queue   <- EGraphST m [Cost]
forall (m :: * -> *). Monad m => EGraphST m [Cost]
findRootClasses
     classes <- gets (Prelude.map fst . IntMap.toList . _eClass)
     let nClasses = [Cost] -> Cost
forall a. [a] -> Cost
forall (t :: * -> *) a. Foldable t => t a -> Cost
length [Cost]
classes
     forM_ classes (setHeight nClasses) -- set all heights to max possible height (number of e-classes)
     forM_ queue (setHeight 0)          -- set root e-classes height to zero
     go queue (TrueSet.fromList queue) 1    -- next height is 1
  where
    setHeight :: Cost -> Cost -> StateT EGraph m ()
setHeight Cost
x Cost
eId' =
      do eId <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
eId'
         ec <- getEClass eId
         let ec' = ASetter EClass EClass Cost Cost
-> (Cost -> Cost) -> EClass -> EClass
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter EClass EClass Cost Cost
Lens' EClass Cost
height (Cost -> Cost -> Cost
forall a b. a -> b -> a
const Cost
x) EClass
ec
         modify' $ over eClass (IntMap.insert eId ec')

    setMinHeight :: Cost -> Cost -> StateT EGraph m ()
setMinHeight Cost
x Cost
eId' = -- set height to the minimum between current and x
      do eId <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
eId'
         h <- _height <$> getEClass eId
         setHeight (min h x) eId

    getChildrenEC :: Monad m => EClassId -> EGraphST m [EClassId]
    getChildrenEC :: forall (m :: * -> *). Monad m => Cost -> EGraphST m [Cost]
getChildrenEC Cost
ec' = do ec <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
ec'
                           gets (concatMap childrenOf' . _eNodes . (IntMap.! ec) . _eClass)

    childrenOf' :: (a, a, a, d) -> [a]
childrenOf' (a
_, -1, -1, d
_) = []
    childrenOf' (a
_, a
e1, -1, d
_) = [a
e1]
    childrenOf' (a
_, a
e1, a
e2, d
_) = [a
e1, a
e2]

    go :: [Cost] -> Set Cost -> Cost -> StateT EGraph m ()
go [] Set Cost
_    Cost
_ = () -> StateT EGraph m ()
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go [Cost]
qs Set Cost
tabu Cost
h =
      do childrenOf <- (Set Cost -> Set Cost -> Set Cost
forall a. Ord a => Set a -> Set a -> Set a
TrueSet.\\ Set Cost
tabu) (Set Cost -> Set Cost)
-> ([[Cost]] -> Set Cost) -> [[Cost]] -> Set Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Cost] -> Set Cost
forall a. Ord a => [a] -> Set a
TrueSet.fromList ([Cost] -> Set Cost)
-> ([[Cost]] -> [Cost]) -> [[Cost]] -> Set Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Cost]] -> [Cost]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Cost]] -> Set Cost)
-> StateT EGraph m [[Cost]] -> StateT EGraph m (Set Cost)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Cost]
-> (Cost -> StateT EGraph m [Cost]) -> StateT EGraph m [[Cost]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Cost]
qs Cost -> StateT EGraph m [Cost]
forall (m :: * -> *). Monad m => Cost -> EGraphST m [Cost]
getChildrenEC -- rerieve all unvisited children
         let childrenL = Set Cost -> [Cost]
forall a. Set a -> [a]
TrueSet.toList Set Cost
childrenOf
         forM_ childrenL (setMinHeight h) -- set the height of the children as the minimum between current and h
         go childrenL (TrueSet.union tabu childrenOf) (h+1) -- move one breadth search style

-- | calculates the cost of a node
calculateCost :: Monad m => CostFun -> SRTree EClassId -> EGraphST m Cost
calculateCost :: forall (m :: * -> *).
Monad m =>
CostFun -> ENode -> EGraphST m Cost
calculateCost CostFun
f ENode
t =
  do let cs :: [Cost]
cs = ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
t
     costs <- (Cost -> StateT EGraph m Cost) -> [Cost] -> StateT EGraph m [Cost]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((EClass -> Cost) -> StateT EGraph m EClass -> StateT EGraph m Cost
forall a b. (a -> b) -> StateT EGraph m a -> StateT EGraph m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (EClassData -> Cost
_cost (EClassData -> Cost) -> (EClass -> EClassData) -> EClass -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info) (StateT EGraph m EClass -> StateT EGraph m Cost)
-> (Cost -> StateT EGraph m EClass) -> Cost -> StateT EGraph m Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cost -> StateT EGraph m EClass
forall (m :: * -> *). Monad m => Cost -> EGraphST m EClass
getEClass) [Cost]
cs
     pure . f $ replaceChildren costs t

-- | check whether an e-node evaluates to a const
calculateConsts :: Monad m => SRTree EClassId -> EGraphST m Consts
calculateConsts :: forall (m :: * -> *). Monad m => ENode -> EGraphST m Consts
calculateConsts ENode
t =
  do let cs :: [Cost]
cs = ENode -> [Cost]
forall a. SRTree a -> [a]
childrenOf ENode
t
     eg <- StateT EGraph m EGraph
forall s (m :: * -> *). MonadState s m => m s
get
     consts <- traverse (fmap (_consts . _info) . getEClass) cs
     case combineConsts $ replaceChildren consts t of
          ConstVal Double
x | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x -> Consts -> StateT EGraph m Consts
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> Consts
ConstVal Double
x)
          Consts
a -> Consts -> StateT EGraph m Consts
forall a. a -> StateT EGraph m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Consts
a

combineConsts :: SRTree Consts -> Consts
combineConsts :: SRTree Consts -> Consts
combineConsts (Const Double
x)    = Double -> Consts
ConstVal Double
x
combineConsts (Param Cost
ix)   = Cost -> Consts
ParamIx Cost
ix
combineConsts (Var Cost
_)      = Consts
NotConst
combineConsts (Uni Function
f Consts
t)    = case Consts
t of
                              ConstVal Double
x -> Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f Double
x
                              --ParamIx  x -> ParamIx x
                              Consts
_          -> Consts
t
combineConsts (Bin Op
op Consts
l Consts
r) = Consts -> Consts -> Consts
evalOp' Consts
l Consts
r
  where
    evalOp' :: Consts -> Consts -> Consts
evalOp' (ParamIx Cost
ix) (ParamIx Cost
iy) = Cost -> Consts
ParamIx (Cost -> Cost -> Cost
forall a. Ord a => a -> a -> a
min Cost
ix Cost
iy)
    evalOp' (ConstVal Double
x) (ConstVal Double
y) = Double -> Consts
ConstVal (Double -> Consts) -> Double -> Consts
forall a b. (a -> b) -> a -> b
$ Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op Double
x Double
y
    evalOp' Consts
_            Consts
_            = Consts
NotConst

insertFitness :: Monad m => EClassId -> Double -> [PVector] -> EGraphST m ()
insertFitness :: forall (m :: * -> *).
Monad m =>
Cost -> Double -> [PVector] -> EGraphST m ()
insertFitness Cost
eId' Double
fit [PVector]
params = do
  eId <- Cost -> EGraphST m Cost
forall (m :: * -> *). Monad m => Cost -> EGraphST m Cost
canonical Cost
eId'
  ec <- gets ((IntMap.! eId) . _eClass)
  let oldFit  = EClassData -> Maybe Double
_fitness (EClassData -> Maybe Double)
-> (EClass -> EClassData) -> EClass -> Maybe Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Maybe Double) -> EClass -> Maybe Double
forall a b. (a -> b) -> a -> b
$ EClass
ec
  --when (oldFit < Just fit) $ do
  let newInfo = (EClass -> EClassData
_info EClass
ec){_fitness = Just fit, _theta = params}
      newEc   = EClass
ec{_info = newInfo}
      sz = EClassData -> Cost
_size EClassData
newInfo
  modify' $ over eClass (IntMap.insert eId newEc)
  if (isNothing oldFit)
    then modify' $ over (eDB . unevaluated) (IntSet.delete eId)
                 . over (eDB . fitRangeDB) (insertRange eId fit)
                 . over (eDB . sizeFitDB) (IntMap.adjust (insertRange eId fit) sz . IntMap.insertWith (><) sz Empty)
    else modify' $ over (eDB . fitRangeDB) (insertRange eId fit . removeRange eId (fromJust oldFit))

insertDL :: Monad m => EClassId -> Double -> EGraphST m ()
insertDL :: forall (m :: * -> *). Monad m => Cost -> Double -> EGraphST m ()
insertDL Cost
eId Double
fit' = do
  let fit :: Double
fit = Double -> Double
forall a. Num a => a -> a
negate Double
fit'
  ec <- (EGraph -> EClass) -> StateT EGraph m EClass
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap EClass -> Cost -> EClass
forall a. IntMap a -> Cost -> a
IntMap.! Cost
eId) (IntMap EClass -> EClass)
-> (EGraph -> IntMap EClass) -> EGraph -> EClass
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph -> IntMap EClass
_eClass)
  let sz = EClassData -> Cost
_size (EClassData -> Cost) -> (EClass -> EClassData) -> EClass -> Cost
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Cost) -> EClass -> Cost
forall a b. (a -> b) -> a -> b
$ EClass
ec
      newInfo = (EClass -> EClassData
_info EClass
ec){_dl = Just fit'}
      newEc   = EClass
ec{_info=newInfo}
  modify' $ over eClass (IntMap.insert eId newEc)
  modify' $ over (eDB . dlRangeDB) (insertRange eId fit)
          . over (eDB . sizeDLDB) (IntMap.adjust (insertRange eId fit) sz . IntMap.insertWith (><) sz Empty)