From 95d6a95a25b39bee955219c603c3fbbd05531bd7 Mon Sep 17 00:00:00 2001 From: Aaron Stannard Date: Mon, 8 Jan 2024 17:17:38 -0600 Subject: [PATCH] AK2001: detect when automatically handled messages are being handled inside MessageExtractor / IMessageExtractor (#43) --- ...dledMessagesInsideMessageExtractorFixer.cs | 80 +++ .../Akka.Analyzers.Fixes.csproj.DotSettings | 2 + .../Akka.Analyzers.Tests.csproj | 3 + ...MessagesInMessageExtractorAnalyzerSpecs.cs | 350 ++++++++++++ ...ledMessagesInMessageExtractorFixerSpecs.cs | 514 ++++++++++++++++++ .../Utility/ReferenceAssembliesHelper.cs | 3 +- ...dMessagesInsideMessageExtractorAnalyzer.cs | 159 ++++++ .../Utility/AkkaClusterContext.cs | 65 +++ .../Utility/AkkaClusterShardingContext.cs | 84 +++ src/Akka.Analyzers/Utility/AkkaContext.cs | 32 +- src/Akka.Analyzers/Utility/RuleDescriptors.cs | 5 + .../Utility/TypeSymbolFactory.cs | 48 ++ 12 files changed, 1343 insertions(+), 2 deletions(-) create mode 100644 src/Akka.Analyzers.Fixes/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.cs create mode 100644 src/Akka.Analyzers.Tests/Analyzers/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorAnalyzerSpecs.cs create mode 100644 src/Akka.Analyzers.Tests/Fixes/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorFixerSpecs.cs create mode 100644 src/Akka.Analyzers/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorAnalyzer.cs create mode 100644 src/Akka.Analyzers/Utility/AkkaClusterContext.cs create mode 100644 src/Akka.Analyzers/Utility/AkkaClusterShardingContext.cs diff --git a/src/Akka.Analyzers.Fixes/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.cs b/src/Akka.Analyzers.Fixes/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.cs new file mode 100644 index 0000000..94282d2 --- /dev/null +++ b/src/Akka.Analyzers.Fixes/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.cs @@ -0,0 +1,80 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2013-2024 .NET Foundation +// +// ----------------------------------------------------------------------- + +using System.Composition; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.VisualBasic.Syntax; +using IfStatementSyntax = Microsoft.CodeAnalysis.CSharp.Syntax.IfStatementSyntax; + +namespace Akka.Analyzers.Fixes; + +[ExportCodeFixProvider(LanguageNames.CSharp)] +[Shared] +public class MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer() + : BatchedCodeFixProvider(RuleDescriptors.Ak2001DoNotUseAutomaticallyHandledMessagesInShardMessageExtractor.Id) +{ + public const string Key_FixAutomaticallyHandledShardedMessage = "AK2001_FixAutoShardMessage"; + + public override async Task RegisterCodeFixesAsync(CodeFixContext context) + { + var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + if (root is null) + return; + + var reportedNodes = new HashSet(); + + foreach (var diagnostic in context.Diagnostics) + { + var diagnosticSpan = diagnostic.Location.SourceSpan; + + // Find the token at the location of the diagnostic. + var token = root.FindToken(diagnosticSpan.Start); + + // Find the correct parent node to remove. + var nodeToRemove = FindParentNodeToRemove(token.Parent); + + // Check if the node has already been processed to avoid duplicates + if (nodeToRemove != null && reportedNodes.Add(nodeToRemove)) + { + context.RegisterCodeFix( + CodeAction.Create( + title: "Remove unnecessary message handling", + createChangedDocument: c => RemoveOffendingNode(context.Document, nodeToRemove, c), + equivalenceKey: Key_FixAutomaticallyHandledShardedMessage), + diagnostic); + } + } + } + + private static SyntaxNode? FindParentNodeToRemove(SyntaxNode? node) + { + while (node != null) + { + if (node is IfStatementSyntax || node is SwitchSectionSyntax || node is SwitchExpressionArmSyntax) + { + // special case - have to check for else if here + if(node.Parent is ElseClauseSyntax) + return node.Parent; + return node; + } + node = node.Parent; + } + return null; + } + + private static async Task RemoveOffendingNode(Document document, SyntaxNode nodeToRemove, CancellationToken cancellationToken) + { + var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + if(root == null) + return document; + + var newRoot = root.RemoveNode(nodeToRemove, SyntaxRemoveOptions.KeepNoTrivia); + return newRoot == null ? document : document.WithSyntaxRoot(newRoot); + } +} \ No newline at end of file diff --git a/src/Akka.Analyzers.Fixes/Akka.Analyzers.Fixes.csproj.DotSettings b/src/Akka.Analyzers.Fixes/Akka.Analyzers.Fixes.csproj.DotSettings index 6b4d807..5b2c6d7 100644 --- a/src/Akka.Analyzers.Fixes/Akka.Analyzers.Fixes.csproj.DotSettings +++ b/src/Akka.Analyzers.Fixes/Akka.Analyzers.Fixes.csproj.DotSettings @@ -1,2 +1,4 @@  + True + True True \ No newline at end of file diff --git a/src/Akka.Analyzers.Tests/Akka.Analyzers.Tests.csproj b/src/Akka.Analyzers.Tests/Akka.Analyzers.Tests.csproj index 58f8bbc..2db5c04 100644 --- a/src/Akka.Analyzers.Tests/Akka.Analyzers.Tests.csproj +++ b/src/Akka.Analyzers.Tests/Akka.Analyzers.Tests.csproj @@ -10,7 +10,10 @@ + + + diff --git a/src/Akka.Analyzers.Tests/Analyzers/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorAnalyzerSpecs.cs b/src/Akka.Analyzers.Tests/Analyzers/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorAnalyzerSpecs.cs new file mode 100644 index 0000000..d0631b6 --- /dev/null +++ b/src/Akka.Analyzers.Tests/Analyzers/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorAnalyzerSpecs.cs @@ -0,0 +1,350 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2013-2024 .NET Foundation +// +// ----------------------------------------------------------------------- + +using Microsoft.CodeAnalysis.Testing; +using Verify = Akka.Analyzers.Tests.Utility.AkkaVerifier; + +namespace Akka.Analyzers.Tests.Analyzers.AK2000; + +public class MustNotHandleAutomaticallyHandledMessagesInMessageExtractorAnalyzerSpecs +{ + public static readonly TheoryData SuccessCases = new() + { +""" +using Akka.Cluster.Sharding; + +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + return null; + } +} +""", +""" +using Akka.Cluster.Sharding; + +public class MsgExtractorCreator{ + IMessageExtractor Create(){ + IMessageExtractor messageExtractor = HashCodeMessageExtractor.Create(100, msg => + { + if (msg is string s) { + return s; + } + else{ + return null; + } + }); + + return messageExtractor; + } +} +""" + }; + + [Theory] + [MemberData(nameof(SuccessCases))] + public Task SuccessCase(string code) + { + return Verify.VerifyAnalyzer(code); + } + + public static readonly + TheoryData<(string testData, (int startLine, int startColumn, int endLine, int endColumn)[] spanData)> + FailureCases = new() + { + ( +// Simple message extractor edge case - using `if` statements +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + if (message is ShardingEnvelope e) + { + return e.EntityId; + } + + return null; + } +} +""", new[]{(18, 24, 18, 42)}), + ( +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.EntityId; + default: + return null; + } + } +} +""", new[]{(17, 18, 17, 36)}), + + // message extractor that uses a switch expression + ( +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + return message switch + { + string sharded => sharded, + ShardingEnvelope e => e.EntityId, + _ => null, + }; + } +} +""", new[]{(16, 13, 16, 31)}), + + // multiple violations (one in each method) + ( +""" +using Akka.Cluster.Sharding; + +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch (message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.EntityId; + default: + return null; + } + } + + public override object EntityMessage(object message) + { + switch (message) + { + case string sharded: + return sharded; + case ShardRegion.StartEntity e: + return e; + default: + return null; + } + } +} +""", + new[] + { + (18, 9, 18, 27), + (31, 9, 31, 34) + }), + + // combo mode - handle both types of forbidden messages in both methods + ( +""" +using Akka.Cluster.Sharding; + +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch (message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.EntityId; + case ShardRegion.StartEntity start: + return start.EntityId; + default: + return null; + } + } + + public override object EntityMessage(object message) + { + switch (message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.Message; + case ShardRegion.StartEntity start: + return start; + default: + return null; + } + } +} +""", + new[] + { + (18, 9, 18, 27), + (20, 9, 20, 38), + (33, 9, 33, 27), + (35, 9, 35, 38), + }), + + // custom IMessageExtractor implementation + ( +""" +using Akka.Cluster.Sharding; +using System; + +public sealed class ShardMessageExtractor : IMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor() + { + } + + public string EntityId(object message) + { + switch (message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.EntityId; + default: + return null; + } + } + + public object EntityMessage(object message) + { + switch (message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.Message; + default: + return null; + } + } + + public string ShardId(object message) + { + return Random.Shared.Next(0,10).ToString(); + } +} +""", new[] +{ + (19, 9, 19, 27), + (32, 9, 32, 27) +} + ), + + // message extractor created by HashCode.MessageExtractor delegate + ( +""" +using Akka.Cluster.Sharding; + +public class MsgExtractorCreator{ + IMessageExtractor Create(){ + IMessageExtractor messageExtractor = HashCodeMessageExtractor.Create(100, msg => + { + if (msg is string s) { + return s; + } + else if (msg is ShardingEnvelope shard) { + return shard.EntityId; + } + else{ + return null; + } + }); + + return messageExtractor; + } +} +""", new[] +{ + (10, 26, 10, 48) +}) + }; + + [Theory] + [MemberData(nameof(FailureCases))] + public async Task FailureCase((string testData, (int startLine, int startColumn, int endLine, int endColumn)[] spanData) d) + { + var (testData, spanData) = d; + var expectedDiagnostics = new DiagnosticResult[spanData.Length]; + var currentDiagnosticIndex = 0; + + // there can be multiple violations per test case + foreach (var (startLine, startColumn, endLine, endColumn) in spanData) + { + expectedDiagnostics[currentDiagnosticIndex++] = Verify.Diagnostic().WithSpan(startLine, startColumn, endLine, endColumn); + } + + await Verify.VerifyAnalyzer(testData, expectedDiagnostics).ConfigureAwait(true); + } +} \ No newline at end of file diff --git a/src/Akka.Analyzers.Tests/Fixes/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorFixerSpecs.cs b/src/Akka.Analyzers.Tests/Fixes/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorFixerSpecs.cs new file mode 100644 index 0000000..1c6d55c --- /dev/null +++ b/src/Akka.Analyzers.Tests/Fixes/AK2000/MustNotHandleAutomaticallyHandledMessagesInMessageExtractorFixerSpecs.cs @@ -0,0 +1,514 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2013-2024 .NET Foundation +// +// ----------------------------------------------------------------------- + +using Akka.Analyzers.Fixes; +using Verify = Akka.Analyzers.Tests.Utility.AkkaVerifier; + +namespace Akka.Analyzers.Tests.Fixes.AK2000; + +public class MustNotHandleAutomaticallyHandledMessagesInMessageExtractorFixerSpecs +{ + [Fact] + public Task RemoveIfStatementFromMessageExtractor() + { + var before = +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + if (message is ShardingEnvelope e) + { + return e.EntityId; + } + + return null; + } +} +"""; + + var after = +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + return null; + } +} +"""; + + var expectedDiagnostic = Verify.Diagnostic() + .WithSpan(18, 24, 18, 42); + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostic); + } + + [Fact] + public Task RemoveIfWithAdditionalFiltering() + { + var before = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + if (message is ShardingEnvelope e && e.EntityId.StartsWith("a")) + { + return e.EntityId; + } + + return null; + } + } + """; + + var after = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + return null; + } + } + """; + + var expectedDiagnostic = Verify.Diagnostic() + .WithSpan(18, 24, 18, 42); + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostic); + } + + [Fact] + public Task RemoveElseIfStatementFromMessageExtractor() + { + var before = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + else if (message is ShardingEnvelope e) + { + return e.EntityId; + } + + return null; + } + } + """; + + var after = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + if(message is string sharded) + { + return sharded; + } + + return null; + } + } + """; + + var expectedDiagnostic = Verify.Diagnostic() + .WithSpan(17, 29, 17, 47); + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostic); + } + + [Fact] + public Task RemoveCaseStatementFromMessageExtractorWithSwitch() + { + var before = +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.EntityId; + default: + return null; + } + } +} +"""; + + var after = +""" +using Akka.Cluster.Sharding; +public sealed class ShardMessageExtractor : HashCodeMessageExtractor +{ + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + default: + return null; + } + } +} +"""; + + var expectedDiagnostic = Verify.Diagnostic() + .WithSpan(17, 18, 17, 36); + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostic); + } + + [Fact] + public Task RemoveCaseStatementWithAdditionalFilteringFromMessageExtractorWithSwitch() + { + var before = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + case ShardingEnvelope e when e.EntityId.StartsWith("a"): + return e.EntityId; + default: + return null; + } + } + } + """; + + var after = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + default: + return null; + } + } + } + """; + + var expectedDiagnostic = Verify.Diagnostic() + .WithSpan(17, 18, 17, 36); + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostic); + } + + [Fact] + public Task RemoveTwoCaseStatementsFromMessageExtractorWithSwitch() + { + var before = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + case ShardingEnvelope e: + return e.EntityId; + case ShardRegion.StartEntity start: + return start.EntityId; + default: + return null; + } + } + } + """; + + var after = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + switch(message) + { + case string sharded: + return sharded; + default: + return null; + } + } + } + """; + + var expectedDiagnostics = new[] + { + Verify.Diagnostic() + .WithSpan(17, 18, 17, 36), + Verify.Diagnostic() + .WithSpan(19, 18, 19, 47), + }; + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostics); + } + + [Fact] + public Task RemoveSwitchExpressionArm() + { + var before = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + return message switch + { + string sharded => sharded, + ShardingEnvelope e => e.EntityId, + _ => null, + }; + } + } + """; + + var after = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + return message switch + { + string sharded => sharded, + _ => null, + }; + } + } + """; + + var expectedDiagnostic = Verify.Diagnostic() + .WithSpan(16, 13, 16, 31); + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostic); + } + + [Fact] + public Task RemoveTwoSwitchExpressionArms() + { + var before = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + return message switch + { + string sharded => sharded, + ShardingEnvelope e => e.EntityId, + ShardRegion.StartEntity start => start.EntityId, + _ => null, + }; + } + } + """; + + var after = + """ + using Akka.Cluster.Sharding; + public sealed class ShardMessageExtractor : HashCodeMessageExtractor + { + /// + /// We only ever run with a maximum of two nodes, so ~10 shards per node + /// + public ShardMessageExtractor(int shardCount = 20) : base(shardCount) + { + } + + public override string EntityId(object message) + { + return message switch + { + string sharded => sharded, + _ => null, + }; + } + } + """; + + var expectedDiagnostics = new[] + { + Verify.Diagnostic() + .WithSpan(16, 13, 16, 31), + Verify.Diagnostic() + .WithSpan(17, 13, 17, 42), + }; + + return Verify.VerifyCodeFix(before, after, MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorFixer.Key_FixAutomaticallyHandledShardedMessage, + expectedDiagnostics); + } +} \ No newline at end of file diff --git a/src/Akka.Analyzers.Tests/Utility/ReferenceAssembliesHelper.cs b/src/Akka.Analyzers.Tests/Utility/ReferenceAssembliesHelper.cs index b775544..0a14ebf 100644 --- a/src/Akka.Analyzers.Tests/Utility/ReferenceAssembliesHelper.cs +++ b/src/Akka.Analyzers.Tests/Utility/ReferenceAssembliesHelper.cs @@ -31,8 +31,9 @@ static ReferenceAssembliesHelper() Path.Combine("ref", "net8.0") ); + // TODO: does this bring all other transitive dependencies? CurrentAkka = defaultAssemblies.AddPackages( - [new PackageIdentity("Akka", "1.5.14")] + [new PackageIdentity("Akka", "1.5.14"), new PackageIdentity("Akka.Cluster.Sharding", "1.5.14")] ); } } \ No newline at end of file diff --git a/src/Akka.Analyzers/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorAnalyzer.cs b/src/Akka.Analyzers/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorAnalyzer.cs new file mode 100644 index 0000000..f496a0b --- /dev/null +++ b/src/Akka.Analyzers/AK2000/MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorAnalyzer.cs @@ -0,0 +1,159 @@ +// ----------------------------------------------------------------------- +// +// Copyright (C) 2013-2024 .NET Foundation +// +// ----------------------------------------------------------------------- + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Akka.Analyzers; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class MustNotUseAutomaticallyHandledMessagesInsideMessageExtractorAnalyzer() + : AkkaDiagnosticAnalyzer(RuleDescriptors.Ak2001DoNotUseAutomaticallyHandledMessagesInShardMessageExtractor) +{ + public override void AnalyzeCompilation(CompilationStartAnalysisContext context, AkkaContext akkaContext) + { + Guard.AssertIsNotNull(context); + Guard.AssertIsNotNull(akkaContext); + + context.RegisterSyntaxNodeAction(ctx => + { + if (akkaContext.HasAkkaClusterShardingInstalled == false) + return; // exit early if we don't have Akka.Cluster.Sharding installed + + AnalyzeMethodDeclaration(ctx, akkaContext); + }, SyntaxKind.MethodDeclaration); + + context.RegisterSyntaxNodeAction(ctx => + { + if (akkaContext.HasAkkaClusterShardingInstalled == false) + return; // exit early if we don't have Akka.Cluster.Sharding installed + + var invocationExpr = (InvocationExpressionSyntax)ctx.Node; + var semanticModel = ctx.SemanticModel; + if (semanticModel.GetSymbolInfo(invocationExpr).Symbol is not IMethodSymbol methodSymbol) + return; // couldn't find the symbol, bail out quickly + + var hashCodeMessageExtractorSymbol = + context.Compilation.GetTypeByMetadataName("Akka.Cluster.Sharding.HashCodeMessageExtractor"); + if (hashCodeMessageExtractorSymbol == null) + return; // couldn't find the type + + if (SymbolEqualityComparer.Default.Equals(methodSymbol.ContainingType, hashCodeMessageExtractorSymbol) && + methodSymbol is { IsStatic: true, Name: "Create" }) + { + + // we are invoking the HashCodeMessageExtractor.Create method if we've made it this far + AnalyzeLambdaExpressions(invocationExpr.ArgumentList.Arguments, ctx, akkaContext); + } + }, SyntaxKind.InvocationExpression); + } + + private static void AnalyzeLambdaExpressions(SeparatedSyntaxList argumentListArguments, SyntaxNodeAnalysisContext ctx, AkkaContext akkaContext) + { + var forbiddenTypes = GetForbiddenTypes(akkaContext); + var reportedLocations = new HashSet(); + + foreach (var argument in argumentListArguments) + { + // if the argument is a lambda expression, we need to analyze it + if (argument.Expression is LambdaExpressionSyntax lambdaExpression) + { + var descendantNodes = lambdaExpression.DescendantNodes(); + + foreach (var node in descendantNodes) + { + AnalyzeDeclaredVariableNodes(ctx, node, forbiddenTypes, reportedLocations); + } + } + } + } + + private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext ctx, AkkaContext akkaContext) + { + var methodDeclaration = (MethodDeclarationSyntax)ctx.Node; + var semanticModel = ctx.SemanticModel; + var methodSymbol = semanticModel.GetDeclaredSymbol(methodDeclaration); + + INamedTypeSymbol? messageExtractorSymbol = akkaContext.AkkaClusterSharding.IMessageExtractorType; + + if (methodSymbol == null || messageExtractorSymbol == null) + return; + + var containingTypeIsMessageExtractor = methodSymbol.ContainingType.AllInterfaces.Any(i => + SymbolEqualityComparer.Default.Equals(i, messageExtractorSymbol)); + + if (!containingTypeIsMessageExtractor) + return; + + var messageExtractorMethods = messageExtractorSymbol.GetMembers() + .OfType() + .Where(m => m.Name is "EntityMessage" or "EntityId") + .ToArray(); + + INamedTypeSymbol?[] forbiddenTypes = GetForbiddenTypes(akkaContext); + + var reportedLocations = new HashSet(); + + // we know for sure that we are inside a message extractor now + foreach (var interfaceMember in methodSymbol.ContainingType.AllInterfaces.SelectMany(i => + i.GetMembers().OfType())) + { + foreach (var extractorMethod in messageExtractorMethods) + { + if (SymbolEqualityComparer.Default.Equals(interfaceMember, extractorMethod)) + { + // Retrieve all the descendant nodes of the method that are expressions + var descendantNodes = methodDeclaration.DescendantNodes(); + + foreach (var node in descendantNodes) + { + AnalyzeDeclaredVariableNodes(ctx, node, forbiddenTypes, reportedLocations); + } + } + } + } + } + + private static INamedTypeSymbol?[] GetForbiddenTypes(AkkaContext akkaContext) + { + var forbiddenTypes = new[] + { akkaContext.AkkaClusterSharding.StartEntityType, akkaContext.AkkaClusterSharding.ShardEnvelopeType }; + return forbiddenTypes; + } + + private static void AnalyzeDeclaredVariableNodes(SyntaxNodeAnalysisContext ctx, SyntaxNode node, + INamedTypeSymbol?[] forbiddenTypes, HashSet reportedLocations) + { + var semanticModel = ctx.SemanticModel; + switch (node) + { + case DeclarationPatternSyntax declarationPatternSyntax: + { + // get the symbol for the declarationPatternSyntax.Type + var variableType = semanticModel.GetTypeInfo(declarationPatternSyntax.Type).Type; + + if (forbiddenTypes.Any(t => SymbolEqualityComparer.Default.Equals(t, variableType))) + { + var location = declarationPatternSyntax.GetLocation(); + + // duplicate + if (reportedLocations.Contains(location)) + break; + var diagnostic = Diagnostic.Create( + RuleDescriptors + .Ak2001DoNotUseAutomaticallyHandledMessagesInShardMessageExtractor, + location); + ctx.ReportDiagnostic(diagnostic); + reportedLocations.Add(location); + } + + break; + } + } + } +} \ No newline at end of file diff --git a/src/Akka.Analyzers/Utility/AkkaClusterContext.cs b/src/Akka.Analyzers/Utility/AkkaClusterContext.cs new file mode 100644 index 0000000..9ccdfc1 --- /dev/null +++ b/src/Akka.Analyzers/Utility/AkkaClusterContext.cs @@ -0,0 +1,65 @@ +// ----------------------------------------------------------------------- +// <copyright file="AkkaClusterContext.cs" company="Akka.NET Project"> +// Copyright (C) 2009-2024 Lightbend Inc. <http://www.lightbend.com> +// Copyright (C) 2013-2024 .NET Foundation <https://github.com/akkadotnet/akka.net> +// </copyright> +// ----------------------------------------------------------------------- + +using Microsoft.CodeAnalysis; + +namespace Akka.Analyzers; + +public interface IAkkaClusterContext +{ + Version Version { get; } + + INamedTypeSymbol? ClusterType { get; } +} + +public sealed class EmptyClusterContext : IAkkaClusterContext +{ + private EmptyClusterContext() + { + } + + public static EmptyClusterContext Instance { get; } = new(); + + public Version Version { get; } = new(); + public INamedTypeSymbol? ClusterType => null; +} + +/// +/// Default AkkaClusterContext. +/// +/// +/// Used to indicate whether or not Akka.Cluster is present inside the solution being scanned and +/// provides access to some of the built-in type symbols that are used in analysis rules. +/// +public sealed class AkkaClusterContext : IAkkaClusterContext +{ + private readonly Lazy _lazyClusterType; + + private AkkaClusterContext(Compilation compilation, Version version) + { + Version = version; + _lazyClusterType = new Lazy(() => TypeSymbolFactory.AkkaCluster(compilation)); + } + + public static IAkkaClusterContext? Get(Compilation compilation, Version? versionOverride = null) + { + // assert that compilation is not null + Guard.AssertIsNotNull(compilation); + + var version = + versionOverride ?? + compilation + .ReferencedAssemblyNames + .FirstOrDefault(a => a.Name.Equals("Akka.Cluster", StringComparison.OrdinalIgnoreCase)) + ?.Version; + + return version is null ? null : new AkkaClusterContext(compilation, version); + } + + public Version Version { get; } + public INamedTypeSymbol? ClusterType => _lazyClusterType.Value; +} \ No newline at end of file diff --git a/src/Akka.Analyzers/Utility/AkkaClusterShardingContext.cs b/src/Akka.Analyzers/Utility/AkkaClusterShardingContext.cs new file mode 100644 index 0000000..0cdb12b --- /dev/null +++ b/src/Akka.Analyzers/Utility/AkkaClusterShardingContext.cs @@ -0,0 +1,84 @@ +// ----------------------------------------------------------------------- +// <copyright file="AkkaClusterShardingContext.cs" company="Akka.NET Project"> +// Copyright (C) 2013-2024 .NET Foundation <https://github.com/akkadotnet/akka.net> +// </copyright> +// ----------------------------------------------------------------------- + +using Microsoft.CodeAnalysis; + +namespace Akka.Analyzers; + +/// +/// Data about the Akka.Cluster.Sharding assembly in the solution being analyzed. +/// +public interface IAkkaClusterShardingContext +{ + Version Version { get; } + + INamedTypeSymbol? ClusterShardingType { get; } + + INamedTypeSymbol? IMessageExtractorType { get; } + + INamedTypeSymbol? ShardEnvelopeType { get; } + + INamedTypeSymbol? StartEntityType { get; } +} + +/// +/// INTERNAL API +/// +public sealed class EmptyAkkaClusterShardingContext : IAkkaClusterShardingContext +{ + private EmptyAkkaClusterShardingContext() + { + } + + public static EmptyAkkaClusterShardingContext Instance { get; } = new(); + + public Version Version { get; } = new(); + public INamedTypeSymbol? ClusterShardingType => null; + public INamedTypeSymbol? IMessageExtractorType => null; + public INamedTypeSymbol? ShardEnvelopeType => null; + public INamedTypeSymbol? StartEntityType => null; +} + +/// +/// INTERNAL API +/// +public sealed class AkkaClusterShardingContext : IAkkaClusterShardingContext +{ + private readonly Lazy _lazyClusterShardingType; + private readonly Lazy _lazyMessageExtractorType; + private readonly Lazy _lazyShardEnvelopeType; + private readonly Lazy _lazyStartEntityType; + + public Version Version { get; } + public INamedTypeSymbol? ClusterShardingType => _lazyClusterShardingType.Value; + public INamedTypeSymbol? IMessageExtractorType => _lazyMessageExtractorType.Value; + public INamedTypeSymbol? ShardEnvelopeType => _lazyShardEnvelopeType.Value; + public INamedTypeSymbol? StartEntityType => _lazyStartEntityType.Value; + + private AkkaClusterShardingContext(Compilation compilation, Version version) + { + Version = version; + _lazyClusterShardingType = new Lazy(() => TypeSymbolFactory.AkkaClusterSharding(compilation)); + _lazyMessageExtractorType = new Lazy(() => TypeSymbolFactory.AkkaMessageExtractor(compilation)); + _lazyShardEnvelopeType = new Lazy(() => TypeSymbolFactory.AkkaShardEnvelope(compilation)); + _lazyStartEntityType = new Lazy(() => TypeSymbolFactory.AkkaStartEntity(compilation)); + } + + public static IAkkaClusterShardingContext? Get(Compilation compilation, Version? versionOverride = null) + { + // assert that compilation is not null + Guard.AssertIsNotNull(compilation); + + var version = + versionOverride ?? + compilation + .ReferencedAssemblyNames + .FirstOrDefault(a => a.Name.Equals("Akka.Cluster.Sharding", StringComparison.OrdinalIgnoreCase)) + ?.Version; + + return version is null ? null : new AkkaClusterShardingContext(compilation, version); + } +} \ No newline at end of file diff --git a/src/Akka.Analyzers/Utility/AkkaContext.cs b/src/Akka.Analyzers/Utility/AkkaContext.cs index 4b0b652..e49c956 100644 --- a/src/Akka.Analyzers/Utility/AkkaContext.cs +++ b/src/Akka.Analyzers/Utility/AkkaContext.cs @@ -12,9 +12,11 @@ namespace Akka.Analyzers; /// Provides information about the Akka.NET context (i.e. which libraries, which versions) in which the analyzer is /// running. /// -public class AkkaContext +public sealed class AkkaContext { private IAkkaCoreContext? _akkaCore; + private IAkkaClusterContext? _akkaCluster; + private IAkkaClusterShardingContext? _akkaClusterSharding; /// /// Initializes a new instance of the class. @@ -26,6 +28,8 @@ public class AkkaContext public AkkaContext(Compilation compilation) { _akkaCore = AkkaCoreContext.Get(compilation); + _akkaCluster = AkkaClusterContext.Get(compilation); + _akkaClusterSharding = AkkaClusterShardingContext.Get(compilation); } private AkkaContext() @@ -44,4 +48,30 @@ public IAkkaCoreContext AkkaCore /// Does the current compilation context even have Akka.NET installed? /// public bool HasAkkaInstalled => AkkaCore != EmptyCoreContext.Instance; + + /// + /// Symbol data and availability for Akka.Cluster. + /// + public IAkkaClusterContext AkkaCluster + { + get { return _akkaCluster ??= EmptyClusterContext.Instance; } + } + + /// + /// Does the current compilation context have Akka.Cluster installed? + /// + public bool HasAkkaClusterInstalled => AkkaCluster != EmptyClusterContext.Instance; + + /// + /// Symbol data and availability for Akka.Cluster.Sharding. + /// + public IAkkaClusterShardingContext AkkaClusterSharding + { + get { return _akkaClusterSharding ??= EmptyAkkaClusterShardingContext.Instance; } + } + + /// + /// Does the current compilation context have Akka.Cluster.Sharding installed? + /// + public bool HasAkkaClusterShardingInstalled => AkkaClusterSharding != EmptyAkkaClusterShardingContext.Instance; } \ No newline at end of file diff --git a/src/Akka.Analyzers/Utility/RuleDescriptors.cs b/src/Akka.Analyzers/Utility/RuleDescriptors.cs index 7781e31..3c10ec5 100644 --- a/src/Akka.Analyzers/Utility/RuleDescriptors.cs +++ b/src/Akka.Analyzers/Utility/RuleDescriptors.cs @@ -41,6 +41,11 @@ private static DiagnosticDescriptor Rule( public static DiagnosticDescriptor Ak2000DoNotUseZeroTimeoutWithAsk { get; } = Rule("AK2000", "Do not use `Ask` with `TimeSpan.Zero` for timeout.", AnalysisCategory.ApiUsage, DiagnosticSeverity.Error, "When using `Ask`, you must always specify a timeout value greater than `TimeSpan.Zero`."); + + public static DiagnosticDescriptor Ak2001DoNotUseAutomaticallyHandledMessagesInShardMessageExtractor { get; } = Rule("AK2001", + "Do not use automatically handled messages in inside `Akka.Cluster.Sharding.IMessageExtractor`s.", AnalysisCategory.ApiUsage, DiagnosticSeverity.Warning, + "When using any implementation of `Akka.Cluster.Sharding.IMessageExtractor`, including `HashCodeMessageExtractor`, you should not use messages " + + "that are automatically handled by Akka.NET such as `Shard.StartEntity` and `ShardingEnvelope`."); #endregion diff --git a/src/Akka.Analyzers/Utility/TypeSymbolFactory.cs b/src/Akka.Analyzers/Utility/TypeSymbolFactory.cs index 9ef7178..73eb3e7 100644 --- a/src/Akka.Analyzers/Utility/TypeSymbolFactory.cs +++ b/src/Akka.Analyzers/Utility/TypeSymbolFactory.cs @@ -39,4 +39,52 @@ public static class TypeSymbolFactory return Guard.AssertIsNotNull(compilation) .GetTypeByMetadataName("Akka.Actor.IIndirectActorProducer"); } + + public static INamedTypeSymbol? AkkaCluster(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Cluster"); + } + + public static INamedTypeSymbol? AkkaClusterSingletonManager(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Tools.Singleton.ClusterSingletonManager"); + } + + public static INamedTypeSymbol? AkkaClusterSingletonProxy(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Tools.Singleton.ClusterSingletonProxy"); + } + + public static INamedTypeSymbol? AkkaClusterClient(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Tools.Client.ClusterClient"); + } + + public static INamedTypeSymbol? AkkaClusterSharding(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Sharding.ClusterSharding"); + } + + public static INamedTypeSymbol? AkkaMessageExtractor(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Sharding.IMessageExtractor"); + } + + public static INamedTypeSymbol? AkkaShardEnvelope(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Sharding.ShardingEnvelope"); + } + + public static INamedTypeSymbol? AkkaStartEntity(Compilation compilation) + { + return Guard.AssertIsNotNull(compilation) + .GetTypeByMetadataName("Akka.Cluster.Sharding.ShardRegion+StartEntity"); + } } \ No newline at end of file