Skip to content

Commit

Permalink
Add support for TLS protocol tracing
Browse files Browse the repository at this point in the history
Signed-off-by: Dom Del Nano <[email protected]>
  • Loading branch information
ddelnano committed Jan 29, 2025
1 parent 845b3d5 commit 3e556d6
Show file tree
Hide file tree
Showing 19 changed files with 404 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/shared/protocols/protocols.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class Protocol {
kKafka = 10,
kMux = 11,
kAMQP = 12,
kTLS = 13,
};

} // namespace protocols
Expand Down
2 changes: 1 addition & 1 deletion src/stirling/binaries/stirling_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DEFINE_string(trace, "",
"Dynamic trace to deploy. Either (1) the path to a file containing PxL or IR trace "
"spec, or (2) <path to object file>:<symbol_name> for full-function tracing.");
DEFINE_string(print_record_batches,
"http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events",
"http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events,tls_events",
"Comma-separated list of tables to print.");
DEFINE_bool(init_only, false, "If true, only runs the init phase and exits. For testing.");
DEFINE_int32(timeout_secs, -1,
Expand Down
21 changes: 21 additions & 0 deletions src/stirling/source_connectors/socket_tracer/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,27 @@ pl_cc_bpf_test(
],
)

pl_cc_bpf_test(
name = "tls_trace_bpf_test",
timeout = "long",
srcs = ["tls_trace_bpf_test.cc"],
flaky = True,
shard_count = 2,
tags = [
"cpu:16",
"no_asan",
"requires_bpf",
],
deps = [
":cc_library",
"//src/common/testing/test_utils:cc_library",
"//src/stirling/source_connectors/socket_tracer/testing:cc_library",
"//src/stirling/source_connectors/socket_tracer/testing/container_images:curl_container",
"//src/stirling/source_connectors/socket_tracer/testing/container_images:nginx_openssl_3_0_8_container",
"//src/stirling/testing:cc_library",
],
)

pl_cc_bpf_test(
name = "dyn_lib_trace_bpf_test",
timeout = "moderate",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pl_cc_test(
"ENABLE_NATS_TRACING=true",
"ENABLE_MONGO_TRACING=true",
"ENABLE_AMQP_TRACING=true",
"ENABLE_TLS_TRACING=true",
],
deps = [
"//src/stirling/bpf_tools/bcc_bpf:headers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,46 @@ static __inline enum message_type_t infer_http_message(const char* buf, size_t c
return kUnknown;
}

static __inline enum message_type_t infer_tls_message(const char* buf, size_t count) {
if (count < 6) {
return kUnknown;
}

uint8_t content_type = buf[0];
// TLS content types correspond to the following:
// 0x14: ChangeCipherSpec
// 0x15: Alert
// 0x16: Handshake
// 0x17: ApplicationData
// 0x18: Heartbeat
if (content_type != 0x16) {
return kUnknown;
}

uint16_t legacy_version = buf[1] << 8 | buf[2];
// TLS versions correspond to the following:
// 0x0300: SSL 3.0
// 0x0301: TLS 1.0
// 0x0302: TLS 1.1
// 0x0303: TLS 1.2
// 0x0304: TLS 1.3
if (legacy_version < 0x0300 || legacy_version > 0x0304) {
return kUnknown;
}

uint8_t handshake_type = buf[5];
// Check for ServerHello
if (handshake_type == 2) {
return kResponse;
}
// Check for ClientHello
if (handshake_type == 1) {
return kRequest;
}

return kUnknown;
}

// Cassandra frame:
// 0 8 16 24 32 40
// +---------+---------+---------+---------+---------+
Expand Down Expand Up @@ -699,7 +739,16 @@ static __inline struct protocol_message_t infer_protocol(const char* buf, size_t
// role by considering which side called accept() vs connect(). Once the clean-up
// above is done, the code below can be turned into a chained ternary.
// PROTOCOL_LIST: Requires update on new protocols.
if (ENABLE_HTTP_TRACING && (inferred_message.type = infer_http_message(buf, count)) != kUnknown) {
//
// TODO(ddelnano): TLS tracing should be handled differently in the future as we want to be able
// to trace the handshake and the application data separately (gh#2095). The current connection
// tracker model only works with one or the other, meaning if TLS tracing is enabled, tracing the
// plaintext within an encrypted conn will not work. ENABLE_TLS_TRACING will default to false
// until this is revisted.
if (ENABLE_TLS_TRACING && (inferred_message.type = infer_tls_message(buf, count)) != kUnknown) {
inferred_message.protocol = kProtocolTLS;
} else if (ENABLE_HTTP_TRACING &&
(inferred_message.type = infer_http_message(buf, count)) != kUnknown) {
inferred_message.protocol = kProtocolHTTP;
} else if (ENABLE_CQL_TRACING &&
(inferred_message.type = infer_cql_message(buf, count)) != kUnknown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,27 @@ TEST(ProtocolInferenceTest, AMQPResponse) {
EXPECT_EQ(protocol_message.protocol, kProtocolAMQP);
EXPECT_EQ(protocol_message.type, kResponse);
}

TEST(ProtocolInferenceTest, TLSRequest) {
struct conn_info_t conn_info = {};
// TLS Client Hello
constexpr uint8_t kReqFrame[] = {
0x16, 0x03, 0x01, 0x00, 0x01, 0x01, 0x00, 0x01, 0xfc, 0x03, 0x03, 0x7b, 0x7b, 0x7b,
};
auto protocol_message =
infer_protocol(reinterpret_cast<const char*>(kReqFrame), sizeof(kReqFrame), &conn_info);
EXPECT_EQ(protocol_message.protocol, kProtocolTLS);
EXPECT_EQ(protocol_message.type, kRequest);
}

TEST(ProtocolInferenceTest, TLSResponse) {
struct conn_info_t conn_info = {};
// TLS Server Hello
constexpr uint8_t kRespFrame[] = {
0x16, 0x03, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0xfc, 0x03, 0x03, 0x7b, 0x7b, 0x7b,
};
auto protocol_message =
infer_protocol(reinterpret_cast<const char*>(kRespFrame), sizeof(kRespFrame), &conn_info);
EXPECT_EQ(protocol_message.protocol, kProtocolTLS);
EXPECT_EQ(protocol_message.type, kResponse);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum traffic_protocol_t {
kProtocolKafka = 10,
kProtocolMux = 11,
kProtocolAMQP = 12,
kProtocolTLS = 13,
// We use magic enum to iterate through protocols in C++ land,
// and don't want the C-enum-size trick to show up there.
#ifndef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ auto CreateTraceRoles() {
res.Set(kProtocolKafka, {kRoleServer});
res.Set(kProtocolMux, {kRoleServer});
res.Set(kProtocolAMQP, {kRoleServer});
res.Set(kProtocolTLS, {kRoleServer});

DCHECK(res.AreAllKeysSet());
return res;
Expand Down
4 changes: 4 additions & 0 deletions src/stirling/source_connectors/socket_tracer/data_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ template void DataStream::ProcessBytesToFrames<protocols::amqp::channel_id, prot
template void DataStream::ProcessBytesToFrames<
protocols::mongodb::stream_id_t, protocols::mongodb::Frame, protocols::mongodb::StateWrapper>(
message_type_t type, protocols::mongodb::StateWrapper* state);

template void DataStream::ProcessBytesToFrames<protocols::tls::stream_id_t, protocols::tls::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
void DataStream::Reset() {
data_buffer_.Reset();
has_new_events_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ pl_cc_library(
"//src/stirling/source_connectors/socket_tracer/protocols/nats:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/pgsql:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/redis:cc_library",
"//src/stirling/source_connectors/socket_tracer/protocols/tls:cc_library",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/nats/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/redis/stitcher.h" // IWYU pragma: export
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/stitcher.h" // IWYU pragma: export
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/nats/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/redis/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/types.h"

namespace px {
namespace stirling {
Expand All @@ -53,7 +54,8 @@ using FrameDequeVariant = std::variant<std::monostate,
absl::flat_hash_map<kafka::correlation_id_t, std::deque<kafka::Packet>>,
absl::flat_hash_map<nats::stream_id_t, std::deque<nats::Message>>,
absl::flat_hash_map<amqp::channel_id, std::deque<amqp::Frame>>,
absl::flat_hash_map<mongodb::stream_id_t, std::deque<mongodb::Frame>>>;
absl::flat_hash_map<mongodb::stream_id_t, std::deque<mongodb::Frame>>,
absl::flat_hash_map<tls::stream_id_t, std::deque<tls::Frame>>>;
// clang-format off

} // namespace protocols
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ DEFINE_int32(stirling_enable_mongodb_tracing,
gflags::Int32FromEnv("PX_STIRLING_ENABLE_MONGODB_TRACING",
px::stirling::TraceMode::OnForNewerKernel),
"If true, stirling will trace and process MongoDB messages");
DEFINE_int32(
stirling_enable_tls_tracing,
gflags::Int32FromEnv("PX_STIRLING_ENABLE_TLS_TRACING", px::stirling::TraceMode::Off),
"If true, stirling will trace and process TLS protocol (not the TLS payload) messages. Note: "
"this disables tracing the plaintext within encrypted connections until gh#2095 is addressed.");
DEFINE_bool(stirling_disable_golang_tls_tracing,
gflags::BoolFromEnv("PX_STIRLING_DISABLE_GOLANG_TLS_TRACING", false),
"If true, stirling will not trace TLS traffic for Go applications. This implies "
Expand Down Expand Up @@ -283,6 +288,10 @@ void SocketTraceConnector::InitProtocolTransferSpecs() {
kAMQPTableNum,
{kRoleClient, kRoleServer},
TRANSFER_STREAM_PROTOCOL(amqp)}},
{kProtocolTLS, TransferSpec{FLAGS_stirling_enable_tls_tracing,
kTLSTableNum,
{kRoleClient, kRoleServer},
TRANSFER_STREAM_PROTOCOL(tls)}},
{kProtocolUnknown, TransferSpec{/* trace_mode */ px::stirling::TraceMode::Off,
/* table_num */ static_cast<uint32_t>(-1),
/* trace_roles */ {},
Expand Down Expand Up @@ -491,6 +500,7 @@ Status SocketTraceConnector::InitBPF() {
absl::StrCat("-DENABLE_NATS_TRACING=", protocol_transfer_specs_[kProtocolNATS].enabled),
absl::StrCat("-DENABLE_AMQP_TRACING=", protocol_transfer_specs_[kProtocolAMQP].enabled),
absl::StrCat("-DENABLE_MONGO_TRACING=", protocol_transfer_specs_[kProtocolMongo].enabled),
absl::StrCat("-DENABLE_TLS_TRACING=", protocol_transfer_specs_[kProtocolTLS].enabled),
absl::StrCat("-DBPF_LOOP_LIMIT=", FLAGS_stirling_bpf_loop_limit),
absl::StrCat("-DBPF_CHUNK_LIMIT=", FLAGS_stirling_bpf_chunk_limit),
};
Expand Down Expand Up @@ -1686,6 +1696,35 @@ void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracke
#endif
}

template <>
void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracker& conn_tracker,
protocols::tls::Record record, DataTable* data_table) {
protocols::tls::Frame& req_message = record.req;
protocols::tls::Frame& resp_message = record.resp;

md::UPID upid(ctx->GetASID(), conn_tracker.conn_id().upid.pid,
conn_tracker.conn_id().upid.start_time_ticks);

DataTable::RecordBuilder<&kTLSTable> r(data_table, resp_message.timestamp_ns);
r.Append<r.ColIndex("time_")>(resp_message.timestamp_ns);
r.Append<r.ColIndex("upid")>(upid.value());
// Note that there is a string copy here,
// But std::move is not allowed because we re-use conn object.
r.Append<r.ColIndex("remote_addr")>(conn_tracker.remote_endpoint().AddrStr());
r.Append<r.ColIndex("remote_port")>(conn_tracker.remote_endpoint().port());
r.Append<r.ColIndex("local_addr")>(conn_tracker.local_endpoint().AddrStr());
r.Append<r.ColIndex("local_port")>(conn_tracker.local_endpoint().port());
r.Append<r.ColIndex("trace_role")>(conn_tracker.role());
r.Append<r.ColIndex("req_type")>(static_cast<uint64_t>(req_message.content_type));
r.Append<r.ColIndex("version")>(static_cast<uint64_t>(req_message.legacy_version));
r.Append<r.ColIndex("extensions")>(ToJSONString(req_message.extensions), kMaxHTTPHeadersBytes);
r.Append<r.ColIndex("latency")>(
CalculateLatency(req_message.timestamp_ns, resp_message.timestamp_ns));
#ifndef NDEBUG
r.Append<r.ColIndex("px_info_")>(PXInfoString(conn_tracker, record));
#endif
}

void SocketTraceConnector::SetupOutput(const std::filesystem::path& path) {
DCHECK(!path.empty());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ DECLARE_int32(stirling_enable_kafka_tracing);
DECLARE_int32(stirling_enable_mux_tracing);
DECLARE_int32(stirling_enable_amqp_tracing);
DECLARE_int32(stirling_enable_mongodb_tracing);
DECLARE_int32(stirling_enable_tls_tracing);
DECLARE_bool(stirling_disable_self_tracing);
DECLARE_string(stirling_role_to_trace);

Expand Down Expand Up @@ -95,9 +96,9 @@ class SocketTraceConnector : public BCCSourceConnector {
public:
static constexpr std::string_view kName = "socket_tracer";
// PROTOCOL_LIST
static constexpr auto kTables =
MakeArray(kConnStatsTable, kHTTPTable, kMySQLTable, kCQLTable, kPGSQLTable, kDNSTable,
kRedisTable, kNATSTable, kKafkaTable, kMuxTable, kAMQPTable, kMongoDBTable);
static constexpr auto kTables = MakeArray(
kConnStatsTable, kHTTPTable, kMySQLTable, kCQLTable, kPGSQLTable, kDNSTable, kRedisTable,
kNATSTable, kKafkaTable, kMuxTable, kAMQPTable, kMongoDBTable, kTLSTable);

static constexpr uint32_t kConnStatsTableNum = TableNum(kTables, kConnStatsTable);
static constexpr uint32_t kHTTPTableNum = TableNum(kTables, kHTTPTable);
Expand All @@ -111,6 +112,7 @@ class SocketTraceConnector : public BCCSourceConnector {
static constexpr uint32_t kMuxTableNum = TableNum(kTables, kMuxTable);
static constexpr uint32_t kAMQPTableNum = TableNum(kTables, kAMQPTable);
static constexpr uint32_t kMongoDBTableNum = TableNum(kTables, kMongoDBTable);
static constexpr uint32_t kTLSTableNum = TableNum(kTables, kTLSTable);

static constexpr auto kSamplingPeriod = std::chrono::milliseconds{200};
// TODO(yzhao): This is not used right now. Eventually use this to control data push frequency.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
#include "src/stirling/source_connectors/socket_tracer/nats_table.h"
#include "src/stirling/source_connectors/socket_tracer/pgsql_table.h"
#include "src/stirling/source_connectors/socket_tracer/redis_table.h"
#include "src/stirling/source_connectors/socket_tracer/tls_table.h"
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "src/stirling/source_connectors/socket_tracer/testing/protocol_checkers.h"

#include "src/stirling/source_connectors/socket_tracer/http_table.h"
#include "src/stirling/source_connectors/socket_tracer/tls_table.h"
#include "src/stirling/testing/common.h"

namespace px {
Expand All @@ -28,6 +29,7 @@ namespace testing {
namespace http = protocols::http;
namespace mux = protocols::mux;
namespace mongodb = protocols::mongodb;
namespace tls = protocols::tls;

//-----------------------------------------------------------------------------
// HTTP Checkers
Expand Down Expand Up @@ -105,6 +107,20 @@ std::vector<mongodb::Record> GetTargetRecords(const types::ColumnWrapperRecordBa
return ToRecordVector<mongodb::Record>(record_batch, target_record_indices);
}

template <>
std::vector<tls::Record> ToRecordVector(const types::ColumnWrapperRecordBatch& rb,
const std::vector<size_t>& indices) {
std::vector<tls::Record> result;

for (const auto& idx : indices) {
auto version = rb[kTLSVersionIdx]->Get<types::Int64Value>(idx);
tls::Record r;
r.req.legacy_version = static_cast<tls::LegacyVersion>(version.val);
result.push_back(r);
}
return result;
}

} // namespace testing
} // namespace stirling
} // namespace px
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "src/stirling/source_connectors/socket_tracer/protocols/http/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/mongodb/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/mux/types.h"
#include "src/stirling/source_connectors/socket_tracer/protocols/tls/types.h"

namespace px {
namespace stirling {
Expand Down
Loading

0 comments on commit 3e556d6

Please sign in to comment.