Skip to content

Commit

Permalink
Assisted inject for datanode migration context (#19926)
Browse files Browse the repository at this point in the history
* Assisted inject for datanode migration context

* add changelog

* code cleanup
  • Loading branch information
todvora authored Jul 18, 2024
1 parent 5a98253 commit 1fc7f56
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 90 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/pr-19926.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type = "f"
message = "Fix concurrency problem with datanode migration context injection"

pulls = ["19926"]
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
*/
package org.graylog.plugins.views.storage.migration;

import com.google.inject.assistedinject.FactoryModuleBuilder;
import org.graylog.plugins.views.storage.migration.state.actions.MigrationActions;
import org.graylog.plugins.views.storage.migration.state.actions.MigrationActionsFactory;
import org.graylog.plugins.views.storage.migration.state.actions.MigrationActionsImpl;
import org.graylog.plugins.views.storage.migration.state.machine.MigrationStateMachine;
import org.graylog.plugins.views.storage.migration.state.machine.MigrationStateMachineProvider;
Expand All @@ -29,8 +31,9 @@ public class DatanodeMigrationBindings extends Graylog2Module {
@Override
protected void configure() {
addSystemRestResource(MigrationStateResource.class);
bind(MigrationStateMachine.class).toProvider(MigrationStateMachineProvider.class);
bind(DatanodeMigrationPersistence.class).to(DatanodeMigrationConfigurationImpl.class);
bind(MigrationActions.class).to(MigrationActionsImpl.class);
install(new FactoryModuleBuilder().implement(MigrationActions.class, MigrationActionsImpl.class).build(
MigrationActionsFactory.class));
bind(MigrationStateMachine.class).toProvider(MigrationStateMachineProvider.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package org.graylog.plugins.views.storage.migration.state.actions;

import org.graylog.plugins.views.storage.migration.state.machine.MigrationStateMachineContext;

/**
* Set of callbacks used during the migration in different states.
*/
Expand Down Expand Up @@ -55,10 +53,6 @@ public interface MigrationActions {

boolean dataNodeStartupFinished();

void setStateMachineContext(MigrationStateMachineContext context);

MigrationStateMachineContext getStateMachineContext();

void startRemoteReindex();

void requestMigrationStatus();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (C) 2020 Graylog, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*/
package org.graylog.plugins.views.storage.migration.state.actions;

import org.graylog.plugins.views.storage.migration.state.machine.MigrationStateMachineContext;

public interface MigrationActionsFactory {
MigrationActions create(MigrationStateMachineContext context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.assistedinject.Assisted;
import jakarta.inject.Inject;
import jakarta.inject.Named;
import jakarta.inject.Singleton;
import jakarta.ws.rs.core.MultivaluedHashMap;
import org.graylog.plugins.views.storage.migration.state.machine.MigrationStateMachineContext;
import org.graylog.security.certutil.CaKeystore;
Expand Down Expand Up @@ -58,7 +58,6 @@
import java.util.Optional;
import java.util.stream.Collectors;

@Singleton
public class MigrationActionsImpl implements MigrationActions {
private static final Logger LOG = LoggerFactory.getLogger(MigrationActionsImpl.class);

Expand All @@ -68,7 +67,7 @@ public class MigrationActionsImpl implements MigrationActions {
private final CaKeystore caKeystore;
private final PreflightConfigService preflightConfigService;

private MigrationStateMachineContext stateMachineContext;
private final MigrationStateMachineContext stateMachineContext;
private final DataNodeCommandService dataNodeCommandService;

private final RemoteReindexingMigrationAdapter migrationService;
Expand All @@ -82,7 +81,8 @@ public class MigrationActionsImpl implements MigrationActions {
private final Version graylogVersion = Version.CURRENT_CLASSPATH;

@Inject
public MigrationActionsImpl(final ClusterConfigService clusterConfigService, NodeService<DataNodeDto> nodeService,
public MigrationActionsImpl(@Assisted MigrationStateMachineContext stateMachineContext,
final ClusterConfigService clusterConfigService, NodeService<DataNodeDto> nodeService,
final CaKeystore caKeystore, DataNodeCommandService dataNodeCommandService,
RemoteReindexingMigrationAdapter migrationService,
final ClusterProcessingControlFactory clusterProcessingControlFactory,
Expand All @@ -92,6 +92,7 @@ public MigrationActionsImpl(final ClusterConfigService clusterConfigService, Nod
ElasticsearchVersionProvider searchVersionProvider,
@Named("elasticsearch_hosts") List<URI> elasticsearchHosts,
final ObjectMapper objectMapper) {
this.stateMachineContext = stateMachineContext;
this.clusterConfigService = clusterConfigService;
this.nodeService = nodeService;
this.caKeystore = caKeystore;
Expand Down Expand Up @@ -119,9 +120,9 @@ public void runDirectoryCompatibilityCheck() {
return new CompatibilityResult(node.getHostname(), "unknown", new CompatibilityResult.IndexerDirectoryInformation(List.of(), "unknown"), List.of(e.getMessage()));
}
}).collect(Collectors.toList());
getStateMachineContext().addExtendedState(MigrationStateMachineContext.KEY_COMPATIBILITY_CHECK_PASSED,
stateMachineContext.addExtendedState(MigrationStateMachineContext.KEY_COMPATIBILITY_CHECK_PASSED,
results.stream().allMatch(r -> r.compatibilityErrors().isEmpty()));
getStateMachineContext().setResponse(results);
stateMachineContext.setResponse(results);
}

@Override
Expand All @@ -133,12 +134,12 @@ public boolean isOldClusterStopped() {
@Override
public void rollingUpgradeSelected() {
Counter traffic = (Counter) metricRegistry.getMetrics().get(GlobalMetricNames.INPUT_TRAFFIC);
getStateMachineContext().addExtendedState(TrafficSnapshot.TRAFFIC_SNAPSHOT, new TrafficSnapshot(traffic.getCount()));
stateMachineContext.addExtendedState(TrafficSnapshot.TRAFFIC_SNAPSHOT, new TrafficSnapshot(traffic.getCount()));
}

@Override
public boolean directoryCompatibilityCheckOk() {
return getStateMachineContext().getExtendedState(MigrationStateMachineContext.KEY_COMPATIBILITY_CHECK_PASSED, Boolean.class).orElse(false);
return stateMachineContext.getExtendedState(MigrationStateMachineContext.KEY_COMPATIBILITY_CHECK_PASSED, Boolean.class).orElse(false);
}

@Override
Expand Down Expand Up @@ -256,32 +257,32 @@ public boolean dataNodeStartupFinished() {

@Override
public void startRemoteReindex() {
final String allowlist = getStateMachineContext().getActionArgumentOpt("allowlist", String.class).orElse(null);
String host = StringUtils.requireNonBlank(getStateMachineContext().getActionArgument("hostname", String.class), "hostname has to be provided");
final String allowlist = stateMachineContext.getActionArgumentOpt("allowlist", String.class).orElse(null);
String host = StringUtils.requireNonBlank(stateMachineContext.getActionArgument("hostname", String.class), "hostname has to be provided");
if (host.endsWith("/")) {
host = host.substring(0, host.length() - 1);
}
final URI hostname = URI.create(host);
final String user = getStateMachineContext().getActionArgumentOpt("user", String.class).orElse(null);
final String password = getStateMachineContext().getActionArgumentOpt("password", String.class).orElse(null);
final List<String> indices = getStateMachineContext().getActionArgumentOpt("indices", List.class).orElse(Collections.emptyList()); // todo: generics!
final boolean trustUnknownCerts = getStateMachineContext().getActionArgumentOpt("trust_unknown_certs", Boolean.class).orElse(false);
final int threadsCount = getStateMachineContext().getActionArgumentOpt("threads", Integer.class).orElse(4);
final String user = stateMachineContext.getActionArgumentOpt("user", String.class).orElse(null);
final String password = stateMachineContext.getActionArgumentOpt("password", String.class).orElse(null);
final List<String> indices = stateMachineContext.getActionArgumentOpt("indices", List.class).orElse(Collections.emptyList()); // todo: generics!
final boolean trustUnknownCerts = stateMachineContext.getActionArgumentOpt("trust_unknown_certs", Boolean.class).orElse(false);
final int threadsCount = stateMachineContext.getActionArgumentOpt("threads", Integer.class).orElse(4);
final String migrationID = migrationService.start(new RemoteReindexRequest(allowlist, hostname, user, password, indices, threadsCount, trustUnknownCerts));
getStateMachineContext().addExtendedState(MigrationStateMachineContext.KEY_MIGRATION_ID, migrationID);
stateMachineContext.addExtendedState(MigrationStateMachineContext.KEY_MIGRATION_ID, migrationID);
}

@Override
public void requestMigrationStatus() {
getStateMachineContext().getExtendedState(MigrationStateMachineContext.KEY_MIGRATION_ID, String.class)
stateMachineContext.getExtendedState(MigrationStateMachineContext.KEY_MIGRATION_ID, String.class)
.map(migrationService::status)
.ifPresent(status -> getStateMachineContext().setResponse(status));
.ifPresent(status -> stateMachineContext.setResponse(status));
}

@Override
public void calculateTrafficEstimate() {
Counter currentTraffic = (Counter) metricRegistry.getMetrics().get(GlobalMetricNames.INPUT_TRAFFIC);
MigrationStateMachineContext context = getStateMachineContext();
MigrationStateMachineContext context = stateMachineContext;
if (context.getExtendedState(TrafficSnapshot.ESTIMATED_TRAFFIC_PER_MINUTE) == null) {
context.getExtendedState(TrafficSnapshot.TRAFFIC_SNAPSHOT, TrafficSnapshot.class)
.ifPresent(traffic -> context.addExtendedState(TrafficSnapshot.ESTIMATED_TRAFFIC_PER_MINUTE, traffic.calculateEstimatedTrafficPerMinute(currentTraffic.getCount())));
Expand All @@ -290,12 +291,12 @@ public void calculateTrafficEstimate() {

@Override
public void verifyRemoteIndexerConnection() {
final URI hostname = Objects.requireNonNull(URI.create(getStateMachineContext().getActionArgument("hostname", String.class)), "hostname has to be provided");
final String user = getStateMachineContext().getActionArgumentOpt("user", String.class).orElse(null);
final String password = getStateMachineContext().getActionArgumentOpt("password", String.class).orElse(null);
final boolean trustUnknownCerts = getStateMachineContext().getActionArgumentOpt("trust_unknown_certs", Boolean.class).orElse(false);
final String allowlist = getStateMachineContext().getActionArgumentOpt("allowlist", String.class).orElse(null);
getStateMachineContext().setResponse(migrationService.checkConnection(hostname, user, password, allowlist, trustUnknownCerts));
final URI hostname = Objects.requireNonNull(URI.create(stateMachineContext.getActionArgument("hostname", String.class)), "hostname has to be provided");
final String user = stateMachineContext.getActionArgumentOpt("user", String.class).orElse(null);
final String password = stateMachineContext.getActionArgumentOpt("password", String.class).orElse(null);
final boolean trustUnknownCerts = stateMachineContext.getActionArgumentOpt("trust_unknown_certs", Boolean.class).orElse(false);
final String allowlist = stateMachineContext.getActionArgumentOpt("allowlist", String.class).orElse(null);
stateMachineContext.setResponse(migrationService.checkConnection(hostname, user, password, allowlist, trustUnknownCerts));
}

@Override
Expand All @@ -305,28 +306,18 @@ public boolean isCompatibleInPlaceMigrationVersion() {

@Override
public void getElasticsearchHosts() {
getStateMachineContext().setResponse(Map.of(
stateMachineContext.setResponse(Map.of(
"elasticsearch_hosts", elasticsearchHosts.stream().map(URI::toString).collect(Collectors.joining(",")),
"allowlist_hosts", elasticsearchHosts.stream().map(host -> host.getHost() + ":" + host.getPort()).collect(Collectors.joining(","))
));
}

@Override
public boolean isRemoteReindexingFinished() {
return Optional.ofNullable(getStateMachineContext())
return Optional.ofNullable(stateMachineContext)
.flatMap(ctx -> ctx.getExtendedState(MigrationStateMachineContext.KEY_MIGRATION_ID, String.class))
.map(migrationService::status)
.filter(m -> m.status() == RemoteReindexingMigrationAdapter.Status.FINISHED)
.isPresent();
}

@Override
public void setStateMachineContext(MigrationStateMachineContext context) {
this.stateMachineContext = context;
}

@Override
public MigrationStateMachineContext getStateMachineContext() {
return stateMachineContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@

public class MigrationStateMachineImpl implements MigrationStateMachine {
private final StateMachine<MigrationState, MigrationStep> stateMachine;
private final MigrationActions migrationActions;
private final DatanodeMigrationPersistence persistenceService;
private MigrationStateMachineContext context;
private final MigrationStateMachineContext context;

public MigrationStateMachineImpl(StateMachine<MigrationState, MigrationStep> stateMachine, MigrationActions migrationActions, DatanodeMigrationPersistence persistenceService) {
public MigrationStateMachineImpl(
StateMachine<MigrationState, MigrationStep> stateMachine,
DatanodeMigrationPersistence persistenceService,
MigrationStateMachineContext context) {
this.stateMachine = stateMachine;
this.migrationActions = migrationActions;
this.persistenceService = persistenceService;
this.context = persistenceService.getStateMachineContext().orElse(new MigrationStateMachineContext());
migrationActions.setStateMachineContext(context);
this.context = context;
}

@Override
Expand All @@ -49,14 +49,12 @@ public CurrentStateInformation trigger(MigrationStep step, Map<String, Object> a
if (Objects.nonNull(args) && !args.isEmpty()) {
context.addActionArguments(step, args);
}
migrationActions.setStateMachineContext(context);
String errorMessage = null;
try {
stateMachine.fire(step);
} catch (Exception e) {
errorMessage = Objects.nonNull(e.getMessage()) ? e.getMessage() : e.toString();
}
context = migrationActions.getStateMachineContext();
persistenceService.saveStateMachineContext(context);
return new CurrentStateInformation(getState(), nextSteps(), errorMessage, context.getResponse());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,26 @@
import jakarta.inject.Provider;
import jakarta.inject.Singleton;
import org.graylog.plugins.views.storage.migration.state.actions.MigrationActions;
import org.graylog.plugins.views.storage.migration.state.actions.MigrationActionsFactory;
import org.graylog.plugins.views.storage.migration.state.persistence.DatanodeMigrationPersistence;

@Singleton
public class MigrationStateMachineProvider implements Provider<MigrationStateMachine> {

private final DatanodeMigrationPersistence persistenceService;
private final MigrationActions migrationActions;
private final MigrationActionsFactory migrationActionsFactory;

@Inject
public MigrationStateMachineProvider(DatanodeMigrationPersistence persistenceService, MigrationActions migrationActions) {
public MigrationStateMachineProvider(DatanodeMigrationPersistence persistenceService, MigrationActionsFactory migrationActionsFactory) {
this.persistenceService = persistenceService;
this.migrationActions = migrationActions;
this.migrationActionsFactory = migrationActionsFactory;
}

@Override
public MigrationStateMachine get() {
final MigrationStateMachineContext context = persistenceService.getStateMachineContext().orElseGet(MigrationStateMachineContext::new);
final MigrationActions migrationActions = migrationActionsFactory.create(context);
final StateMachine<MigrationState, MigrationStep> stateMachine = MigrationStateMachineBuilder.buildFromPersistedState(persistenceService, migrationActions);
return new MigrationStateMachineImpl(stateMachine, migrationActions, persistenceService);
return new MigrationStateMachineImpl(stateMachine, persistenceService, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ class MigrationStateMachineTest {
@Test
void testPersistence() {
final InMemoryStateMachinePersistence persistence = new InMemoryStateMachinePersistence();

final MigrationActionsAdapter migrationActions = new MigrationActionsAdapter() {};

final MigrationStateMachine migrationStateMachine = new MigrationStateMachineProvider(persistence, migrationActions).get();
final MigrationStateMachine migrationStateMachine = new MigrationStateMachineProvider(persistence, MigrationActionsAdapter::new).get();
migrationStateMachine.trigger(MigrationStep.SELECT_MIGRATION, Collections.emptyMap());


Expand All @@ -55,7 +52,7 @@ void testPersistence() {
@Test
void testReset() {
final InMemoryStateMachinePersistence persistence = new InMemoryStateMachinePersistence();
final MigrationStateMachineProvider provider = new MigrationStateMachineProvider(persistence, new MigrationActionsAdapter());
final MigrationStateMachineProvider provider = new MigrationStateMachineProvider(persistence, MigrationActionsAdapter::new);
final MigrationStateMachine sm = provider.get();
sm.trigger(MigrationStep.SELECT_MIGRATION, Collections.emptyMap());

Expand All @@ -69,7 +66,7 @@ void testReset() {

@Test
void testSerialization() {
final MigrationStateMachine migrationStateMachine = new MigrationStateMachineProvider(new InMemoryStateMachinePersistence(), new MigrationActionsAdapter()).get();
final MigrationStateMachine migrationStateMachine = new MigrationStateMachineProvider(new InMemoryStateMachinePersistence(), MigrationActionsAdapter::new).get();
final String serialized = migrationStateMachine.serialize();
Assertions.assertThat(serialized).isNotEmpty().startsWith("digraph G {");
final String fragment = URLEncoder.encode(serialized, StandardCharsets.UTF_8).replace("+", "%20");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,17 @@

public class MigrationActionsAdapter implements MigrationActions {

MigrationStateMachineContext context;
protected final MigrationStateMachineContext context;

public MigrationActionsAdapter() {
this.context = new MigrationStateMachineContext();
public MigrationActionsAdapter(MigrationStateMachineContext context) {
this.context = context;
}

@Override
public boolean dataNodeStartupFinished() {
return false;
}

public void setStateMachineContext(MigrationStateMachineContext context) {
this.context = context;
}

@Override
public MigrationStateMachineContext getStateMachineContext() {
return context;
}

@Override
public void startRemoteReindex() {

Expand Down
Loading

0 comments on commit 1fc7f56

Please sign in to comment.