Skip to content

Commit

Permalink
Support loading module from command line (microsoft#679)
Browse files Browse the repository at this point in the history
* Initial draft.

* Validate module paths.

* Added tests.

* Refactoring.

* Avoid loading assembly if signing check fails.

---------

Co-authored-by: Badrish Chandramouli <[email protected]>
  • Loading branch information
yrajas and badrishc authored Sep 26, 2024
1 parent 1218678 commit b70553b
Show file tree
Hide file tree
Showing 13 changed files with 222 additions and 62 deletions.
7 changes: 6 additions & 1 deletion libs/host/Configuration/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ internal sealed class Options
[Option("extension-bin-paths", Separator = ',', Required = false, HelpText = "List of directories on server from which custom command binaries can be loaded by admin users")]
public IEnumerable<string> ExtensionBinPaths { get; set; }

[ModuleFilePathValidation(true, true, false)]
[Option("loadmodulecs", Separator = ',', Required = false, HelpText = "List of modules to be loaded")]
public IEnumerable<string> LoadModuleCS { get; set; }

[Option("extension-allow-unsigned", Required = false, HelpText = "Allow loading custom commands from digitally unsigned assemblies (not recommended)")]
public bool? ExtensionAllowUnsignedAssemblies { get; set; }

Expand Down Expand Up @@ -653,7 +657,8 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null)
ExtensionBinPaths = ExtensionBinPaths?.ToArray(),
ExtensionAllowUnsignedAssemblies = ExtensionAllowUnsignedAssemblies.GetValueOrDefault(),
IndexResizeFrequencySecs = IndexResizeFrequencySecs,
IndexResizeThreshold = IndexResizeThreshold
IndexResizeThreshold = IndexResizeThreshold,
LoadModuleCS = LoadModuleCS
};
}

Expand Down
35 changes: 35 additions & 0 deletions libs/host/Configuration/OptionsValidators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,41 @@ protected override ValidationResult IsValid(object value, ValidationContext vali
}
}

[AttributeUsage(AttributeTargets.Property)]
internal class ModuleFilePathValidationAttribute : FilePathValidationAttribute
{
internal ModuleFilePathValidationAttribute(bool fileMustExist, bool directoryMustExist, bool isRequired, string[] acceptedFileExtensions = null) : base(fileMustExist, directoryMustExist, isRequired, acceptedFileExtensions)
{
}

protected override ValidationResult IsValid(object value, ValidationContext validationContext)
{
if (TryInitialValidation<IEnumerable<string>>(value, validationContext, out var initValidationResult, out var filePaths))
return initValidationResult;

var errorSb = new StringBuilder();
var isValid = true;
foreach (var filePathArg in filePaths)
{
var filePath = filePathArg.Split(' ')[0];
var result = base.IsValid(filePath, validationContext);
if (result != null && result != ValidationResult.Success)
{
isValid = false;
errorSb.AppendLine(result.ErrorMessage);
}
}

if (!isValid)
{
var errorMessage = $"Error(s) validating one or more file paths:{Environment.NewLine}{errorSb}";
return new ValidationResult(errorMessage, [validationContext.MemberName]);
}

return ValidationResult.Success;
}
}

/// <summary>
/// Validation logic for a string representing an IP address (either IPv4 or IPv6)
/// </summary>
Expand Down
28 changes: 28 additions & 0 deletions libs/host/GarnetServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using Garnet.cluster;
using Garnet.common;
Expand Down Expand Up @@ -211,6 +213,32 @@ private void InitializeServer()
Store = new StoreApi(storeWrapper);

server.Register(WireFormat.ASCII, Provider);

LoadModules(customCommandManager);
}

private void LoadModules(CustomCommandManager customCommandManager)
{
if (opts.LoadModuleCS == null)
return;

foreach (var moduleCS in opts.LoadModuleCS)
{
var moduleCSData = moduleCS.Split(' ', StringSplitOptions.RemoveEmptyEntries);
if (moduleCSData.Length < 1)
continue;

var modulePath = moduleCSData[0];
var moduleArgs = moduleCSData.Length > 1 ? moduleCSData.Skip(1).ToArray() : [];
if (ModuleUtils.LoadAssemblies([modulePath], null, true, out var loadedAssemblies, out var errorMsg))
{
ModuleRegistrar.Instance.LoadModule(customCommandManager, loadedAssemblies.ToList()[0], moduleArgs, logger, out errorMsg);
}
else
{
logger?.LogError("Module {0} failed to load with error {1}", modulePath, Encoding.UTF8.GetString(errorMsg));
}
}
}

private void CreateMainStore(IClusterFactory clusterFactory, out string checkpointDir)
Expand Down
3 changes: 3 additions & 0 deletions libs/host/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,7 @@

/* Overflow bucket count over total index size in percentage to trigger index resize */
"IndexResizeThreshold": 50,

/* List of module paths to be loaded at startup */
"LoadModuleCS": null
}
8 changes: 4 additions & 4 deletions libs/server/Module/ModuleRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using System.Reflection;
using Microsoft.Extensions.Logging;

namespace Garnet.server.Module
namespace Garnet.server
{
/// <summary>
/// Abstract base class that all Garnet modules must inherit from.
Expand Down Expand Up @@ -171,11 +171,11 @@ public ModuleActionStatus RegisterProcedure(string name, CustomProcedure customS
}
}

internal sealed class ModuleRegistrar
public sealed class ModuleRegistrar
{
private static readonly Lazy<ModuleRegistrar> lazy = new Lazy<ModuleRegistrar>(() => new ModuleRegistrar());

internal static ModuleRegistrar Instance { get { return lazy.Value; } }
public static ModuleRegistrar Instance { get { return lazy.Value; } }

private ModuleRegistrar()
{
Expand All @@ -184,7 +184,7 @@ private ModuleRegistrar()

private readonly ConcurrentDictionary<string, ModuleLoadContext> modules;

internal bool LoadModule(CustomCommandManager customCommandManager, Assembly loadedAssembly, string[] moduleArgs, ILogger logger, out ReadOnlySpan<byte> errorMessage)
public bool LoadModule(CustomCommandManager customCommandManager, Assembly loadedAssembly, string[] moduleArgs, ILogger logger, out ReadOnlySpan<byte> errorMessage)
{
errorMessage = default;

Expand Down
82 changes: 82 additions & 0 deletions libs/server/Module/ModuleUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Reflection.Metadata;
using System.Reflection.PortableExecutable;
using Garnet.common;

namespace Garnet.server
{
public class ModuleUtils
{
public static bool LoadAssemblies(
IEnumerable<string> binaryPaths,
string[] allowedExtensionPaths,
bool allowUnsignedAssemblies,
out IEnumerable<Assembly> loadedAssemblies,
out ReadOnlySpan<byte> errorMessage)
{
loadedAssemblies = null;
errorMessage = default;

// Get all binary file paths from inputs binary paths
if (!FileUtils.TryGetFiles(binaryPaths, out var files, out _, [".dll", ".exe"], SearchOption.AllDirectories))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_GETTING_BINARY_FILES;
return false;
}

// Check that all binary files are contained in allowed binary paths
var binaryFiles = files.ToArray();
if (allowedExtensionPaths != null)
{
if (binaryFiles.Any(f =>
allowedExtensionPaths.All(p => !FileUtils.IsFileInDirectory(f, p))))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_BINARY_FILES_NOT_IN_ALLOWED_PATHS;
return false;
}
}

// If necessary, check that all assemblies are digitally signed
if (!allowUnsignedAssemblies)
{
foreach (var filePath in files)
{
using var fs = File.OpenRead(filePath);
using var peReader = new PEReader(fs);

var metadataReader = peReader.GetMetadataReader();
var assemblyPublicKeyHandle = metadataReader.GetAssemblyDefinition().PublicKey;

if (assemblyPublicKeyHandle.IsNil)
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_ASSEMBLY_NOT_SIGNED;
return false;
}

var publicKeyBytes = metadataReader.GetBlobBytes(assemblyPublicKeyHandle);
if (publicKeyBytes == null || publicKeyBytes.Length == 0)
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_ASSEMBLY_NOT_SIGNED;
return false;
}
}
}

// Get all assemblies from binary files
if (!FileUtils.TryLoadAssemblies(binaryFiles, out loadedAssemblies, out _))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_LOADING_ASSEMBLIES;
return false;
}

return true;
}
}
}
54 changes: 4 additions & 50 deletions libs/server/Resp/AdminCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using Garnet.common;
using Garnet.server.Custom;
using Garnet.server.Module;

namespace Garnet.server
{
Expand Down Expand Up @@ -136,52 +134,6 @@ private bool NetworkMonitor()
return true;
}

private bool LoadAssemblies(IEnumerable<string> binaryPaths, out IEnumerable<Assembly> loadedAssemblies, out ReadOnlySpan<byte> errorMessage)
{
loadedAssemblies = null;
errorMessage = default;

// Get all binary file paths from inputs binary paths
if (!FileUtils.TryGetFiles(binaryPaths, out var files, out _, [".dll", ".exe"],
SearchOption.AllDirectories))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_GETTING_BINARY_FILES;
return false;
}

// Check that all binary files are contained in allowed binary paths
var binaryFiles = files.ToArray();
if (binaryFiles.Any(f =>
storeWrapper.serverOptions.ExtensionBinPaths.All(p => !FileUtils.IsFileInDirectory(f, p))))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_BINARY_FILES_NOT_IN_ALLOWED_PATHS;
return false;
}

// Get all assemblies from binary files
if (!FileUtils.TryLoadAssemblies(binaryFiles, out loadedAssemblies, out _))
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_LOADING_ASSEMBLIES;
return false;
}

// If necessary, check that all assemblies are digitally signed
if (!storeWrapper.serverOptions.ExtensionAllowUnsignedAssemblies)
{
foreach (var loadedAssembly in loadedAssemblies)
{
var publicKey = loadedAssembly.GetName().GetPublicKey();
if (publicKey == null || publicKey.Length == 0)
{
errorMessage = CmdStrings.RESP_ERR_GENERIC_ASSEMBLY_NOT_SIGNED;
return false;
}
}
}

return true;
}

/// <summary>
/// Register all custom commands / transactions
/// </summary>
Expand Down Expand Up @@ -231,7 +183,8 @@ private bool TryRegisterCustomCommands(
}
}

if (!LoadAssemblies(binaryPaths, out var loadedAssemblies, out errorMessage))
if (!ModuleUtils.LoadAssemblies(binaryPaths, storeWrapper.serverOptions.ExtensionBinPaths,
storeWrapper.serverOptions.ExtensionAllowUnsignedAssemblies, out var loadedAssemblies, out errorMessage))
return false;

foreach (var c in classNameToRegisterArgs.Keys)
Expand Down Expand Up @@ -488,7 +441,8 @@ private bool NetworkModuleLoad(CustomCommandManager customCommandManager)
for (var i = 0; i < moduleArgs.Length; i++)
moduleArgs[i] = parseState.GetArgSliceByRef(i + 1).ToString();

if (LoadAssemblies([modulePath], out var loadedAssemblies, out var errorMsg))
if (ModuleUtils.LoadAssemblies([modulePath], storeWrapper.serverOptions.ExtensionBinPaths,
storeWrapper.serverOptions.ExtensionAllowUnsignedAssemblies, out var loadedAssemblies, out var errorMsg))
{
Debug.Assert(loadedAssemblies != null);
var assembliesList = loadedAssemblies.ToList();
Expand Down
3 changes: 3 additions & 0 deletions libs/server/Servers/GarnetServerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license.

using System;
using System.Collections.Generic;
using System.IO;
using Garnet.server.Auth.Settings;
using Garnet.server.TLS;
Expand Down Expand Up @@ -368,6 +369,8 @@ public class GarnetServerOptions : ServerOptions
/// </summary>
public bool ExtensionAllowUnsignedAssemblies;

public IEnumerable<string> LoadModuleCS;

/// <summary>
/// Constructor
/// </summary>
Expand Down
1 change: 0 additions & 1 deletion playground/GarnetJSON/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license.

using Garnet.server;
using Garnet.server.Module;
using Microsoft.Extensions.Logging;

namespace GarnetJSON
Expand Down
1 change: 0 additions & 1 deletion playground/SampleModule/SampleModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using Garnet;
using Garnet.server;
using Garnet.server.Module;
using Microsoft.Extensions.Logging;

namespace SampleModule
Expand Down
4 changes: 3 additions & 1 deletion test/Garnet.test/GarnetServerConfigTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void ImportExportConfigLocal()
// No import path, include command line args, export to file
// Check values from command line override values from defaults.conf
static string GetFullExtensionBinPath(string testProjectName) => Path.GetFullPath(testProjectName, TestUtils.RootTestsProjectPath);
var args = new string[] { "--config-export-path", configPath, "-p", "4m", "-m", "128m", "-s", "2g", "--recover", "--port", "53", "--reviv-obj-bin-record-count", "2", "--reviv-fraction", "0.5", "--extension-bin-paths", $"{GetFullExtensionBinPath("Garnet.test")},{GetFullExtensionBinPath("Garnet.test.cluster")}" };
var args = new string[] { "--config-export-path", configPath, "-p", "4m", "-m", "128m", "-s", "2g", "--recover", "--port", "53", "--reviv-obj-bin-record-count", "2", "--reviv-fraction", "0.5", "--extension-bin-paths", $"{GetFullExtensionBinPath("Garnet.test")},{GetFullExtensionBinPath("Garnet.test.cluster")}", "--loadmodulecs", $"{Assembly.GetExecutingAssembly().Location}" };
parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out options, out invalidOptions);
ClassicAssert.IsTrue(parseSuccessful);
ClassicAssert.AreEqual(invalidOptions.Count, 0);
Expand All @@ -95,6 +95,8 @@ public void ImportExportConfigLocal()
ClassicAssert.IsTrue(options.Recover);
ClassicAssert.IsTrue(File.Exists(configPath));
ClassicAssert.AreEqual(2, options.ExtensionBinPaths.Count());
ClassicAssert.AreEqual(1, options.LoadModuleCS.Count());
ClassicAssert.AreEqual(Assembly.GetExecutingAssembly().Location, options.LoadModuleCS.First());

// Import from previous export command, no command line args
// Check values from import path override values from default.conf
Expand Down
Loading

0 comments on commit b70553b

Please sign in to comment.