Skip to content

Commit

Permalink
wip: pass query field names in doc level queries during monitor creat…
Browse files Browse the repository at this point in the history
…ion/updation

Signed-off-by: Surya Sashank Nistala <[email protected]>
  • Loading branch information
eirsep committed Jan 31, 2024
1 parent 0bbdb31 commit 86f956e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 27 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ dependencies {
implementation group: 'org.apache.commons', name: 'commons-lang3', version: "${versions.commonslang}"
implementation "org.antlr:antlr4-runtime:4.10.1"
implementation "com.cronutils:cron-utils:9.1.6"
api "org.opensearch:common-utils:${common_utils_version}@jar"
api files("/Users/snistala/Documents/opensearch/common-utils/build/libs/common-utils-3.0.0.0-SNAPSHOT.jar")
api "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
implementation "org.jetbrains.kotlin:kotlin-stdlib:${kotlin_version}"
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

package org.opensearch.securityanalytics.mapper;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;

public class MapperUtils {

Expand Down Expand Up @@ -246,7 +247,6 @@ public void onError(String error) {
}
});
mappingsTraverser.traverse();

return presentPathsMappings;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,12 @@ public Object convertConditionFieldEqValQueryExpr(ConditionFieldEqualsValueExpre

@Override
public Object convertConditionValStr(ConditionValueExpression condition) throws SigmaValueError {
String field = getFinalValueField();
ruleQueryFields.put(field, Map.of("type", "text", "analyzer", "rule_analyzer"));
SigmaString value = (SigmaString) condition.getValue();
boolean containsWildcard = value.containsWildcard();
return String.format(Locale.getDefault(), (containsWildcard? this.unboundWildcardExpression: this.unboundValueStrExpression), this.convertValueStr((SigmaString) condition.getValue()));
return String.format(Locale.getDefault(), (containsWildcard? this.unboundWildcardExpression: this.unboundValueStrExpression),
this.convertValueStr((SigmaString) condition.getValue()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
Expand Down Expand Up @@ -249,7 +251,8 @@ public void onFailure(Exception e) {
});
}

private void createMonitorFromQueries(List<Pair<String, Rule>> rulesById, Detector detector, ActionListener<List<IndexMonitorResponse>> listener, WriteRequest.RefreshPolicy refreshPolicy) {
private void createMonitorFromQueries(List<Pair<String, Rule>> rulesById, Detector detector, ActionListener<List<IndexMonitorResponse>> listener, WriteRequest.RefreshPolicy refreshPolicy,
List<String> queryFieldNames) {
List<Pair<String, Rule>> docLevelRules = rulesById.stream().filter(it -> !it.getRight().isAggregationRule()).collect(
Collectors.toList());
List<Pair<String, Rule>> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect(
Expand All @@ -262,7 +265,7 @@ public void onResponse(List<DocLevelQuery> dlqs) {
List<IndexMonitorRequest> monitorRequests = new ArrayList<>();

if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) {
monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, dlqs != null ? dlqs : List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST));
monitorRequests.add(createDocLevelMonitorRequest(docLevelRules, dlqs != null ? dlqs : List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames));
}

if (!bucketLevelRules.isEmpty()) {
Expand Down Expand Up @@ -416,7 +419,12 @@ public void onFailure(Exception e) {
}
}

private void updateMonitorFromQueries(String index, List<Pair<String, Rule>> rulesById, Detector detector, ActionListener<List<IndexMonitorResponse>> listener, WriteRequest.RefreshPolicy refreshPolicy) throws Exception {
private void updateMonitorFromQueries(String index,
List<Pair<String, Rule>> rulesById,
Detector detector,
ActionListener<List<IndexMonitorResponse>> listener,
WriteRequest.RefreshPolicy refreshPolicy,
List<String> queryFieldNames) {
List<IndexMonitorRequest> monitorsToBeUpdated = new ArrayList<>();

List<Pair<String, Rule>> bucketLevelRules = rulesById.stream().filter(it -> it.getRight().isAggregationRule()).collect(
Expand Down Expand Up @@ -472,9 +480,9 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
// Process doc level monitors
if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) {
if (detector.getDocLevelMonitorId() == null) {
monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST));
monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames));
} else {
monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT));
monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT, queryFieldNames));
}
}

Expand All @@ -500,9 +508,9 @@ public void onFailure(Exception e) {
// Process doc level monitors
if (!docLevelRules.isEmpty() || detector.getThreatIntelEnabled()) {
if (detector.getDocLevelMonitorId() == null) {
monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST));
monitorsToBeAdded.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, Monitor.NO_ID, Method.POST, queryFieldNames));
} else {
monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT));
monitorsToBeUpdated.add(createDocLevelMonitorRequest(docLevelRules, docLevelQueries != null? docLevelQueries: List.of(), detector, refreshPolicy, detector.getDocLevelMonitorId(), Method.PUT, queryFieldNames));
}
}

Expand Down Expand Up @@ -663,7 +671,7 @@ public void onFailure(Exception e) {
}
}

private IndexMonitorRequest createDocLevelMonitorRequest(List<Pair<String, Rule>> queries, List<DocLevelQuery> threatIntelQueries, Detector detector, WriteRequest.RefreshPolicy refreshPolicy, String monitorId, RestRequest.Method restMethod) {
private IndexMonitorRequest createDocLevelMonitorRequest(List<Pair<String, Rule>> queries, List<DocLevelQuery> threatIntelQueries, Detector detector, RefreshPolicy refreshPolicy, String monitorId, Method restMethod, List<String> queryFieldNames) {
List<DocLevelMonitorInput> docLevelMonitorInputs = new ArrayList<>();

List<DocLevelQuery> docLevelQueries = new ArrayList<>();
Expand All @@ -673,15 +681,14 @@ private IndexMonitorRequest createDocLevelMonitorRequest(List<Pair<String, Rule>

Rule rule = query.getRight();
String name = query.getLeft();

String actualQuery = rule.getQueries().get(0).getValue();

List<String> tags = new ArrayList<>();
tags.add(rule.getLevel());
tags.add(rule.getCategory());
tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList()));

DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, Collections.emptyList(), actualQuery, tags);
DocLevelQuery docLevelQuery = new DocLevelQuery(id, name, Collections.emptyList(), actualQuery, tags, queryFieldNames);
docLevelQueries.add(docLevelQuery);
}
docLevelQueries.addAll(threatIntelQueries);
Expand Down Expand Up @@ -1389,11 +1396,7 @@ public void onResponse(SearchResponse response) {
} else if (detectorInput.getCustomRules().size() > 0) {
onFailures(new OpenSearchStatusException("Custom Rule Index not found", RestStatus.NOT_FOUND));
} else {
if (request.getMethod() == RestRequest.Method.POST) {
createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy());
} else if (request.getMethod() == RestRequest.Method.PUT) {
updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy());
}
upsertMonitorFromQueries(queries, detector, logIndex, listener);
}
} catch (Exception e) {
onFailures(e);
Expand All @@ -1407,6 +1410,53 @@ public void onFailure(Exception e) {
});
}

private void upsertMonitorFromQueries(List<Pair<String, Rule>> queries, Detector detector, String logIndex, ActionListener<List<IndexMonitorResponse>> listener) throws Exception {
Set<String> ruleFieldNames = new HashSet<>();
for (Pair<String, Rule> query : queries) {
List<String> queryFieldNames = query.getValue().getQueryFieldNames().stream().map(Value::getValue).collect(Collectors.toList());
ruleFieldNames.addAll(queryFieldNames);
}

CountDownLatch indexMappingsLatch = new CountDownLatch(1);
client.execute(GetIndexMappingsAction.INSTANCE, new GetIndexMappingsRequest(logIndex), new ActionListener<>() {
@Override
public void onResponse(GetIndexMappingsResponse getMappingsViewResponse) {
try {
List<Pair<String, String>> aliasPathPairs;

aliasPathPairs = MapperUtils.getAllAliasPathPairs(getMappingsViewResponse.getMappings().get(logIndex));
for (Pair<String, String> aliasPathPair : aliasPathPairs) {
if (ruleFieldNames.contains(aliasPathPair.getLeft())) {
ruleFieldNames.remove(aliasPathPair.getLeft());
ruleFieldNames.add(aliasPathPair.getRight());
}
}
} catch (Exception e) {
logger.error("Failure in parsing rule field names/aliases while " +
detector.getId() == null ? "creating" : "updating" +
" detector. Not optimizing detector queries with relevant fields", e);
ruleFieldNames.clear();
} finally {
indexMappingsLatch.countDown();
}

}

@Override
public void onFailure(Exception e) {
log.error("Failed to fetch mappings view response for log index " + logIndex, e);
listener.onFailure(e);
indexMappingsLatch.countDown();
}
});
indexMappingsLatch.await();
if (request.getMethod() == Method.POST) {
createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy(), new ArrayList<>(ruleFieldNames));
} else if (request.getMethod() == Method.PUT) {
updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy(), new ArrayList<>(ruleFieldNames));
}
}

@SuppressWarnings("unchecked")
public void importCustomRules(Detector detector, DetectorInput detectorInput, List<Pair<String, Rule>> queries, ActionListener<List<IndexMonitorResponse>> listener) {
final String logIndex = detectorInput.getIndices().get(0);
Expand Down Expand Up @@ -1443,11 +1493,7 @@ public void onResponse(SearchResponse response) {
queries.add(Pair.of(id, rule));
}

if (request.getMethod() == RestRequest.Method.POST) {
createMonitorFromQueries(queries, detector, listener, request.getRefreshPolicy());
} else if (request.getMethod() == RestRequest.Method.PUT) {
updateMonitorFromQueries(logIndex, queries, detector, listener, request.getRefreshPolicy());
}
upsertMonitorFromQueries(queries, detector, logIndex, listener);
} catch (Exception ex) {
onFailures(ex);
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/resources/mappings/finding_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
}
}
},
"query_field_names": {
"type": "keyword"
},
"fields": {
"type": "text"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2056,7 +2056,7 @@ public void testCreateDetectorWithCloudtrailAggrRuleWithEcsFields() throws IOExc
// both req params and req body are supported
createMappingRequest.setJsonEntity(
"{\n" +
" \"index_name\": \"" + index + "\",\n" +
" \"index_name\": \"cloudtrail\",\n" +
" \"rule_topic\": \"cloudtrail\",\n" +
" \"partial\": true,\n" +
" \"alias_mappings\": {\n" +
Expand Down

0 comments on commit 86f956e

Please sign in to comment.