Skip to content

Commit

Permalink
Add wildcard field support to KQL (elastic#1139)
Browse files Browse the repository at this point in the history
  • Loading branch information
rw-access authored Apr 22, 2021
1 parent cabe923 commit 8d8bcfb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kql/kql2eql.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def value(self, tree):
if eql.utils.is_string(value) and "*" in value:
return eql.ast.FunctionCall("wildcard", [field, value_ast])

if self.get_field_type(field_name) == "ip" and "/" in value:
if self.get_field_types(field_name) == {"ip"} and "/" in value:
return eql.ast.FunctionCall("cidrMatch", [field, value_ast])

return eql.ast.Comparison(field, "==", value_ast)
Expand Down
39 changes: 34 additions & 5 deletions kql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import os
import re
from typing import Optional, Set

import eql
from lark import Token # noqa: F401
Expand Down Expand Up @@ -37,6 +38,10 @@ def child_tokens(self):

lark_parser = Lark(grammar, propagate_positions=True, tree_class=KvTree, start=['query'], parser='lalr')

def wildcard2regex(wc: str) -> re.Pattern:
parts = wc.split("*")
return re.compile("^{regex}$".format(regex=".*?".join(re.escape(w) for w in parts)))


class BaseKqlParser(Interpreter):
NON_SPACE_WS = re.compile(r"[^\S ]+")
Expand All @@ -62,9 +67,7 @@ def __init__(self, text, schema=None):
if schema:
for field, field_type in schema.items():
if "*" in field:
parts = field.split("*")
pattern = re.compile("^{regex}$".format(regex=".*?".join(re.escape(w) for w in parts)))
self.star_fields.append(pattern)
self.star_fields.append(wildcard2regex(field))

def assert_lower_token(self, *tokens):
for token in tokens:
Expand Down Expand Up @@ -123,6 +126,24 @@ def get_field_type(self, dotted_path, lark_tree=None):

return self.mapping_schema.get(dotted_path)

def get_field_types(self, wildcard_dotted_path, lark_tree=None) -> Optional[Set[str]]:
if "*" not in wildcard_dotted_path:
field_type = self.get_field_type(wildcard_dotted_path, lark_tree=lark_tree)
return {field_type} if field_type is not None else None

if self.mapping_schema is not None:
regex = wildcard2regex(wildcard_dotted_path)
field_types = set()

for field, field_type in self.mapping_schema.items():
if regex.fullmatch(field) is not None:
field_types.add(field_type)

if len(field_types) == 0:
raise self.error(lark_tree, "Unknown field")

return field_types

@staticmethod
def get_literal_type(literal_value):
if isinstance(literal_value, bool):
Expand All @@ -140,9 +161,17 @@ def get_literal_type(literal_value):
raise NotImplementedError("Unknown literal type: {}".format(type(literal_value).__name__))

def convert_value(self, field_name, python_value, value_tree):
field_type = self.get_field_type(field_name)
field_type = None
field_types = self.get_field_types(field_name)
value_type = self.get_literal_type(python_value)

if field_types is not None:
if len(field_types) == 1:
field_type = list(field_types)[0]
elif len(field_types) > 1:
raise self.error(value_tree,
f"{field_name} has multiple types {', '.join(field_types)}")

if field_type is not None and field_type != value_type:
if field_type in STRING_FIELDS:
return eql.utils.to_unicode(python_value)
Expand Down Expand Up @@ -228,7 +257,7 @@ def field_value_expression(self, tree):

with self.scope(self.visit(field_tree)) as field:
# check the field against the schema
self.get_field_type(field.name, field_tree)
self.get_field_types(field.name, field_tree)
return FieldComparison(field, self.visit(expr))

def field_range_expression(self, tree):
Expand Down
8 changes: 8 additions & 0 deletions tests/kuery/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def test_list_equals(self):
def test_number_exists(self):
self.assertEqual(kql.parse("foo:*", schema={"foo": "long"}), FieldComparison(Field("foo"), Exists()))

def test_multiple_types_success(self):
schema = {"common.a": "keyword", "common.b": "keyword"}
self.validate("common.* : \"hello\"", FieldComparison(Field("common.*"), String("hello")), schema=schema)

def test_multiple_types_fail(self):
with self.assertRaises(kql.KqlParseError):
kql.parse("common.* : \"hello\"", schema={"common.a": "keyword", "common.b": "ip"})

def test_number_wildcard_fail(self):
with self.assertRaises(kql.KqlParseError):
kql.parse("foo:*wc", schema={"foo": "long"})
Expand Down

0 comments on commit 8d8bcfb

Please sign in to comment.