{-# LANGUAGE TemplateHaskell #-}
module Data.Comp.Derive.Equality
(
EqF(..),
makeEqF
) where
import Data.Comp.Derive.Utils
import Language.Haskell.TH hiding (Cxt, match)
class EqF f where
eqF :: Eq a => f a -> f a -> Bool
makeEqF :: Name -> Q [Dec]
makeEqF :: Name -> Q [Dec]
makeEqF Name
fname = do
Just (DataInfo Cxt
_cxt Name
name [TyVarBndr flag]
args [Con]
constrs [DerivClause]
_deriving) <- Q Info -> Q (Maybe DataInfo)
abstractNewtypeQ (Q Info -> Q (Maybe DataInfo)) -> Q Info -> Q (Maybe DataInfo)
forall a b. (a -> b) -> a -> b
$ Name -> Q Info
reify Name
fname
let argNames :: Cxt
argNames = (TyVarBndr flag -> Type) -> [TyVarBndr flag] -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Type
VarT (Name -> Type)
-> (TyVarBndr flag -> Name) -> TyVarBndr flag -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr flag -> Name
forall {flag}. TyVarBndr flag -> Name
tyVarBndrName) ([TyVarBndr flag] -> [TyVarBndr flag]
forall a. HasCallStack => [a] -> [a]
init [TyVarBndr flag]
args)
complType :: Type
complType = (Type -> Type -> Type) -> Type -> Cxt -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
name) Cxt
argNames
preCond :: Cxt
preCond = (Type -> Type) -> Cxt -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Cxt -> Type
mkClassP ''Eq (Cxt -> Type) -> (Type -> Cxt) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
: [])) Cxt
argNames
classType :: Type
classType = Type -> Type -> Type
AppT (Name -> Type
ConT ''EqF) Type
complType
Dec
eqFDecl <- Name -> [Q Clause] -> Q Dec
forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'eqF ([Con] -> [Q Clause]
eqFClauses [Con]
constrs)
[Dec] -> Q [Dec]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Cxt -> Type -> [Dec] -> Dec
mkInstanceD Cxt
preCond Type
classType [Dec
eqFDecl]]
where eqFClauses :: [Con] -> [Q Clause]
eqFClauses [Con]
constrs = (Con -> Q Clause) -> [Con] -> [Q Clause]
forall a b. (a -> b) -> [a] -> [b]
map ((Name, Int) -> Q Clause
genEqClause((Name, Int) -> Q Clause)
-> (Con -> (Name, Int)) -> Con -> Q Clause
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Con -> (Name, Int)
abstractConType) [Con]
constrs
[Q Clause] -> [Q Clause] -> [Q Clause]
forall a. [a] -> [a] -> [a]
++ [Con] -> [Q Clause]
forall {t :: * -> *} {m :: * -> *} {a}.
(Foldable t, Quote m) =>
t a -> [m Clause]
defEqClause [Con]
constrs
defEqClause :: t a -> [m Clause]
defEqClause t a
constrs
| t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
constrs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 = []
| Bool
otherwise = [[m Pat] -> m Body -> [m Dec] -> m Clause
forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [m Pat
forall (m :: * -> *). Quote m => m Pat
wildP,m Pat
forall (m :: * -> *). Quote m => m Pat
wildP] (m Exp -> m Body
forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [|False|]) []]
genEqClause :: (Name, Int) -> Q Clause
genEqClause (Name
constr, Int
n) = do
[Name]
varNs <- Int -> String -> Q [Name]
newNames Int
n String
"x"
[Name]
varNs' <- Int -> String -> Q [Name]
newNames Int
n String
"y"
let pat :: Pat
pat = Name -> Cxt -> [Pat] -> Pat
ConP Name
constr [] ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
varNs
pat' :: Pat
pat' = Name -> Cxt -> [Pat] -> Pat
ConP Name
constr [] ([Pat] -> Pat) -> [Pat] -> Pat
forall a b. (a -> b) -> a -> b
$ (Name -> Pat) -> [Name] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP [Name]
varNs'
vars :: [Exp]
vars = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
varNs
vars' :: [Exp]
vars' = (Name -> Exp) -> [Name] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE [Name]
varNs'
mkEq :: Exp -> Exp -> m Exp
mkEq Exp
x Exp
y = let (m Exp
x',m Exp
y') = (Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
x,Exp -> m Exp
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
y)
in [| $m Exp
x' == $m Exp
y'|]
eqs :: Q Exp
eqs = [Q Exp] -> Q Exp
forall (m :: * -> *). Quote m => [m Exp] -> m Exp
listE ([Q Exp] -> Q Exp) -> [Q Exp] -> Q Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Q Exp) -> [Exp] -> [Exp] -> [Q Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Q Exp
forall {m :: * -> *}. Quote m => Exp -> Exp -> m Exp
mkEq [Exp]
vars [Exp]
vars'
Exp
body <- if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
then [|True|]
else [|and $Q Exp
eqs|]
Clause -> Q Clause
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return (Clause -> Q Clause) -> Clause -> Q Clause
forall a b. (a -> b) -> a -> b
$ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
pat, Pat
pat'] (Exp -> Body
NormalB Exp
body) []