Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typed structural equality #154

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## next -- *TBA*

## 2.1.6.0.100

* Updates `Data.Parameterized.TH.GADT.structuralEquality` to add type
assertions to cover all type parameters. This change may require the
addition of the `ScopedTypeVariables` pragma to modules importing this code.

## 2.1.6.0 -- *2022 Dec 18*

* Added `FinMap`: an integer map with a statically-known maximum size.
Expand Down
2 changes: 1 addition & 1 deletion parameterized-utils.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Cabal-version: 2.2
Name: parameterized-utils
Version: 2.1.6.0.99
Version: 2.1.6.0.100
Author: Galois Inc.
Maintainer: [email protected]
stability: stable
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Parameterized/Classes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import Data.Type.Equality as Equality
import Data.Parameterized.Compose ()

-- We define these type alias here to avoid importing Control.Lens
-- modules, as this apparently causes problems with the safe Hasekll
-- modules, as this apparently causes problems with the safe Haskell
-- checking.
type Lens' s a = forall f. Functor f => (a -> f a) -> s -> f s
type Traversal' s a = forall f. Applicative f => (a -> f a) -> s -> f s
Expand Down
128 changes: 108 additions & 20 deletions src/Data/Parameterized/TH/GADT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
------------------------------------------------------------------------
{-# LANGUAGE CPP #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE EmptyCase #-}
module Data.Parameterized.TH.GADT
( -- * Instance generators
-- $typePatterns
Expand All @@ -40,15 +41,16 @@ module Data.Parameterized.TH.GADT
, assocTypePats
) where

import Control.Monad
import Data.Maybe
import Data.Set (Set)
import Control.Monad
import Data.Function ( on )
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Language.Haskell.TH
import Language.Haskell.TH.Datatype


import Data.Parameterized.Classes
import Data.Parameterized.Classes

------------------------------------------------------------------------
-- Template Haskell utilities
Expand Down Expand Up @@ -133,10 +135,72 @@ typeVars :: TypeSubstitution a => a -> Set Name
typeVars = Set.fromList . freeVariables


-- | @structuralEquality@ declares a structural equality predicate.
-- | @structuralEquality@ declares a structural equality predicate for a GADT.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but this comment could use some expansion. For instance, IIUC, this is mostly used to generate == for Eq instances, if so, the Haddock should probably say that.

structuralEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralEquality tpq pats =
[| \x y -> isJust ($(structuralTypeEquality tpq pats) x y) |]
structuralEquality tpq pats = do
d <- reifyDatatype =<< asTypeCon "structuralEquality" =<< tpq

-- tpq is some type of GADT: data X p1 p2 ... where ...
--
-- The general approach is to generate a structural type equality such that the
-- result is a Maybe (e :+: f) is Just Refl and then verify it is a Just value
-- to assert equality by generating (via template haskell):
--
-- \ x y -> isJust $(structuralTypeEquality ... x y)
--
-- However, that result presumes a `TestEquality f where testEquality :: f a ->
-- f b -> Maybe (a :~: b)`. If the GADT has a single type parameter, those
Comment on lines +151 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like structuralEquality calls structuralTypeEquality, which generates (the syntax of) a function of type f a -> f b -> Maybe (a :~: b), meaning that using structuralEquality would not require that tpq have an instance of TestEquality.

-- types align and there is no problem. If the GADT has multiple type
-- variables, GHC is unsure of which we are making the TestEquality assertion
-- about and we need to help. We actually want to make that assertion over
-- _all_ of the parameters, so given:
Comment on lines +151 to +156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the issue here is the fact that the GADT has multiple type variables, but rather it is a type inference issue. Would you be willing to rewrite this comment to explain this aspect of the problem? (The issue description that I left elsewhere in my review could serve as a template for the new comment.)

Also, this comment is pretty long. It might be helpful to split this out into a separate Note so that we don't break up the code with a large number of comment lines.

--
-- data D p1 p2 p3 where ...
--
-- the template haskell here should generate:
--
-- \ (x :: D xt1 xt2 xt3) (y :: D yt1 yt2 yt3) ->
-- isJust ( ($(structuralTypeEquality ... x y))
-- :: Maybe ( '(xt1, xt2, xt3) :~: '(yt1, yt2, yt3) )
-- )
--
-- This will perform the equality check in a way that obtains proof of equality
-- for all of the type parameters. This will require the ScopedTypeVariables
-- pragma, but GHC will happily suggest that if it's missing.
Comment on lines +160 to +169
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'm quite skeptical of equating all of the type parameters. Consider this data type:

data S a b where
  MkS1 :: S a Bool
  MkS2 :: S a Double

This use of structuralTypeEquality will typecheck:

f :: S a1 b1 -> S a2 b2 -> Maybe (b1 :~: b2)
f = $(structuralTypeEquality [t|S|] [])

But this would not:

f :: S a1 b1 -> S a2 b2 -> Maybe ('(a1, b1) :~: '(a2, b2))
f = $(structuralTypeEquality [t|S|] [])

And indeed, it's not clear how this could typecheck, as matching on S's data constructors don't provide any way to scrutinize a1 or a2.

--
-- This is also useful for the equality test on the single parameter case:
--
-- data D p1 where ...
--
-- instance Eq (D a) where
-- (==) = $(structuralEquality [t|D|] []
--
-- Again, this will fail without the template haskell assertion of the target
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, "assertion" implies "checked (and may fail) at run-time". I could see "ascription" or "annotation" here.

Suggested change
-- Again, this will fail without the template haskell assertion of the target
-- Again, this will fail without the template haskell ascription of the target

-- types matching the argument types.

gadtParams <- return $ datatypeInstTypes d
arg1Params <- fmap varT <$> newNames "xTy" (length gadtParams)
arg2Params <- fmap varT <$> newNames "yTy" (length gadtParams)
let arg1Ty = foldl appT (conT $ datatypeName d) arg1Params
let arg2Ty = foldl appT (conT $ datatypeName d) arg2Params
#if MIN_VERSION_base(4,14,0)
let mkSuperTy tyList = foldl appT (promotedTupleT (length tyList)) tyList
#else
let mkSuperTy tyList =
if length tyList < 2
then if length tyList == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how much computation is really done, i.e., how lazy length is, but this form makes it abundantly clear that it is not at all necessary to compute the length:

Suggested change
then if length tyList == 0
then if null tyList

then error "Expected at least one type in structuralEquality"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error would be for cases like $(structuralTypeEquality [t|Bool]), right? Perhaps this would be more clear?

Suggested change
then error "Expected at least one type in structuralEquality"
then error "Expected at least one type parameter in structuralEquality"

else head tyList
Comment on lines +191 to +193
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternate way of phrasing this that avoids use of the partial function head:

Suggested change
then if length tyList == 0
then error "Expected at least one type in structuralEquality"
else head tyList
then
if case tyList of
[] -> error "Expected at least one type in structuralEquality"
(ty:_) -> tyList

else foldl appT (promotedTupleT (length tyList)) tyList
#endif
let arg1AllParamTy = mkSuperTy arg1Params
let arg2AllParamTy = mkSuperTy arg2Params

[| \(x :: $(arg1Ty)) (y :: $(arg2Ty)) ->
isJust ($(structuralTypeEquality_ True tpq pats) x y
:: Maybe ($(arg1AllParamTy) :~: $(arg2AllParamTy))
)
|]
Comment on lines +181 to +203
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit saddened by how complex this function has become, as structuralEquality really ought to be as simple as calling isJust $(structuralTypeEquality ...). Is there any reason that we couldn't just keep the implementation of structuralEquality as-is, but move the explicit type ascriptions to structuralTypeEquality? The type ascriptions aren't strictly required for structuralTypeEquality, but it wouldn't do any harm to add them, I think.


joinEqMaybe :: Name -> Name -> ExpQ -> ExpQ
joinEqMaybe x y r = do
Expand Down Expand Up @@ -181,26 +245,44 @@ matchEqArguments _ _ _ _ _ _ [] = error "Unexpected end of names."
mkSimpleEqF :: [Type] -- ^ Data declaration types
-> Set Name
-> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments
-> ConstructorInfo
-> ConstructorInfo -- ^ The constructor we are concerned with
-> [Name]
-> ExpQ
-> Bool -- ^ wildcard case required
-> ExpQ
mkSimpleEqF dTypes bnd pats con xv yQ multipleCases = do
-> [ConstructorInfo] -- ^ All constructors (for determining if wildcard case required)
-> Bool -- ^ True if the equality arguments are the same type
-> ExpQ
mkSimpleEqF dTypes bnd pats con xv yQ multipleCases argsSameType = do
-- Get argument types for constructor.
let nm = constructorName con
(yp,yv) <- conPat con "y"
let rv = matchEqArguments dTypes pats nm bnd (constructorFields con) xv yv
let otherMatchingCons =
-- Determine the other constructors that should be matched relative to
-- `con`. If this is supplying code for `testEquality`, the input
-- signature is `f a -> f b -> ...` and will admit different types, so
-- all constructors should be checked, but if this is supplying code for
-- `Eq` or similar where the input signature is `a -> a -> ...`
-- (i.e. `argsSameType` is `True`), then only constructors that have the
-- same resulting type should be checked, otherwise GHC will emit
-- warnings/errors about "pattern not reached" for the case statement
-- being generated here.
let sameContext = (==) `on` constructorContext
in if argsSameType
then filter (sameContext con) multipleCases
else multipleCases
caseE yQ $ match (pure yp) (normalB rv) []
: [ match wildP (normalB [| Nothing |]) [] | multipleCases ]
: [ match wildP (normalB [| Nothing |]) []
| 1 < length otherMatchingCons
]

-- | Match equational form.
mkEqF :: DatatypeInfo -- ^ Data declaration.
-> [(TypePat,ExpQ)]
-> ConstructorInfo
-> ConstructorInfo -- ^ Constructor for which equality is to be determined
-> [Name]
-> ExpQ
-> Bool -- ^ wildcard case required
-> [ConstructorInfo] -- ^ All constructors (for determining if wildcard case required)
-> Bool -- ^ True if the equality arguments are the same type
-> ExpQ
mkEqF d pats con =
let dVars = dataParamTypes d -- the type arguments for the constructor
Expand All @@ -216,12 +298,18 @@ mkEqF d pats con =
-- forall x y . f x -> f y -> Maybe (x :~: y)
-- @
structuralTypeEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralTypeEquality tpq pats = do
structuralTypeEquality = structuralTypeEquality_ False

structuralTypeEquality_ :: Bool -> TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralTypeEquality_ argsSameType tpq pats = do
d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq

let multipleCons = not (null (drop 1 (datatypeCons d)))
let multipleCons = datatypeCons d
trueEqs yQ = [ do (xp,xv) <- conPat con "x"
match (pure xp) (normalB (mkEqF d pats con xv yQ multipleCons)) []
match (pure xp)
(normalB
(mkEqF d pats con xv yQ multipleCons argsSameType))
[]
| con <- datatypeCons d
]

Expand Down
68 changes: 60 additions & 8 deletions test/Test/TH.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
Expand All @@ -13,14 +14,15 @@ module Test.TH
)
where

import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty
import Test.Tasty.HUnit

import Control.Monad (when)
import Data.Parameterized.Classes
import Data.Parameterized.NatRepr
import Data.Parameterized.TH.GADT
import GHC.TypeNats
import Control.Monad (when)
import Data.Parameterized.Classes
import Data.Parameterized.NatRepr
import Data.Parameterized.SymbolRepr
import Data.Parameterized.TH.GADT
import GHC.TypeNats
Comment on lines -16 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why did this formatting change? Would be good to have it in a separate commit, if necessary.


data T1 = A | B | C
$(mkRepr ''T1)
Expand All @@ -39,13 +41,41 @@ instance TestEquality T2Repr where
[ (AnyType, [|testEquality|]) ])
deriving instance Show (T2Repr t)

data T3 (is_a :: Symbol) where
T3_Int :: Int -> T3 "int"
T3_Bool :: Bool -> T3 "bool"
$(return [])
instance TestEquality T3 where
testEquality = $(structuralTypeEquality [t|T3|] [])
instance Eq (T3 s) where
(==) = $(structuralEquality [t|T3|] [])
deriving instance Show (T3 s)

data T4 b (is_a :: Symbol) where
T4_Int :: Int -> T4 b "int"
T4_Bool :: Bool -> T4 b "bool"
$(return [])
instance TestEquality (T4 b) where
testEquality = $(structuralTypeEquality [t|T4|] [])
instance Eq (T4 b s) where
(==) = $(structuralEquality [t|T4|] [])
deriving instance Show (T4 b s)

eqTest :: (TestEquality f, Show (f a), Show (f b)) => f a -> f b -> IO ()
eqTest a b =
when (not (isJust (testEquality a b))) $ assertFailure $ show a ++ " /= " ++ show b

neqTest :: (TestEquality f, Show (f a), Show (f b)) => f a -> f b -> IO ()
neqTest a b =
when (isJust (testEquality a b)) $ assertFailure $ show a ++ " == " ++ show b
when (isJust (testEquality a b))
$ assertFailure
$ show a <> " == " <> show b <> " but should not be!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Looks like these are still strings, so I'd prefer keeping ++ rather than <>, as it's less polymorphic, and so lets the reader know the more specific types involved.


assertNotEqual :: (Eq a, Show a) => String -> a -> a -> IO ()
assertNotEqual msg a b =
when (a == b)
$ assertFailure
$ msg <> " " <> show a <> " == " <> show b <> " but should not be!"

thTests :: IO TestTree
thTests = testGroup "TH" <$> return
Expand All @@ -62,6 +92,28 @@ thTests = testGroup "TH" <$> return
T2_2Repr (knownNat @5) `neqTest` T2_2Repr (knownNat @9)
T2_1Repr BRepr `neqTest` T2_2Repr (knownNat @4)

, testCase "Instance tests" $ do
assertEqual "T3_Int values" (T3_Int 5) (T3_Int 5)
assertNotEqual "T3_Int values" (T3_Int 5) (T3_Int 54)
assertEqual "T3_Bool values" (T3_Bool True) (T3_Bool True)
assertNotEqual "T3_Bool values" (T3_Bool True) (T3_Bool False)

-- n.b. the following is not possible: 'T3 "int"' is not a 'T3 "bool"'
-- assertEqual "T3_Int/T3_Bool values" (T3_Int 1) (T3_Bool True)

T3_Int 1 `eqTest` T3_Int 1
T3_Int 1 `neqTest` T3_Int 3
T3_Int 1 `neqTest` T3_Bool True
T3_Bool False `neqTest` T3_Bool True
T3_Bool True `eqTest` T3_Bool True

assertEqual "T4_Int values" (T4_Int @String 5) (T4_Int @String 5)
assertNotEqual "T4_Int values" (T4_Int @String 5) (T4_Int @String 54)

T4_Int @String 1 `eqTest` T4_Int @String 1
T4_Int @String 1 `neqTest` T4_Int @String 2
Comment on lines +95 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe get more coverage with some property tests:

forall (i j :: Int). (T3_Int i == T3_Int j) == (i == j)
[same for Bool]
forall (i j :: Int). (T3_Int i == T3_Int j) == isJust (testEquality (T3_Int i) (T3_Int j))
...



, testCase "KnownRepr test" $ do
-- T1
let aRepr = knownRepr :: T1Repr 'A
Expand Down