{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
module Data.SRTree.Eval
( evalTree
, evalOp
, evalFun
, cbrt
, inverseFunc
, invertibles
, evalInverse
, invright
, invleft
, replicateAs
, SRVector, PVector, SRMatrix
, compMode
)
where
import Data.Massiv.Array
import qualified Data.Massiv.Array as M
import Data.SRTree.Internal
import Data.SRTree.Recursion (Fix (..), cata)
type SRVector = M.Array D Ix1 Double
type PVector = M.Array S Ix1 Double
type SRMatrix = M.Array S Ix2 Double
compMode :: M.Comp
compMode :: Comp
compMode = Comp
M.Seq
instance Index ix => Num (M.Array D ix Double) where
+ :: Array D ix Double -> Array D ix Double -> Array D ix Double
(+) = Array D ix Double -> Array D ix Double -> Array D ix Double
forall ix r e.
(HasCallStack, Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
(!+!)
(-) = Array D ix Double -> Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
(!-!)
* :: Array D ix Double -> Array D ix Double -> Array D ix Double
(*) = Array D ix Double -> Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e -> Array r ix e
(!*!)
abs :: Array D ix Double -> Array D ix Double
abs = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e
absA
signum :: Array D ix Double -> Array D ix Double
signum = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e
signumA
fromInteger :: Integer -> Array D ix Double
fromInteger = Integer -> Array D ix Double
forall a. Num a => Integer -> a
fromInteger
negate :: Array D ix Double -> Array D ix Double
negate = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> Array r ix e
negateA
instance Index ix => Floating (M.Array D ix Double) where
pi :: Array D ix Double
pi = Array D ix Double
forall a. Floating a => a
pi
exp :: Array D ix Double -> Array D ix Double
exp = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
expA
log :: Array D ix Double -> Array D ix Double
log = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
logA
sqrt :: Array D ix Double -> Array D ix Double
sqrt = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
sqrtA
sin :: Array D ix Double -> Array D ix Double
sin = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
sinA
cos :: Array D ix Double -> Array D ix Double
cos = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
cosA
tan :: Array D ix Double -> Array D ix Double
tan = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
tanA
asin :: Array D ix Double -> Array D ix Double
asin = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
asinA
acos :: Array D ix Double -> Array D ix Double
acos = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
acosA
atan :: Array D ix Double -> Array D ix Double
atan = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
atanA
sinh :: Array D ix Double -> Array D ix Double
sinh = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
sinhA
cosh :: Array D ix Double -> Array D ix Double
cosh = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
coshA
tanh :: Array D ix Double -> Array D ix Double
tanh = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
tanhA
asinh :: Array D ix Double -> Array D ix Double
asinh = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
asinhA
acosh :: Array D ix Double -> Array D ix Double
acosh = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
acoshA
atanh :: Array D ix Double -> Array D ix Double
atanh = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
atanhA
** :: Array D ix Double -> Array D ix Double -> Array D ix Double
(**) = Array D ix Double -> Array D ix Double -> Array D ix Double
forall ix r1 e r2.
(Index ix, Source r1 e, Source r2 e, Floating e) =>
Array r1 ix e -> Array r2 ix e -> Array D ix e
(.**)
instance Index ix => Fractional (M.Array D ix Double) where
fromRational :: Rational -> Array D ix Double
fromRational = Rational -> Array D ix Double
forall a. Fractional a => Rational -> a
fromRational
/ :: Array D ix Double -> Array D ix Double -> Array D ix Double
(/) = Array D ix Double -> Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e -> Array r ix e
(!/!)
recip :: Array D ix Double -> Array D ix Double
recip = Array D ix Double -> Array D ix Double
forall ix r e.
(Index ix, NumericFloat r e) =>
Array r ix e -> Array r ix e
recipA
replicateAs :: SRMatrix -> Double -> SRVector
replicateAs :: SRMatrix -> Double -> SRVector
replicateAs SRMatrix
xss Double
c = let (Sz (Ix1
m :. Ix1
_)) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss in Comp -> Sz Ix1 -> Double -> SRVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate (SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
xss) (Ix1 -> Sz Ix1
forall ix. Index ix => ix -> Sz ix
Sz Ix1
m) Double
c
{-# INLINE replicateAs #-}
evalTree :: SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree :: SRMatrix -> PVector -> Fix SRTree -> SRVector
evalTree SRMatrix
xss PVector
params = (SRTree SRVector -> SRVector) -> Fix SRTree -> SRVector
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((SRTree SRVector -> SRVector) -> Fix SRTree -> SRVector)
-> (SRTree SRVector -> SRVector) -> Fix SRTree -> SRVector
forall a b. (a -> b) -> a -> b
$
\case
Var Ix1
ix -> SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix
Param Ix1
ix -> SRMatrix -> Double -> SRVector
replicateAs SRMatrix
xss (Double -> SRVector) -> Double -> SRVector
forall a b. (a -> b) -> a -> b
$ PVector
params PVector -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix
Const Double
c -> SRMatrix -> Double -> SRVector
replicateAs SRMatrix
xss Double
c
Uni Function
g SRVector
t -> Function -> SRVector -> SRVector
forall a. Floating a => Function -> a -> a
evalFun Function
g SRVector
t
Bin Op
op SRVector
l SRVector
r -> Op -> SRVector -> SRVector -> SRVector
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op SRVector
l SRVector
r
{-# INLINE evalTree #-}
evalOp :: Floating a => Op -> a -> a -> a
evalOp :: forall a. Floating a => Op -> a -> a -> a
evalOp Op
Add = a -> a -> a
forall a. Num a => a -> a -> a
(+)
evalOp Op
Sub = (-)
evalOp Op
Mul = a -> a -> a
forall a. Num a => a -> a -> a
(*)
evalOp Op
Div = a -> a -> a
forall a. Fractional a => a -> a -> a
(/)
evalOp Op
Power = a -> a -> a
forall a. Floating a => a -> a -> a
(**)
evalOp Op
PowerAbs = \a
l a
r -> a -> a
forall a. Num a => a -> a
abs a
l a -> a -> a
forall a. Floating a => a -> a -> a
** a
r
evalOp Op
AQ = \a
l a
r -> a
l a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
ra -> a -> a
forall a. Num a => a -> a -> a
*a
r)
{-# INLINE evalOp #-}
evalFun :: Floating a => Function -> a -> a
evalFun :: forall a. Floating a => Function -> a -> a
evalFun Function
Id = a -> a
forall a. a -> a
id
evalFun Function
Abs = a -> a
forall a. Num a => a -> a
abs
evalFun Function
Sin = a -> a
forall a. Floating a => a -> a
sin
evalFun Function
Cos = a -> a
forall a. Floating a => a -> a
cos
evalFun Function
Tan = a -> a
forall a. Floating a => a -> a
tan
evalFun Function
Sinh = a -> a
forall a. Floating a => a -> a
sinh
evalFun Function
Cosh = a -> a
forall a. Floating a => a -> a
cosh
evalFun Function
Tanh = a -> a
forall a. Floating a => a -> a
tanh
evalFun Function
ASin = a -> a
forall a. Floating a => a -> a
asin
evalFun Function
ACos = a -> a
forall a. Floating a => a -> a
acos
evalFun Function
ATan = a -> a
forall a. Floating a => a -> a
atan
evalFun Function
ASinh = a -> a
forall a. Floating a => a -> a
asinh
evalFun Function
ACosh = a -> a
forall a. Floating a => a -> a
acosh
evalFun Function
ATanh = a -> a
forall a. Floating a => a -> a
atanh
evalFun Function
Sqrt = a -> a
forall a. Floating a => a -> a
sqrt
evalFun Function
SqrtAbs = a -> a
forall a. Floating a => a -> a
sqrt (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
abs
evalFun Function
Cbrt = a -> a
forall a. Floating a => a -> a
cbrt
evalFun Function
Square = (a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
evalFun Function
Log = a -> a
forall a. Floating a => a -> a
log
evalFun Function
LogAbs = a -> a
forall a. Floating a => a -> a
log (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
abs
evalFun Function
Exp = a -> a
forall a. Floating a => a -> a
exp
evalFun Function
Recip = a -> a
forall a. Fractional a => a -> a
recip
evalFun Function
Cube = (a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3)
{-# INLINE evalFun #-}
cbrt :: Floating a => a -> a
cbrt :: forall a. Floating a => a -> a
cbrt a
x = a -> a
forall a. Num a => a -> a
signum a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Num a => a -> a
abs a
x a -> a -> a
forall a. Floating a => a -> a -> a
** (a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
3)
{-# INLINE cbrt #-}
inverseFunc :: Function -> Function
inverseFunc :: Function -> Function
inverseFunc Function
Id = Function
Id
inverseFunc Function
Sin = Function
ASin
inverseFunc Function
Cos = Function
ACos
inverseFunc Function
Tan = Function
ATan
inverseFunc Function
Sinh = Function
ASinh
inverseFunc Function
Cosh = Function
ACosh
inverseFunc Function
Tanh = Function
ATanh
inverseFunc Function
ASin = Function
Sin
inverseFunc Function
ACos = Function
Cos
inverseFunc Function
ATan = Function
Tan
inverseFunc Function
ASinh = Function
Sinh
inverseFunc Function
ACosh = Function
Cosh
inverseFunc Function
ATanh = Function
Tanh
inverseFunc Function
Sqrt = Function
Square
inverseFunc Function
Square = Function
Sqrt
inverseFunc Function
Log = Function
Exp
inverseFunc Function
Exp = Function
Log
inverseFunc Function
Recip = Function
Recip
inverseFunc Function
x = [Char] -> Function
forall a. HasCallStack => [Char] -> a
error ([Char] -> Function) -> [Char] -> Function
forall a b. (a -> b) -> a -> b
$ Function -> [Char]
forall a. Show a => a -> [Char]
show Function
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" has no support for inverse function"
{-# INLINE inverseFunc #-}
evalInverse :: Floating a => Function -> a -> a
evalInverse :: forall a. Floating a => Function -> a -> a
evalInverse Function
Id = a -> a
forall a. a -> a
id
evalInverse Function
Sin = a -> a
forall a. Floating a => a -> a
asin
evalInverse Function
Cos = a -> a
forall a. Floating a => a -> a
acos
evalInverse Function
Tan = a -> a
forall a. Floating a => a -> a
atan
evalInverse Function
Sinh = a -> a
forall a. Floating a => a -> a
asinh
evalInverse Function
Cosh = a -> a
forall a. Floating a => a -> a
acosh
evalInverse Function
Tanh = a -> a
forall a. Floating a => a -> a
atanh
evalInverse Function
ASin = a -> a
forall a. Floating a => a -> a
sin
evalInverse Function
ACos = a -> a
forall a. Floating a => a -> a
cos
evalInverse Function
ATan = a -> a
forall a. Floating a => a -> a
tan
evalInverse Function
ASinh = a -> a
forall a. Floating a => a -> a
sinh
evalInverse Function
ACosh = a -> a
forall a. Floating a => a -> a
cosh
evalInverse Function
ATanh = a -> a
forall a. Floating a => a -> a
tanh
evalInverse Function
Sqrt = (a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
evalInverse Function
SqrtAbs = (a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)
evalInverse Function
Square = a -> a
forall a. Floating a => a -> a
sqrt
evalInverse Function
Cbrt = (a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3)
evalInverse Function
Log = a -> a
forall a. Floating a => a -> a
exp
evalInverse Function
LogAbs = a -> a
forall a. Floating a => a -> a
exp
evalInverse Function
Exp = a -> a
forall a. Floating a => a -> a
log
evalInverse Function
Abs = a -> a
forall a. Num a => a -> a
abs
evalInverse Function
Recip = a -> a
forall a. Fractional a => a -> a
recip
evalInverse Function
Cube = a -> a
forall a. Floating a => a -> a
cbrt
{-# INLINE evalInverse #-}
invright :: Floating a => Op -> a -> (a -> a)
invright :: forall a. Floating a => Op -> a -> a -> a
invright Op
Add a
v = a -> a -> a
forall a. Num a => a -> a -> a
subtract a
v
invright Op
Sub a
v = (a -> a -> a
forall a. Num a => a -> a -> a
+a
v)
invright Op
Mul a
v = (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
v)
invright Op
Div a
v = (a -> a -> a
forall a. Num a => a -> a -> a
*a
v)
invright Op
Power a
v = (a -> a -> a
forall a. Floating a => a -> a -> a
**(a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
v))
invright Op
PowerAbs a
v = (a -> a -> a
forall a. Floating a => a -> a -> a
**(a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
v))
invright Op
AQ a
v = (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
va -> a -> a
forall a. Num a => a -> a -> a
*a
v))
{-# INLINE invright #-}
invleft :: Floating a => Op -> a -> (a -> a)
invleft :: forall a. Floating a => Op -> a -> a -> a
invleft Op
Add a
v = a -> a -> a
forall a. Num a => a -> a -> a
subtract a
v
invleft Op
Sub a
v = (a -> a -> a
forall a. Num a => a -> a -> a
+a
v) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate
invleft Op
Mul a
v = (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
v)
invleft Op
Div a
v = (a
va -> a -> a
forall a. Fractional a => a -> a -> a
/)
invleft Op
Power a
v = a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
v
invleft Op
PowerAbs a
v = a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
v (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
abs
invleft Op
AQ a
v = (a
va -> a -> a
forall a. Fractional a => a -> a -> a
/)
{-# INLINE invleft #-}
invertibles :: [Function]
invertibles :: [Function]
invertibles = [Function
Id, Function
Sin, Function
Cos, Function
Tan, Function
Tanh, Function
ASin, Function
ACos, Function
ATan, Function
ATanh, Function
Sqrt, Function
Square, Function
Log, Function
Exp, Function
Recip]
{-# INLINE invertibles #-}