{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Algorithm.EqSat.Simplify
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  [email protected]
-- Stability   :  experimental
-- Portability :
--
-- Module containing the algebraic rules and simplification function.
--
-----------------------------------------------------------------------------
module Algorithm.EqSat.Simplify ( Rule(..), simplifyEqSatDefault, applyMergeOnlyDftl, rewrites, rewritesParams, rewriteBasic, rewritesFun, rewritesSimple ) where

import Algorithm.EqSat (eqSat, applySingleMergeOnlyEqSat)
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.DB
  ( ClassOrVar,
    Pattern (Fixed, VarPat),
    Rule (..),
    getInt,
  )
import Control.Monad.State.Strict (evalState)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.Map (Map)
import qualified Data.Map as Map
import Data.SRTree

type ConstrFun = Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool 

constrainOnVal :: (Consts -> Bool) -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool 
constrainOnVal :: (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal Consts -> Bool
f (VarPat Char
c) Map ClassOrVar ClassOrVar
subst EGraph
eg =
    let cid :: Int
cid = ClassOrVar -> Int
getInt (ClassOrVar -> Int) -> ClassOrVar -> Int
forall a b. (a -> b) -> a -> b
$ Map ClassOrVar ClassOrVar
subst Map ClassOrVar ClassOrVar -> ClassOrVar -> ClassOrVar
forall k a. Ord k => Map k a -> k -> a
Map.! Int -> ClassOrVar
forall a b. b -> Either a b
Right (Char -> Int
forall a. Enum a => a -> Int
fromEnum Char
c)
     in Consts -> Bool
f (EClassData -> Consts
_consts (EClassData -> Consts)
-> (EClass -> EClassData) -> EClass -> Consts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EClass -> EClassData
_info (EClass -> Consts) -> EClass -> Consts
forall a b. (a -> b) -> a -> b
$ EGraph -> ClassIdMap EClass
_eClass EGraph
eg ClassIdMap EClass -> Int -> EClass
forall a. IntMap a -> Int -> a
IM.! Int
cid)
constrainOnVal Consts -> Bool
_ Pattern
_ Map ClassOrVar ClassOrVar
_ EGraph
_ = Bool
False 

-- TODO: aux functions to avoid repeated pattern in constraint creation 
--
-- check if a matched pattern contains constant 
isConstPt :: ConstrFun
isConstPt :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$ 
    \case
       ConstVal Double
_ -> Bool
True 
       Consts
_          -> Bool
False

-- check if the matched pattern is a positive constant 
isConstPos :: ConstrFun
isConstPos :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPos = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
    \case
      ConstVal Double
x -> Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 
      Consts
_          -> Bool
False

isNotParam :: ConstrFun
isNotParam :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotParam = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
   \case
      ParamIx Int
_ -> Bool
False
      Consts
_         -> Bool
True

-- check if the matched pattern is nonzero
isNotZero :: ConstrFun
isNotZero :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
    \case
       ConstVal Double
x -> Double -> Double
forall a. Num a => a -> a
abs Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
1e-9
       Consts
_          -> Bool
True

-- check if the matched pattern is even 
isEven :: ConstrFun
isEven :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isEven = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
    \case
       ConstVal Double
x -> Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Double
x Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
x Bool -> Bool -> Bool
&& Integer -> Bool
forall a. Integral a => a -> Bool
even (Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round Double
x) 
       Consts
_          -> Bool
True

-- check if the matched pattern is integer
isInteger :: ConstrFun
isInteger :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isInteger = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
    \case
       ConstVal Double
x -> Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling Double
x Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
x
       Consts
_          -> Bool
True

-- check if the matched pattern is positive
isPositive :: ConstrFun
isPositive :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
    \case
       ConstVal Double
x -> Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0
       Consts
_          -> Bool
True

-- check if the matched pattern is valid
isValid :: ConstrFun
isValid :: Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isValid = (Consts -> Bool)
-> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
constrainOnVal ((Consts -> Bool)
 -> Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool)
-> (Consts -> Bool)
-> Pattern
-> Map ClassOrVar ClassOrVar
-> EGraph
-> Bool
forall a b. (a -> b) -> a -> b
$
    \case
       ConstVal Double
x -> Bool -> Bool
not (Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x Bool -> Bool -> Bool
|| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x)
       Consts
_          -> Bool
True

-- basic algebraic rules 
rewriteBasic :: [Rule]
rewriteBasic :: [Rule]
rewriteBasic =
    [
      Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"y" Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"x"
    --, ("x" ** "y") * ("x" ** "z") :=> "x" ** ("y" + "z") -- :| isPositive "x"
    --, (powabs "x" "y") * (powabs "x" "z") :=> powabs "x" ("y" + "x")
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z")
    --, ("x" + "y") - "z" :=> "x" + ("y" - "z") -- TODO: check that I don't need that
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z")
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z") Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z")
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"z" -- TODO: check that I don't this
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z" -- TODO
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z" Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"z" -- TODO: inv(x) <=> x^-1 , x/y <=> x*y^-1
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"z" -- ^
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"z" -- ^ TODO: 0 ^-1 check
    , (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Rule
:=> (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" -- :| isConstPt "w" :| isConstPt "z"
    , (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Rule
:=> (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" -- TODO: handle sub :| isConstPt "w" :| isConstPt "z"
    , (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Rule
:=> (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"y") -- TODO handle with power :| isConstPt "w" :| isConstPt "z" :| isNotZero "z"
    -- TODO: a + b*y :=> b * (a/b + y) :| isNotZero b
    , ((Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"w")) Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"w") Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt Pattern
"z" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    -- , "a" * (("x" * "y") + ("z" * "w")) :=> ("a" * "x") * ("y" + ("z" / "x") * "w") :| isConstPt "a" :| isConstPt "x" :| isConstPt "z" :| isNotZero "x"
    , ((Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"w")) Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"w") Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt Pattern
"z" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    , ((Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"w")) Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"w") Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPt Pattern
"z"
    -- , "x" + "y" :=> "y" * ("x" * "y" ** (-1) + 1) :| isNotZero "y" -- GABRIEL 
    -- , "x" + "y" * "z" :=> "y" * ("x" * "y" ** (-1) + "z") :| isNotZero "y" -- GABRIEL 
    ]

-- rules for nonlinear functions 
rewritesFun :: [Rule]
rewritesFun :: [Rule]
rewritesFun =
    [
      Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern
forall a. Floating a => a -> a
exp Pattern
"x") Pattern -> Pattern -> Rule
:==: Pattern -> Pattern
forall a. Floating a => a -> a
exp (Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x")
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern
forall a. Floating a => a -> a
exp Pattern
"x")  Pattern -> Pattern -> Rule
:=> Pattern
"x"
    -- , exp (log "x")  :=> "x" -- :| isPositive "x" ??? exp(log(x)), x, log(exp(0))
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPos Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isConstPos Pattern
"y"
    -- , log ("x" / "y") :=> log "x" - log "y" :| isConstPos "x" :| isConstPos "y"
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x"
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern -> Pattern
powabs Pattern
"x" Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x")
    --, sqrt ("x" ** "y") :=> "x" ** ("y" / 2) :| isEven "y"
    -- , sqrt ("y" * "x") :=> sqrt "y" * sqrt "x" --
    --, sqrt ("y" / "x") :=> sqrt "y" / sqrt "x"
    , Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"y" -- :| isConstPt "x"
    , Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y"
    , Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"x")
    --, sqrt ("z" * ("x" - "y")) :=> sqrt (negate "z") * sqrt ("y" - "x")
    --, sqrt ("z" * ("x" + "y")) :=> sqrt "z" * sqrt ("x" + "y")
    , Pattern -> Pattern
forall a. Fractional a => a -> a
recip (Pattern -> Pattern
forall a. Fractional a => a -> a
recip Pattern
"x") Pattern -> Pattern -> Rule
:=> Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Pattern -> Pattern -> Rule
:==: (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z") -- :| bothSameSign "x" "y"
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Pattern -> Pattern -> Rule
:==: (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z") -- :| isInteger "z"
    --, recip "x" :==: "x" ** (-1) -- GABRIEL 
    --, "x" / "y" :==: "x" * "y" ** (-1) -- GABRIEL 
    , Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isEven Pattern
"y"
    ]

-- Rules that reduces redundant parameters
constReduction :: [Rule]
constReduction :: [Rule]
constReduction =
    [
      Pattern
0 Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
"x"
    -- , "x" - 0 :=> "x"
    --, 1 * "x" :=> "x"
    -- , 0 / "x" :=> 0 :| isNotZero "x"
    --, "x" - "x" :=> 0 :| isNotParam "x"
    --, "x" / "x" :=> 1 :| isNotZero "x" :| isNotParam "x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
1 Pattern -> Pattern -> Rule
:=> Pattern
"x"
    , Pattern -> Pattern -> Pattern
powabs Pattern
"x" Pattern
1 Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x"

    -- , "x" * (1 / "x") :=> 1 :| isNotParam "x" :| isNotZero "x"
    -- , negate ("x" * "y") :=> (negate "x") * "y" :| isConstPt "x"

    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Pattern -> Pattern -> Rule
:==: Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z") Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive Pattern
"x"
    , (Pattern -> Pattern -> Pattern
powabs Pattern
"x" Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern -> Pattern -> Pattern
powabs Pattern
"x" Pattern
"z") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern -> Pattern
powabs Pattern
"x" (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"x")
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Pattern -> Pattern -> Rule
:==: Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z") Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive Pattern
"x"
    , Pattern -> Pattern -> Pattern
powabs (Pattern -> Pattern -> Pattern
powabs Pattern
"x" Pattern
"y") Pattern
"z" Pattern -> Pattern -> Rule
:=> Pattern -> Pattern -> Pattern
powabs Pattern
"x" (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z")
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Pattern -> Pattern -> Rule
:==: Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive Pattern
"y"

    --, "x" ** "y" * "x" ** "z" :==: "x" ** ("y" + "z") :| isInteger "y" :| isInteger "z"  :| isNotZero "x"
    --, ("x" ** "y") ** "z" :==: "x" ** ("y" * "z") :| isInteger "y" :| isInteger "z" :| isNotZero "x"
    --, ("x" * "y") ** "z" :==: "x" ** "z" * "y" ** "z" :| isInteger "z" :| isNotZero "x" :| isNotZero "y"

    ]

rewritesWithConstant :: [Rule]
rewritesWithConstant :: [Rule]
rewritesWithConstant =
    [
      Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
2
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
0
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
1 Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
1) Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive Pattern
"x"
    , Pattern
1 Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
1
    , Pattern -> Pattern -> Pattern
powabs Pattern
1 Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
1
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern
forall a. Floating a => a -> a
sqrt Pattern
"x") Pattern -> Pattern -> Rule
:=> Pattern
0.5 Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotParam Pattern
"x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** (Pattern
1Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/Pattern
2)   Pattern -> Pattern -> Rule
:==: Pattern -> Pattern
forall a. Floating a => a -> a
sqrt Pattern
"x" -- <==>
    , Pattern -> Pattern -> Pattern
powabs Pattern
"x" (Pattern
1Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/Pattern
2) Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Floating a => a -> a
sqrt (Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x")
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** (Pattern
1Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/Pattern
3) Pattern -> Pattern -> Rule
:==: SRTree Pattern -> Pattern
Fixed (Function -> Pattern -> SRTree Pattern
forall val. Function -> val -> SRTree val
Uni Function
Cbrt Pattern
"x")
    , Pattern
0 Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
0 Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isValid Pattern
"x" -- :| isNotParam "x"
    , Pattern
0 Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
0 Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isPositive Pattern
"x"
    , Pattern -> Pattern -> Pattern
powabs Pattern
0 Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
0
    , Pattern
0 Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
negate Pattern
"x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern -> Pattern
forall a. Num a => a -> a
negate Pattern
"y" Pattern -> Pattern -> Rule
:==: Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y"
    ]
rewritesWithParam :: [Rule]
rewritesWithParam :: [Rule]
rewritesWithParam =
    [
    --  "x" * "x" :=> "x" ** Fixed (Param 0)
      Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0)
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0) Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    , Pattern
1 Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0)
    , Pattern -> Pattern -> Pattern
powabs Pattern
1 Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0)
    -- , log (sqrt "x") :=> Fixed (Param 0) * log "x" :| isNotParam "x"
    ]

rewritesSimple :: [Rule]
rewritesSimple :: [Rule]
rewritesSimple =
    [
      Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"y" Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"x"
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"z") Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z") -- :| isPositive "x"
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z")
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z")
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z") Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z")
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"z" -- TODO: check that I don't this
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z" -- TODO
    , (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z" Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"z" -- TODO: inv(x) <=> x^-1 , x/y <=> x*y^-1
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"z" -- ^
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"z") Pattern -> Pattern -> Rule
:=> (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"y" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"z" -- ^ TODO: 0 ^-1 check
    , (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Rule
:=> (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" -- :| isConstPt "w" :| isConstPt "z"
    , (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Rule
:=> (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" -- TODO: handle sub :| isConstPt "w" :| isConstPt "z"
    , (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x") Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ (Pattern
"z" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Rule
:=> (Pattern
"w" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"z") Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"y")
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern
forall a. Floating a => a -> a
exp Pattern
"x")  Pattern -> Pattern -> Rule
:=> Pattern
"x"
    , Pattern -> Pattern
forall a. Floating a => a -> a
exp (Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x")  Pattern -> Pattern -> Rule
:=> Pattern
"x"
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
+ Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"y"
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x"
    , Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"y"
    , Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"y"
    , Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"y") Pattern -> Pattern -> Rule
:=> Pattern -> Pattern
forall a. Num a => a -> a
abs (Pattern
"y" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"x")
    , Pattern -> Pattern
forall a. Fractional a => a -> a
recip (Pattern -> Pattern
forall a. Fractional a => a -> a
recip Pattern
"x") Pattern -> Pattern -> Rule
:=> Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern
"x" Pattern -> Pattern -> Rule
:=> Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0)
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
- Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0)
    , Pattern
"x" Pattern -> Pattern -> Pattern
forall a. Fractional a => a -> a -> a
/ Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0) Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotZero Pattern
"x"
    , Pattern
1 Pattern -> Pattern -> Pattern
forall a. Floating a => a -> a -> a
** Pattern
"x" Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0)
    , Pattern -> Pattern
forall a. Floating a => a -> a
log (Pattern -> Pattern
forall a. Floating a => a -> a
sqrt Pattern
"x") Pattern -> Pattern -> Rule
:=> SRTree Pattern -> Pattern
Fixed (Int -> SRTree Pattern
forall val. Int -> SRTree val
Param Int
0) Pattern -> Pattern -> Pattern
forall a. Num a => a -> a -> a
* Pattern -> Pattern
forall a. Floating a => a -> a
log Pattern
"x" Rule -> (Map ClassOrVar ClassOrVar -> EGraph -> Bool) -> Rule
:| Pattern -> Map ClassOrVar ClassOrVar -> EGraph -> Bool
isNotParam Pattern
"x"
    ]
powabs :: Pattern -> Pattern -> Pattern
powabs Pattern
l Pattern
r = SRTree Pattern -> Pattern
Fixed (Op -> Pattern -> Pattern -> SRTree Pattern
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs Pattern
l Pattern
r)

-- | default cost function for simplification
-- TODO:
-- num_params:
--   length:
--      terminal < nonterminal:
--        symbol comparison (constants, parameters, variables x0, x10, x2)
--          op priorities (+, -, *, inv_div, pow, abs, exp, log, log10, sqrt)
--            univariates
myCost :: SRTree Int -> Int
myCost :: SRTree Int -> Int
myCost (Var Int
_)      = Int
1
myCost (Const Double
_)    = Int
3
myCost (Param Int
_)    = Int
3
myCost (Bin Op
op Int
l Int
r) = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r
myCost (Uni Function
_ Int
t)    = Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t

-- all rewrite rules
rewrites :: [Rule]
rewrites :: [Rule]
rewrites = [Rule]
rewriteBasic [Rule] -> [Rule] -> [Rule]
forall a. Semigroup a => a -> a -> a
<> [Rule]
constReduction [Rule] -> [Rule] -> [Rule]
forall a. Semigroup a => a -> a -> a
<> [Rule]
rewritesFun [Rule] -> [Rule] -> [Rule]
forall a. Semigroup a => a -> a -> a
<> [Rule]
rewritesWithConstant
rewritesParams :: [Rule]
rewritesParams :: [Rule]
rewritesParams = [Rule]
rewriteBasic [Rule] -> [Rule] -> [Rule]
forall a. Semigroup a => a -> a -> a
<> [Rule]
constReduction [Rule] -> [Rule] -> [Rule]
forall a. Semigroup a => a -> a -> a
<> [Rule]
rewritesFun [Rule] -> [Rule] -> [Rule]
forall a. Semigroup a => a -> a -> a
<> [Rule]
rewritesWithParam

-- | simplify using the default parameters 
simplifyEqSatDefault :: Fix SRTree -> Fix SRTree
simplifyEqSatDefault :: Fix SRTree -> Fix SRTree
simplifyEqSatDefault Fix SRTree
t = Fix SRTree
-> [Rule]
-> (SRTree Int -> Int)
-> Int
-> EGraphST Identity (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Fix SRTree
-> [Rule] -> (SRTree Int -> Int) -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
t [Rule]
rewrites SRTree Int -> Int
myCost Int
30 EGraphST Identity (Fix SRTree) -> EGraph -> Fix SRTree
forall s a. State s a -> s -> a
`evalState` EGraph
emptyGraph

-- | simplifies with custom parameters
simplifyEqSat :: [Rule] -> CostFun -> Int -> Fix SRTree -> Fix SRTree
simplifyEqSat :: [Rule] -> (SRTree Int -> Int) -> Int -> Fix SRTree -> Fix SRTree
simplifyEqSat [Rule]
rwrts SRTree Int -> Int
costFun Int
it Fix SRTree
t = Fix SRTree
-> [Rule]
-> (SRTree Int -> Int)
-> Int
-> EGraphST Identity (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Fix SRTree
-> [Rule] -> (SRTree Int -> Int) -> Int -> EGraphST m (Fix SRTree)
eqSat Fix SRTree
t [Rule]
rwrts SRTree Int -> Int
costFun Int
it EGraphST Identity (Fix SRTree) -> EGraph -> Fix SRTree
forall s a. State s a -> s -> a
`evalState` EGraph
emptyGraph

-- | apply a single step of merge-only using default rules
applyMergeOnlyDftl :: Monad m => CostFun -> EGraphST m ()
applyMergeOnlyDftl :: forall (m :: * -> *).
Monad m =>
(SRTree Int -> Int) -> EGraphST m ()
applyMergeOnlyDftl SRTree Int -> Int
costFun = (SRTree Int -> Int) -> [Rule] -> EGraphST m ()
forall (m :: * -> *).
Monad m =>
(SRTree Int -> Int) -> [Rule] -> EGraphST m ()
applySingleMergeOnlyEqSat SRTree Int -> Int
costFun [Rule]
rewrites