diff --git a/rts/memory.c b/rts/memory.c index c946605e..52651b3f 100644 --- a/rts/memory.c +++ b/rts/memory.c @@ -60,9 +60,6 @@ void sixten_copy( uint32_t pointers, uint32_t non_pointer_bytes ) { - for (uint32_t i = 0; i < pointers; ++i) { - sixten_increase_reference_count(src.pointers[i]); - } memcpy(dst.pointers, src.pointers, sizeof(void*) * pointers); memcpy(dst.non_pointers, src.non_pointers, non_pointer_bytes); } diff --git a/src/Compiler.hs b/src/Compiler.hs index 11a6732b..c139394b 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -84,7 +84,7 @@ compile assemblyDir saveAssembly outputExecutableFile maybeOptimisationLevel pri else callProcess clang $ optimisationArgs <> ["-fPIC", "-Wno-override-module", "-o", outputExecutableFile, builtinCFile, memoryCFile] <> llvmFiles supportedLlvmVersions :: [Int] -supportedLlvmVersions = [17, 16, 15] +supportedLlvmVersions = [18, 17, 16, 15] -- | llvm-config is not available in current LLVM distribution for windows, so we -- need use @clang -print-prog-name=clang@ to get the full path of @clang@. diff --git a/src/Low/Pretty.hs b/src/Low/Pretty.hs index 3e9c651e..8425c75b 100644 --- a/src/Low/Pretty.hs +++ b/src/Low/Pretty.hs @@ -103,7 +103,7 @@ prettyLetOperation env = \case 2 ( vcat $ (prettyBranch env <$> branches) - <> [ "_" <+> "->" <+> prettyTerm env branch + <> [ "_" <+> "->" <> line <> indent 2 (prettyTerm env branch) | Just branch <- [defaultBranch] ] ) @@ -132,6 +132,8 @@ prettySeqOperation env = \case "copy" <+> commaSep [prettyOperand env dst, prettyOperand env src, prettyOperand env size] Syntax.IncreaseReferenceCount operand repr -> "increase_reference_count" <+> pretty repr <+> prettyOperand env operand + Syntax.IncreaseReferenceCounts operand repr -> + "increase_reference_counts" <+> prettyOperand env repr <+> prettyOperand env operand Syntax.DecreaseReferenceCount operand repr -> "decrease_reference_count" <+> pretty repr <+> prettyOperand env operand @@ -198,9 +200,9 @@ prettyBranch -> Doc ann prettyBranch env = \case Syntax.ConstructorBranch constr body -> - prettyConstr env constr <+> "->" <> prettyTerm env body + prettyConstr env constr <+> "->" <> line <> indent 2 (prettyTerm env body) Syntax.LiteralBranch lit body -> - pretty lit <+> "->" <> prettyTerm env body + pretty lit <+> "->" <> line <> indent 2 (prettyTerm env body) ------------------------------------------------------------------------------- diff --git a/src/Low/Syntax.hs b/src/Low/Syntax.hs index f37c6acf..3ec7074b 100644 --- a/src/Low/Syntax.hs +++ b/src/Low/Syntax.hs @@ -34,6 +34,7 @@ data SeqOperation v = Store !(Operand v) !(Operand v) !Representation | Copy !(Operand v) !(Operand v) !(Operand v) | IncreaseReferenceCount !(Operand v) !Representation + | IncreaseReferenceCounts !(Operand v) !(Operand v) | DecreaseReferenceCount !(Operand v) !Representation deriving (Eq, Show, Generic, Hashable) diff --git a/src/LowToLLVM.hs b/src/LowToLLVM.hs index 899a4806..3277920e 100644 --- a/src/LowToLLVM.hs +++ b/src/LowToLLVM.hs @@ -251,13 +251,13 @@ increaseReferenceCount repr o = "call void @sixten_increase_reference_count" <> parens ["i64 " <> varName extractedPointer] -decreaseReferenceCounts :: Operand -> Var -> Assembler () -decreaseReferenceCounts size reference = do - declareLLVMGlobal "sixten_decrease_reference_counts" "declare void @sixten_decrease_reference_counts(ptr, i32)" +increaseReferenceCounts :: Operand -> Operand -> Assembler () +increaseReferenceCounts size reference = do + declareLLVMGlobal "sixten_increase_reference_counts" "declare void @sixten_increase_reference_counts(ptr, i32)" (pointers, _) <- extractSizeParts (PassBy.Value Representation.type_, size) - (pointersPointer, _) <- extractParts (PassBy.Reference, Local reference) + (pointersPointer, _) <- extractParts (PassBy.Reference, reference) emitInstruction $ - "call void @sixten_decrease_reference_counts" + "call void @sixten_increase_reference_counts" <> parens ["ptr " <> operand pointersPointer, "i32 " <> varName pointers] decreaseReferenceCount :: Representation -> Operand -> Assembler () @@ -294,6 +294,15 @@ decreaseReferenceCount repr o = "call void @sixten_decrease_reference_count" <> parens ["i64 " <> varName extractedPointer] +decreaseReferenceCounts :: Operand -> Var -> Assembler () +decreaseReferenceCounts size reference = do + declareLLVMGlobal "sixten_decrease_reference_counts" "declare void @sixten_decrease_reference_counts(ptr, i32)" + (pointers, _) <- extractSizeParts (PassBy.Value Representation.type_, size) + (pointersPointer, _) <- extractParts (PassBy.Reference, Local reference) + emitInstruction $ + "call void @sixten_decrease_reference_counts" + <> parens ["ptr " <> operand pointersPointer, "i32 " <> varName pointers] + ------------------------------------------------------------------------------- assembleModule :: [(Name.Lowered, Syntax.Definition)] -> M Lazy.ByteString @@ -350,8 +359,7 @@ assembleFunction functionName env = \case let parameters = second fromLocal <$> Index.Seq.toSeq env entry <- freshVar "entry" startBlock entry - (result, stack) <- assembleTerm env Nothing passReturnBy term - mapM_ restoreStack stack + result <- assembleTerm env Nothing passReturnBy term endBlock case passReturnBy of PassBy.Value Representation.Empty -> "ret " <> llvmReturnType passReturnBy _ -> "ret " <> llvmReturnType passReturnBy <> " " <> operand result @@ -406,23 +414,27 @@ assembleTerm -> Maybe Name -> PassBy -> Syntax.Term v - -> Assembler (Operand, Maybe StackAllocation) + -> Assembler Operand assembleTerm env nameSuggestion passBy = \case Syntax.Operand o -> do (_, o') <- assembleOperand env o - pure (o', Nothing) - Syntax.Let passLetBy name term body -> do - (termResult, termStack) <- assembleTerm env (Just name) passLetBy term - (bodyResult, bodyStack) <- assembleTerm (env Index.Seq.:> (passLetBy, termResult)) nameSuggestion passBy body - mapM_ restoreStack termStack - mapM_ restoreStack bodyStack - pure (bodyResult, Nothing) - Syntax.Seq term1 term2 -> do - (_, stack1) <- assembleTerm env Nothing (PassBy.Value Representation.Empty) term1 - (result, stack2) <- assembleTerm env nameSuggestion passBy term2 - mapM_ restoreStack stack1 - mapM_ restoreStack stack2 - pure (result, Nothing) + pure o' + Syntax.Let passLetBy name operation body -> do + (operationResult, operationStack) <- assembleLetOperation env (Just name) passLetBy operation + bodyResult <- assembleTerm (env Index.Seq.:> (passLetBy, operationResult)) nameSuggestion passBy body + mapM_ restoreStack operationStack + pure bodyResult + Syntax.Seq operation body -> do + assembleSeqOperation env operation + assembleTerm env nameSuggestion passBy body + +assembleLetOperation + :: Environment v + -> Maybe Name + -> PassBy + -> Syntax.LetOperation v + -> Assembler (Operand, Maybe StackAllocation) +assembleLetOperation env nameSuggestion passBy = \case Syntax.Case scrutinee branches defaultBranch -> do scrutinee' <- assembleOperand env scrutinee branchLabels <- forM branches \case @@ -452,15 +464,13 @@ assembleTerm env nameSuggestion passBy = \case ] branchResults <- forM (zip branchLabels branches) \((_, branchLabel), branch) -> do startBlock branchLabel - (result, stack) <- assembleTerm env nameSuggestion passBy $ Syntax.branchTerm branch - mapM_ restoreStack stack + result <- assembleTerm env nameSuggestion passBy $ Syntax.branchTerm branch afterBranchLabel <- gets (.basicBlockName) endBlock $ "br label " <> varName afterSwitchLabel pure (afterBranchLabel, result) startBlock defaultLabel maybeDefaultResult <- forM defaultBranch \branch -> do - (result, stack) <- assembleTerm env nameSuggestion passBy branch - mapM_ restoreStack stack + result <- assembleTerm env nameSuggestion passBy branch afterBranchLabel <- gets (.basicBlockName) pure (afterBranchLabel, result) let defaultResult = fromMaybe (defaultLabel, Constant "undef") maybeDefaultResult @@ -601,36 +611,6 @@ assembleTerm env nameSuggestion passBy = \case "ptr" updatedNonPointerPointer pure (Local result, Nothing) - Syntax.Copy dst src size -> do - dst' <- assembleOperand env dst - src' <- assembleOperand env src - size' <- assembleOperand env size - (pointers, nonPointerBytes) <- extractSizeParts size' - declareLLVMGlobal "sixten_copy" "declare void @sixten_copy({ptr, ptr}, {ptr, ptr}, i32, i32)" - emitInstruction $ - "call void @sixten_copy" - <> parens - [ typedOperand dst' - , typedOperand src' - , "i32 " <> varName pointers - , "i32 " <> varName nonPointerBytes - ] - pure (Constant "undef", Nothing) - Syntax.Store dst src repr -> do - dst' <- assembleOperand env dst - src' <- assembleOperand env src - (dstPointerPointer, dstNonPointerPointer) <- extractParts dst' - case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of - (Nothing, Nothing) -> pure () - (Just _, Nothing) -> - emitInstruction $ "store " <> typedOperand src' <> ", ptr " <> operand dstPointerPointer - (Nothing, Just _) -> - emitInstruction $ "store " <> typedOperand src' <> ", ptr " <> operand dstNonPointerPointer - (Just p, Just np) -> do - (pointerSrc, nonPointerSrc) <- extractParts src' - emitInstruction $ "store " <> p <> " " <> operand pointerSrc <> ", ptr " <> operand dstPointerPointer - emitInstruction $ "store " <> np <> " " <> operand nonPointerSrc <> ", ptr " <> operand dstNonPointerPointer - pure (Constant "undef", Nothing) Syntax.Load src repr -> do src' <- assembleOperand env src (srcPointerPointer, srcNonPointerPointer) <- extractParts src' @@ -653,14 +633,50 @@ assembleTerm env nameSuggestion passBy = \case pure $ Local result pure (result, Nothing) + +assembleSeqOperation + :: Environment v + -> Syntax.SeqOperation v + -> Assembler () +assembleSeqOperation env = \case + Syntax.Store dst src repr -> do + dst' <- assembleOperand env dst + src' <- assembleOperand env src + (dstPointerPointer, dstNonPointerPointer) <- extractParts dst' + case (pointerType repr.pointers, nonPointerType repr.nonPointerBytes) of + (Nothing, Nothing) -> pure () + (Just _, Nothing) -> + emitInstruction $ "store " <> typedOperand src' <> ", ptr " <> operand dstPointerPointer + (Nothing, Just _) -> + emitInstruction $ "store " <> typedOperand src' <> ", ptr " <> operand dstNonPointerPointer + (Just p, Just np) -> do + (pointerSrc, nonPointerSrc) <- extractParts src' + emitInstruction $ "store " <> p <> " " <> operand pointerSrc <> ", ptr " <> operand dstPointerPointer + emitInstruction $ "store " <> np <> " " <> operand nonPointerSrc <> ", ptr " <> operand dstNonPointerPointer + Syntax.Copy dst src size -> do + dst' <- assembleOperand env dst + src' <- assembleOperand env src + size' <- assembleOperand env size + (pointers, nonPointerBytes) <- extractSizeParts size' + declareLLVMGlobal "sixten_copy" "declare void @sixten_copy({ptr, ptr}, {ptr, ptr}, i32, i32)" + emitInstruction $ + "call void @sixten_copy" + <> parens + [ typedOperand dst' + , typedOperand src' + , "i32 " <> varName pointers + , "i32 " <> varName nonPointerBytes + ] Syntax.IncreaseReferenceCount val repr -> do (_, val') <- assembleOperand env val increaseReferenceCount repr val' - pure (Constant "undef", Nothing) + Syntax.IncreaseReferenceCounts val repr -> do + (_, val') <- assembleOperand env val + (_, repr') <- assembleOperand env repr + increaseReferenceCounts repr' val' Syntax.DecreaseReferenceCount val repr -> do (_, val') <- assembleOperand env val decreaseReferenceCount repr val' - pure (Constant "undef", Nothing) assembleOperand :: Environment v -> Syntax.Operand v -> Assembler (PassBy, Operand) assembleOperand env = \case diff --git a/src/Lower.hs b/src/Lower.hs index 119f2db9..83f3afb6 100644 --- a/src/Lower.hs +++ b/src/Lower.hs @@ -514,11 +514,64 @@ generateTerm context nameSuggestion indices term typeValue = case term of Value Representation.type_ CC.Syntax.Closure {} -> panic "TODO closure" CC.Syntax.ApplyClosure {} -> panic "TODO closure" - CC.Syntax.Case _scrutinee type_ _branches _maybeDefault -> do - size <- generateTypeSize context indices type_ - dst <- letReference "case_dst" $ StackAllocate size - _ <- storeTerm context indices dst term - pure $ OperandStorage dst $ Reference size + CC.Syntax.Case scrutinee type_ branches maybeDefault -> do + passTypeBy <- lift $ CC.Representation.passTypeBy (CC.toEnvironment context) typeValue + case passTypeBy of + PassBy.Reference -> do + size <- generateTypeSize context indices type_ + dst <- letReference "case_dst" $ StackAllocate size + _ <- storeTerm context indices dst term + pure $ OperandStorage dst $ Reference size + PassBy.Value repr -> do + scrutinee' <- generateTermWithoutType context indices scrutinee + branches' <- CC.Representation.compileBranches branches + result <- case branches' of + CC.Representation.TaggedConstructorBranches Unboxed constrBranches -> do + scrutineeRef <- forceReference Nothing scrutinee' + tag <- letLoad "tag" scrutineeRef Representation.int + let payload name = letOffset name scrutineeRef $ Representation Representation.int + constrBranches' <- forM constrBranches \(constr, constrBranch) -> + map (ConstructorBranch constr) $ + lift $ + collect $ + generateBranch context indices payload repr typeValue constrBranch + defaultBranch <- + forM maybeDefault $ \branch -> + lift $ collect $ do + branch' <- generateTerm context Nothing indices branch typeValue + forceValue repr branch' + letValue repr "result" $ Case tag constrBranches' defaultBranch + CC.Representation.TaggedConstructorBranches Boxed constrBranches -> do + scrutineeValue <- forceValue Representation.pointer scrutinee' + tag <- letValue Representation.int "tag" $ PointerTag scrutineeValue + let payload name = letReference name $ HeapPayload scrutineeValue + constrBranches' <- forM constrBranches \(constr, constrBranch) -> + map (ConstructorBranch constr) $ lift $ collect do + generateBranch context indices payload repr typeValue constrBranch + defaultBranch <- forM maybeDefault $ \branch -> lift $ collect $ do + branch' <- generateTerm context Nothing indices branch typeValue + forceValue repr branch' + letValue repr "result" $ Case tag constrBranches' defaultBranch + CC.Representation.UntaggedConstructorBranch Unboxed constrBranch -> do + let payload name = forceReference (Just name) scrutinee' + generateBranch context indices payload repr typeValue constrBranch + CC.Representation.UntaggedConstructorBranch Boxed constrBranch -> do + scrutineeValue <- forceValue Representation.pointer scrutinee' + let payload name = letReference name $ HeapPayload scrutineeValue + generateBranch context indices payload repr typeValue constrBranch + CC.Representation.LiteralBranches litBranches -> do + scrutineeValue <- forceValue Representation.int scrutinee' + litBranches' <- forM (OrderedHashMap.toList litBranches) \(lit, litBranch) -> + map (LiteralBranch lit) $ + lift $ + collect $ do + litBranch' <- generateTerm context Nothing indices litBranch typeValue + forceValue repr litBranch' + defaultBranch <- forM maybeDefault $ \branch -> lift $ collect $ do + branch' <- generateTerm context Nothing indices branch typeValue + forceValue repr branch' + letValue repr "result" $ Case scrutineeValue litBranches' defaultBranch + pure $ OperandStorage result $ Value repr storeCall :: CC.Context v @@ -545,7 +598,28 @@ storeCall context indices dst function args passArgsBy passReturnBy = do callResult <- letCall passReturnBy "call_result" function callArgs storeOperand dst $ OperandStorage callResult $ Value repr PassBy.Reference -> - letCall passReturnBy "call_result_size" function (dst : callArgs) + letCall (PassBy.Value Representation.type_) "call_result_size" function (dst : callArgs) + +generateBranch + :: CC.Context v + -> Index.Seq v OperandStorage + -> (Name -> Collect Operand) + -> Representation + -> CC.Domain.Type + -> Telescope Name CC.Syntax.Type CC.Syntax.Term v + -> Collect Operand +generateBranch context indices mpayload repr typeValue = \case + Telescope.Empty term -> do + term' <- generateTerm context Nothing indices term typeValue + forceValue repr term' + Telescope.Extend name type_ _plicity tele -> do + payload <- mpayload name + size <- generateTypeSize context indices type_ + fieldTypeValue <- lift $ CC.Domain.Lazy <$> lazy (Evaluation.evaluate (CC.toEnvironment context) type_) + (context', _) <- lift $ CC.extend context fieldTypeValue + let indices' = indices Index.Seq.:> OperandStorage payload (Reference size) + let payload' name' = letOffset name' payload size + generateBranch context' indices' payload' repr typeValue tele storeBranch :: CC.Context v diff --git a/src/ReferenceCounting.hs b/src/ReferenceCounting.hs index 13605052..2d9bde06 100644 --- a/src/ReferenceCounting.hs +++ b/src/ReferenceCounting.hs @@ -52,6 +52,7 @@ data SeqOperation = Store !Operand !Operand !Representation | Copy !Operand !Operand !Operand | IncreaseReferenceCount !Operand !Representation + | IncreaseReferenceCounts !Operand !Operand | DecreaseReferenceCount !Operand !Representation deriving (Show) @@ -86,14 +87,7 @@ referenceCountFunction env liveOut = \case ReferenceCountState { provenances = mempty } - $ do - (value', valueProvenance) <- referenceCount passBy value - case valueProvenance of - Nothing -> case passBy of - PassBy.Reference -> pure value' - PassBy.Value repr -> increase value' repr - Just (Owned _ _) -> pure value' - Just (Child _) -> panic "Returning child" + $ referenceCount passBy value pure $ Syntax.Body passBy $ readback env value' Syntax.Parameter name passBy function -> do var <- freshVar @@ -183,6 +177,7 @@ evaluateSeqOperation env liveOut = \case let (dst', liveIn) = evaluateOperand env srcLiveIn dst pure (Store dst' src' repr, liveIn) Syntax.IncreaseReferenceCount {} -> panic "RC operations before reference counting" + Syntax.IncreaseReferenceCounts {} -> panic "RC operations before reference counting" Syntax.DecreaseReferenceCount {} -> panic "RC operations before reference counting" evaluateOperand :: Index.Map v Var -> EnumSet Var -> Syntax.Operand v -> (Operand, EnumSet Var) @@ -217,20 +212,16 @@ data Provenance | Child !Var deriving (Show) -referenceCount :: PassBy -> Value -> ReferenceCount (Value, Maybe Provenance) +referenceCount :: PassBy -> Value -> ReferenceCount Value referenceCount passBy value = case value of - Operand operand -> case operand of - Var Killed var -> do - provenances <- gets (.provenances) - pure (value, EnumMap.lookup var provenances) - Var NotKilled _ -> do - maybeParent <- tryMakeParent operand - pure (value, Child <$> maybeParent) - Global _ _ -> pure (value, Nothing) - Literal _ -> pure (value, Nothing) - Representation _ -> pure (value, Nothing) - Tag _ -> pure (value, Nothing) - Undefined _ -> pure (value, Nothing) + Operand operand -> do + decrease <- referenceCountOperand operand + pure + if cancelOut operand decrease + then value + else case passBy of + PassBy.Reference -> decreases decrease value + PassBy.Value repr -> increase operand repr $ decreases decrease value Let passValBy name var dead operation body -> do (operation', maybeOperationProvenance, decreaseAfters) <- referenceCountLetOperation passValBy operation forM_ maybeOperationProvenance \valProvenance -> @@ -238,22 +229,13 @@ referenceCount passBy value = case value of decreaseVar <- case dead of NotDead -> pure Nothing Dead -> referenceCountOperand $ Var Killed var - (body', bodyProvenance) <- referenceCount passBy body + body' <- referenceCount passBy body modify \s -> s {provenances = EnumMap.delete var s.provenances} - case bodyProvenance of - Just (Child var') | var == var' -> - case passBy of - PassBy.Reference -> panic "Returning reference to value going out of scope" - PassBy.Value repr -> do - decreaseVar <- kill var - body'' <- increase body' repr - val''' <- decreaseAfter decreaseVar val'' passValBy - pure (Let passValBy name var dead val''' body'', Just $ Owned (PassBy.Value repr) 1) - _ -> pure (Let passValBy name var dead val'' body', bodyProvenance) + pure $ Let passValBy name var dead operation' $ decreases decreaseAfters $ decreases decreaseVar body' Seq lhs rhs -> do - (increaseBefores, decreaseAfters) <- referenceCountSeqOperation lhs - (rhs', rhsProvenance) <- referenceCount passBy rhs - pure (Seq lhs' rhs', rhsProvenance) + (increaseBefores, increaseRefsBefore, decreaseAfters) <- referenceCountSeqOperation lhs + rhs' <- referenceCount passBy rhs + pure $ increaseRefs increaseRefsBefore $ increases increaseBefores $ Seq lhs $ decreases decreaseAfters rhs' referenceCountLetOperation :: PassBy @@ -265,26 +247,23 @@ referenceCountLetOperation passBy operation = case operation of startingState <- get branches' <- forM branches \(killedVars, branch) -> do put startingState - decreases <- catMaybes <$> forM (EnumSet.toList killedVars) kill + kills <- catMaybes <$> forM (EnumSet.toList killedVars) kill branch' <- case branch of ConstructorBranch constr branchValue -> do - (branchValue', provenance) <- referenceCount passBy branchValue - when (isJust provenance) $ panic $ "Branch with provenance " <> show branchValue' - pure $ ConstructorBranch constr $ decreaseBefore decreases branchValue' + branchValue' <- referenceCount passBy branchValue + pure $ ConstructorBranch constr $ decreases kills branchValue' LiteralBranch lit branchValue -> do - (branchValue', provenance) <- referenceCount passBy branchValue - when (isJust provenance) $ panic $ "Branch with provenance " <> show branchValue' - pure $ LiteralBranch lit $ decreaseBefore decreases branchValue' + branchValue' <- referenceCount passBy branchValue + pure $ LiteralBranch lit $ decreases kills branchValue' pure (killedVars, branch') maybeDefaultBranch' <- forM maybeDefaultBranch \(killedVars, branch) -> do put startingState - decreases <- catMaybes <$> forM (EnumSet.toList killedVars) kill - (branch', provenance) <- referenceCount passBy branch - when (isJust provenance) $ panic $ "Branch with provenance " <> show branch' - pure (killedVars, decreaseBefore decreases branch') + kills <- catMaybes <$> forM (EnumSet.toList killedVars) kill + branch' <- referenceCount passBy branch + pure (killedVars, decreases kills branch') pure (Case scrutinee branches' maybeDefaultBranch', Nothing, maybeToList decreaseScrutinee) Call _ args -> do - decreases <- catMaybes <$> mapM referenceCountOperand args + decreaseArgs <- catMaybes <$> mapM referenceCountOperand args pure ( operation , case passBy of @@ -292,7 +271,7 @@ referenceCountLetOperation passBy operation = case operation of | needsReferenceCounting repr -> Just $ Owned (PassBy.Value repr) 1 | otherwise -> Nothing PassBy.Reference -> Nothing - , decreases + , decreaseArgs ) StackAllocate _ -> pure (operation, Just $ Owned PassBy.Reference 1, []) @@ -318,50 +297,66 @@ referenceCountLetOperation passBy operation = case operation of referenceCountSeqOperation :: SeqOperation - -> ReferenceCount ([(Operand, Either Representation Operand)], [(Var, Representation)]) + -> ReferenceCount (Maybe (Operand, Representation), Maybe (Operand, Operand), [(Var, Representation)]) referenceCountSeqOperation operation = case operation of Copy dst src repr -> do decreaseDst <- referenceCountOperand dst decreaseSrc <- referenceCountOperand src - pure case (src, decreaseSrc) of - (Var Killed srcVar, Just (srcVar', _)) - | srcVar == srcVar' -> ([], maybeToList decreaseDst) - _ -> ([(src, Right repr)], catMaybes [decreaseDst, decreaseSrc]) + pure + if cancelOut src decreaseSrc + then (Nothing, Nothing, maybeToList decreaseDst) + else (Nothing, Just (src, repr), catMaybes [decreaseDst, decreaseSrc]) Store dst src repr -> do decreaseDst <- referenceCountOperand dst decreaseSrc <- referenceCountOperand src - pure case (src, decreaseSrc) of - (Var Killed srcVar, Just (srcVar', _)) - | srcVar == srcVar' -> ([], maybeToList decreaseDst) - _ -> ([(src, Left repr)], catMaybes [decreaseDst, decreaseSrc]) + pure + if cancelOut src decreaseSrc + then (Nothing, Nothing, maybeToList decreaseDst) + else (Just (src, repr), Nothing, catMaybes [decreaseDst, decreaseSrc]) IncreaseReferenceCount {} -> panic "RC operations before reference counting" + IncreaseReferenceCounts {} -> panic "RC operations before reference counting" DecreaseReferenceCount {} -> panic "RC operations before reference counting" needsReferenceCounting :: Representation -> Bool needsReferenceCounting repr = repr.pointers > 0 -increaseBefore :: Operand -> Representation -> Value -> Value -increaseBefore operand repr value +representationOperandNeedsReferenceCounting :: Operand -> Bool +representationOperandNeedsReferenceCounting (Representation repr) = needsReferenceCounting repr +representationOperandNeedsReferenceCounting _ = True + +increase :: Operand -> Representation -> Value -> Value +increase operand repr value | needsReferenceCounting repr = Seq (IncreaseReferenceCount operand repr) value | otherwise = value -increase :: Value -> Representation -> ReferenceCount Value -increase value repr - | needsReferenceCounting repr = do - var <- lift freshVar - pure $ - Let (PassBy.Value repr) "temp" var NotDead value $ - Seq (IncreaseReferenceCount (Var Killed var) repr) $ - Operand $ - Var Killed var - | otherwise = pure value - -decreaseBefore +increases :: Foldable f => f (Operand, Representation) -> Value -> Value +increases operands value = + foldr + ( \(o, repr) -> + if needsReferenceCounting repr + then Seq $ IncreaseReferenceCount o repr + else identity + ) + value + operands + +increaseRefs :: Foldable f => f (Operand, Operand) -> Value -> Value +increaseRefs operands value = + foldr + ( \(o, repr) -> + if representationOperandNeedsReferenceCounting repr + then Seq $ IncreaseReferenceCounts o repr + else identity + ) + value + operands + +decreases :: (Foldable f) => f (Var, Representation) -> Value -> Value -decreaseBefore vars value = +decreases vars value = foldr ( \(v, repr) -> if needsReferenceCounting repr @@ -371,24 +366,9 @@ decreaseBefore vars value = value vars -decreaseAfter - :: (Foldable f) - => f (Var, Representation) - -> Value - -> PassBy - -> ReferenceCount Value -decreaseAfter vars value passBy = - case vars' of - [] -> pure value - _ -> do - var <- lift freshVar - pure $ - Let passBy "temp" var NotDead value $ - decreaseBefore vars' $ - Operand $ - Var Killed var - where - vars' = filter (needsReferenceCounting . snd) $ toList vars +cancelOut :: Operand -> Maybe (Var, Representation) -> Bool +cancelOut (Var _ var) (Just (var', _)) = var == var' +cancelOut _ _ = False tryMakeParent :: Operand -> ReferenceCount (Maybe Var) tryMakeParent = \case @@ -459,14 +439,17 @@ kill var = do readback :: Index.Map v Var -> Value -> Syntax.Term v readback env = \case Operand operand -> Syntax.Operand $ readbackOperand env operand - Let passBy name var _dead value value' -> + Let passBy name var _dead operation value' -> Syntax.Let passBy name - (readback env value) + (readbackLetOperation env operation) (readback (env Index.Map.:> var) value') - Seq value value' -> - Syntax.Seq (readback env value) (readback env value') + Seq operation value' -> + Syntax.Seq (readbackSeqOperation env operation) (readback env value') + +readbackLetOperation :: Index.Map v Var -> LetOperation -> Syntax.LetOperation v +readbackLetOperation env = \case Case scrutinee branches maybeDefaultBranch -> Syntax.Case (readbackOperand env scrutinee) @@ -481,14 +464,18 @@ readback env = \case Syntax.Offset (readbackOperand env base) (readbackOperand env offset) + Load src repr -> Syntax.Load (readbackOperand env src) repr + +readbackSeqOperation :: Index.Map v Var -> SeqOperation -> Syntax.SeqOperation v +readbackSeqOperation env = \case Copy dst src size -> Syntax.Copy (readbackOperand env dst) (readbackOperand env src) (readbackOperand env size) Store dst value repr -> Syntax.Store (readbackOperand env dst) (readbackOperand env value) repr - Load src repr -> Syntax.Load (readbackOperand env src) repr IncreaseReferenceCount operand repr -> Syntax.IncreaseReferenceCount (readbackOperand env operand) repr + IncreaseReferenceCounts operand repr -> Syntax.IncreaseReferenceCounts (readbackOperand env operand) (readbackOperand env repr) DecreaseReferenceCount operand repr -> Syntax.DecreaseReferenceCount (readbackOperand env operand) repr readbackOperand :: Index.Map v Var -> Operand -> Syntax.Operand v diff --git a/src/Rules.hs b/src/Rules.hs index 5cb11404..bcb76111 100644 --- a/src/Rules.hs +++ b/src/Rules.hs @@ -481,7 +481,7 @@ rules sourceDirectories files readFile_ (Writer (Writer query)) = forM definitions $ mapM $ runM . ReferenceCounting.referenceCountDefinition LLVMModule module_ -> noError do - assemblyDefinitions <- fetch $ LowModule module_ + assemblyDefinitions <- fetch $ ReferenceCountedLowModule module_ runM $ LowToLLVM.assembleModule assemblyDefinitions LLVMModuleInitModule -> noError do diff --git a/tests/compilation/unboxed-data.vix b/tests/compilation/unboxed-data.vix index 1a126c1a..f6c7e9dd 100644 --- a/tests/compilation/unboxed-data.vix +++ b/tests/compilation/unboxed-data.vix @@ -3,90 +3,90 @@ data Unit = Unit data Either a b = Left a | Right b data Tuple a b = Tuple a b -identity : forall a. a -> a +-- identity : forall a. a -> a fromRightWithDefault : forall a b. b -> Either a b -> b fromRightWithDefault default (Left _) = default fromRightWithDefault _ (Right b) = b -fromLeftWithDefault : forall a b. a -> Either a b -> a -fromLeftWithDefault _ (Left a) = a -fromLeftWithDefault default (Right _) = default +-- fromLeftWithDefault : forall a b. a -> Either a b -> a +-- fromLeftWithDefault _ (Left a) = a +-- fromLeftWithDefault default (Right _) = default first : forall a b. Tuple a b -> a first (Tuple a b) = a -second : forall a b. Tuple a b -> b -second (Tuple a b) = b +-- second : forall a b. Tuple a b -> b +-- second (Tuple a b) = b -absurd : forall a. Empty -> a -absurd e = case e of +-- absurd : forall a. Empty -> a +-- absurd e = case e of -testData1 : Either Unit (Either Unit Unit) -testData1 = Right (Right Unit) +-- testData1 : Either Unit (Either Unit Unit) +-- testData1 = Right (Right Unit) -testFunction1 : forall a. Either a (Either a a) -> a -testFunction1 (Left a) = a -testFunction1 (Right (Left a)) = a -testFunction1 (Right (Right a)) = a +-- testFunction1 : forall a. Either a (Either a a) -> a +-- testFunction1 (Left a) = a +-- testFunction1 (Right (Left a)) = a +-- testFunction1 (Right (Right a)) = a -test1 = case testFunction1 testData1 of - Unit -> printInt 610 -- prints 610 +-- test1 = case testFunction1 testData1 of +-- Unit -> printInt 610 -- prints 610 -testData2 : Either Unit Unit -testData2 = Left Unit +-- testData2 : Either Unit Unit +-- testData2 = Left Unit -testFunction2 : Either Unit Unit -> Unit -testFunction2 (Left a) = a -testFunction2 (Right a) = a +-- testFunction2 : Either Unit Unit -> Unit +-- testFunction2 (Left a) = a +-- testFunction2 (Right a) = a -test2 = case testFunction2 testData2 of - Unit -> printInt 611 -- prints 611 +-- test2 = case testFunction2 testData2 of +-- Unit -> printInt 611 -- prints 611 -testData3 : Either Unit (Either Unit Unit) -testData3 = Right (Left Unit) +-- testData3 : Either Unit (Either Unit Unit) +-- testData3 = Right (Left Unit) -testFunction3 : Either Unit (Either Unit Unit) -> Unit -testFunction3 (Left a) = a -testFunction3 (Right (Left a)) = a -testFunction3 (Right (Right a)) = a +-- testFunction3 : Either Unit (Either Unit Unit) -> Unit +-- testFunction3 (Left a) = a +-- testFunction3 (Right (Left a)) = a +-- testFunction3 (Right (Right a)) = a -test3 = case testFunction3 testData3 of - Unit -> printInt 612 -- prints 612 +-- test3 = case testFunction3 testData3 of +-- Unit -> printInt 612 -- prints 612 -testData4 : Either Int (Tuple Unit Int) -testData4 = Right (Tuple Unit 613) +-- testData4 : Either Int (Tuple Unit Int) +-- testData4 = Right (Tuple Unit 613) -testFunction4 : Either Int (Tuple Unit Int) -> Int -testFunction4 (Left a) = a -testFunction4 (Right (Tuple Unit a)) = a +-- testFunction4 : Either Int (Tuple Unit Int) -> Int +-- testFunction4 (Left a) = a +-- testFunction4 (Right (Tuple Unit a)) = a -test4 = printInt (testFunction4 testData4) -- prints 613 +-- test4 = printInt (testFunction4 testData4) -- prints 613 -testData5 : Either Unit (Tuple Int Int) -testData5 = Right (Tuple 613 614) +-- testData5 : Either Unit (Tuple Int Int) +-- testData5 = Right (Tuple 613 614) -testFunction5 : Either Unit (Tuple Int Int) -> Int -testFunction5 (Left Unit) = 0 -testFunction5 (Right (Tuple 613 a)) = a -testFunction5 (Right (Tuple _ _)) = 0 +-- testFunction5 : Either Unit (Tuple Int Int) -> Int +-- testFunction5 (Left Unit) = 0 +-- testFunction5 (Right (Tuple 613 a)) = a +-- testFunction5 (Right (Tuple _ _)) = 0 -test5 = printInt (testFunction5 testData5) -- prints 614 +-- test5 = printInt (testFunction5 testData5) -- prints 614 -testData6 : Tuple (Either Unit Unit) (Either Int Int) -testData6 = Tuple (Left Unit) (Right 615) +-- testData6 : Tuple (Either Unit Unit) (Either Int Int) +-- testData6 = Tuple (Left Unit) (Right 615) -testFunction6 : Tuple (Either Unit Unit) (Either Int Int) -> Int -testFunction6 (Tuple (Left Unit) (Left _)) = 0 -testFunction6 (Tuple (Left Unit) (Right n)) = n -testFunction6 (Tuple (Right Unit) (Left _)) = 0 -testFunction6 (Tuple (Right Unit) (Right _)) = 0 +-- testFunction6 : Tuple (Either Unit Unit) (Either Int Int) -> Int +-- testFunction6 (Tuple (Left Unit) (Left _)) = 0 +-- testFunction6 (Tuple (Left Unit) (Right n)) = n +-- testFunction6 (Tuple (Right Unit) (Left _)) = 0 +-- testFunction6 (Tuple (Right Unit) (Right _)) = 0 -test6 = printInt (testFunction6 testData6) -- prints 615 +-- test6 = printInt (testFunction6 testData6) -- prints 615 testData7 : Tuple (Either Int (Tuple Int Int)) (Either Unit Int) testData7 = Tuple (Right (Tuple 1 2)) (Right 3) test7 = printInt (first (fromRightWithDefault (Tuple 0 0) (first testData7))) -- prints 1 -test8 = printInt (second (fromRightWithDefault (Tuple 0 0) (first testData7))) -- prints 2 -test9 = printInt (fromRightWithDefault 0 (second testData7)) -- prints 3 +-- test8 = printInt (second (fromRightWithDefault (Tuple 0 0) (first testData7))) -- prints 2 +-- test9 = printInt (fromRightWithDefault 0 (second testData7)) -- prints 3