Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pgsq: add cancel request message - v1 #10000

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rust/src/pgsql/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ fn log_request(req: &PgsqlFEMessage, flags: u32) -> Result<JsonBuilder, JsonErro
}) => {
js.set_string_from_bytes(req.to_str(), payload)?;
}
PgsqlFEMessage::CancelRequest(CancelRequestMessage {
pid,
backend_key,
}) => {
js.set_uint("pid", (*pid).into())?;
js.set_uint("backend_key", (*backend_key).into())?;
}
PgsqlFEMessage::Terminate(TerminationMessage {
identifier: _,
length: _,
Expand Down
84 changes: 69 additions & 15 deletions rust/src/pgsql/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ use nom7::{Err, IResult};
pub const PGSQL_LENGTH_FIELD: u32 = 4;

pub const PGSQL_DUMMY_PROTO_MAJOR: u16 = 1234; // 0x04d2
pub const PGSQL_DUMMY_PROTO_CANCEL_REQUEST: u16 = 5678; // 0x162e
pub const PGSQL_DUMMY_PROTO_MINOR_SSL: u16 = 5679; //0x162f
pub const _PGSQL_DUMMY_PROTO_MINOR_GSSAPI: u16 = 5680; // 0x1630

fn parse_length(i: &[u8]) -> IResult<&[u8], u32> {
verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)
}

#[derive(Debug, PartialEq, Eq)]
pub enum PgsqlParameters {
// startup parameters
Expand Down Expand Up @@ -311,6 +316,12 @@ pub struct TerminationMessage {
pub length: u32,
}

#[derive(Debug, PartialEq, Eq)]
pub struct CancelRequestMessage {
pub pid: u32,
pub backend_key: u32,
}

#[derive(Debug, PartialEq, Eq)]
pub enum PgsqlFEMessage {
SSLRequest(DummyStartupPacket),
Expand All @@ -319,6 +330,7 @@ pub enum PgsqlFEMessage {
SASLInitialResponse(SASLInitialResponsePacket),
SASLResponse(RegularPacket),
SimpleQuery(RegularPacket),
CancelRequest(CancelRequestMessage),
Terminate(TerminationMessage),
}

Expand All @@ -331,6 +343,7 @@ impl PgsqlFEMessage {
PgsqlFEMessage::SASLInitialResponse(_) => "sasl_initial_response",
PgsqlFEMessage::SASLResponse(_) => "sasl_response",
PgsqlFEMessage::SimpleQuery(_) => "simple_query",
PgsqlFEMessage::CancelRequest(_) => "cancel_request",
PgsqlFEMessage::Terminate(_) => "termination_message",
}
}
Expand Down Expand Up @@ -562,7 +575,7 @@ fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenti

pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload)(i)?;
Ok((i, PgsqlFEMessage::SASLInitialResponse(
SASLInitialResponsePacket {
Expand All @@ -576,7 +589,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {

pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let resp = PgsqlFEMessage::SASLResponse(
RegularPacket {
Expand Down Expand Up @@ -605,16 +618,20 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
},
PGSQL_DUMMY_PROTO_MAJOR => {
let (b, proto_major) = be_u16(b)?;
let (b, proto_minor) = all_consuming(be_u16)(b)?;
let _message = match proto_minor {
PGSQL_DUMMY_PROTO_MINOR_SSL => (len, proto_major, proto_minor),
let (b, proto_minor) = be_u16(b)?;
let (b, message) = match proto_minor {
PGSQL_DUMMY_PROTO_CANCEL_REQUEST => {
parse_cancel_request(b)?
},
PGSQL_DUMMY_PROTO_MINOR_SSL => (b, PgsqlFEMessage::SSLRequest(DummyStartupPacket{
length: len,
proto_major,
proto_minor
})),
_ => return Err(Err::Error(make_error(b, ErrorKind::Switch))),
};

(b, PgsqlFEMessage::SSLRequest(DummyStartupPacket{
length: len,
proto_major,
proto_minor}))
(b, message)
}
_ => return Err(Err::Error(make_error(b, ErrorKind::Switch))),
};
Expand All @@ -636,7 +653,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
// Password can be encrypted or in cleartext
pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
let (i, password) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
take_until1("\x00")
Expand All @@ -651,7 +668,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {

fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((i, PgsqlFEMessage::SimpleQuery(RegularPacket {
identifier,
Expand All @@ -660,9 +677,18 @@ fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
})))
}

fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, pid) = be_u32(i)?;
let (i, backend_key) = be_u32(i)?;
Ok((i, PgsqlFEMessage::CancelRequest(CancelRequestMessage {
pid,
backend_key,
})))
}

fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?;
let (i, length) = verify(be_u32, |&x| x == PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
Ok((i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length })))
}

Expand Down Expand Up @@ -760,7 +786,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq

fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?;
let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
let (i, param) = map_parser(take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter)(i)?;
Ok((i, PgsqlBEMessage::ParameterStatus(ParameterStatusMessage {
identifier,
Expand Down Expand Up @@ -791,7 +817,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {

fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?;
let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
let (i, length) = parse_length(i)?;
let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?;
Ok((i, PgsqlBEMessage::CommandComplete(RegularPacket {
identifier,
Expand Down Expand Up @@ -1247,9 +1273,37 @@ mod tests {
let result = parse_request(&buf[0..3]);
assert!(result.is_err());

// TODO add other messages
}

#[test]
fn test_cancel_request_message() {
// A cancel request message
let buf: &[u8] = &[
0x00, 0x00, 0x00, 0x10, // length: 16 (fixed)
0x04, 0xd2, 0x16, 0x2e, // 1234.5678 - identifies a cancel request
0x00, 0x00, 0x76, 0x31, // PID: 30257
0x23, 0x84, 0xf7, 0x2d]; // Backend key: 595916589
let result = parse_cancel_request(buf);
assert!(result.is_ok());

let result = parse_cancel_request(&buf[0..3]);
assert!(result.is_err());

let result = pgsql_parse_startup_packet(buf);
assert!(result.is_ok());

let fail_result = pgsql_parse_startup_packet(&buf[0..3]);
assert!(fail_result.is_err());

let result = parse_request(buf);
assert!(result.is_ok());

let fail_result = parse_request(&buf[0..3]);
assert!(fail_result.is_err());
}



#[test]
fn test_parse_error_response_code() {
let buf: &[u8] = &[0x43, 0x32, 0x38, 0x30, 0x30, 0x30, 0x00];
Expand Down
4 changes: 3 additions & 1 deletion rust/src/pgsql/pgsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ pub enum PgsqlStateProgress {
DataRowReceived,
CommandCompletedReceived,
ErrorMessageReceived,
CancelRequestReceived,
ConnectionTerminated,
#[cfg(test)]
UnknownState,
Expand Down Expand Up @@ -151,7 +152,7 @@ impl Default for PgsqlState {
Self::new()
}
}

impl PgsqlState {
pub fn new() -> Self {
Self {
Expand Down Expand Up @@ -280,6 +281,7 @@ impl PgsqlState {

// Important to keep in mind that: "In simple Query mode, the format of retrieved values is always text, except when the given command is a FETCH from a cursor declared with the BINARY option. In that case, the retrieved values are in binary format. The format codes given in the RowDescription message tell which format is being used." (from pgsql official documentation)
}
PgsqlFEMessage::CancelRequest(_) => Some(PgsqlStateProgress::CancelRequestReceived),
PgsqlFEMessage::Terminate(_) => {
SCLogDebug!("Match: Terminate message");
Some(PgsqlStateProgress::ConnectionTerminated)
Expand Down
Loading