Skip to content

Commit

Permalink
Optimize the hook by parsing the packet before processing it
Browse files Browse the repository at this point in the history
  • Loading branch information
itaispiegel committed Feb 2, 2024
1 parent 55ff497 commit 40a1976
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 109 deletions.
2 changes: 1 addition & 1 deletion module/Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
obj-m += firewall.o
firewall-objs := module.o rules_table.o logs.o netfilter_hook.o
firewall-objs := module.o rules_table.o logs.o netfilter_hook.o parser.o

all:
make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules
Expand Down
159 changes: 51 additions & 108 deletions module/netfilter_hook.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,14 @@
#include "fw.h"
#include "logs.h"
#include "netfilter_hook.h"

const __be32 LOOPBACK_PREFIX = 0x7f000000;
const __be32 LOOPBACK_MASK = 0xff000000;
#include "parser.h"

extern struct list_head logs_list;
extern size_t logs_count;

extern rule_t rules[MAX_RULES];
extern __u8 rules_count;

struct ports_tuple {
__be16 sport;
__be16 dport;
};

static unsigned int forward_hook_func(void *priv, struct sk_buff *skb,
const struct nf_hook_state *state);

Expand All @@ -29,37 +22,14 @@ static const struct nf_hook_ops forward_hook = {
.hooknum = NF_INET_FORWARD,
};

static inline bool is_loopback_addr(__be32 addr) {
return (addr & LOOPBACK_MASK) == LOOPBACK_PREFIX;
}

static inline bool is_loopback_skb(struct sk_buff *skb) {
struct iphdr *ip_header = ip_hdr(skb);
return is_loopback_addr(ip_header->saddr) ||
is_loopback_addr(ip_header->daddr);
}

static inline bool is_unhandled_protocol_skb(struct sk_buff *skb) {
struct iphdr *ip_header = ip_hdr(skb);
return ip_header->protocol != PROT_TCP && ip_header->protocol != PROT_UDP &&
ip_header->protocol != PROT_ICMP;
}

static inline bool is_xmas_skb(struct sk_buff *skb) {
struct iphdr *ip_header = ip_hdr(skb);
struct tcphdr *tcp_header = tcp_hdr(skb);
return ip_header->protocol == PROT_TCP &&
(tcp_header->fin && tcp_header->urg && tcp_header->psh);
}

static inline bool match_direction(rule_t *rule, struct sk_buff *skb) {
static inline bool match_direction(rule_t *rule, packet_t *packet) {
// This is a bit confusing, but the packets going outside are received
// on the IN device and vice versa, so the direction is reversed.
return rule->direction == DIRECTION_ANY ||
(rule->direction == DIRECTION_IN &&
strcmp(skb->dev->name, OUT_NET_DEVICE_NAME) == 0) ||
strcmp(packet->dev_name, OUT_NET_DEVICE_NAME) == 0) ||
(rule->direction == DIRECTION_OUT &&
strcmp(skb->dev->name, IN_NET_DEVICE_NAME) == 0);
strcmp(packet->dev_name, IN_NET_DEVICE_NAME) == 0);
}

static inline bool match_rule_ports(__be16 rule_port, __be16 skb_port) {
Expand All @@ -69,48 +39,40 @@ static inline bool match_rule_ports(__be16 rule_port, __be16 skb_port) {
(rule_port == PORT_ABOVE_1023_BE && be16_to_cpu(skb_port) > 1023));
}

static inline bool match_ip_addrs(rule_t *rule, struct iphdr *ip_header) {
static inline bool match_ip_addrs(rule_t *rule, packet_t *packet) {
return (rule->src_ip == 0 ||
(rule->src_prefix_size != 0 &&
(rule->src_ip & rule->src_prefix_mask) ==
(ip_header->saddr & rule->src_prefix_mask))) &&
(packet->src_ip & rule->src_prefix_mask))) &&
(rule->dst_ip == 0 ||
(rule->dst_prefix_size != 0 &&
(rule->dst_ip & rule->dst_prefix_mask) ==
(ip_header->daddr & rule->dst_prefix_mask)));
(packet->dst_ip & rule->dst_prefix_mask)));
}

static inline bool match_ports(rule_t *rule, struct sk_buff *skb) {
char *transport_header = skb_transport_header(skb);
static inline bool match_ports(rule_t *rule, packet_t *packet) {
return (rule->protocol == PROT_UDP &&
match_rule_ports(rule->src_port,
((struct udphdr *)transport_header)->source) &&
match_rule_ports(rule->dst_port,
((struct udphdr *)transport_header)->dest)) ||
match_rule_ports(rule->src_port, packet->src_port) &&
match_rule_ports(rule->dst_port, packet->dst_port)) ||
(rule->protocol == PROT_TCP &&
match_rule_ports(rule->src_port,
((struct tcphdr *)transport_header)->source) &&
match_rule_ports(rule->dst_port,
((struct tcphdr *)transport_header)->dest));
match_rule_ports(rule->src_port, packet->src_port) &&
match_rule_ports(rule->dst_port, packet->dst_port));
}

static inline bool match_protocol(rule_t *rule, struct iphdr *ip_header) {
return rule->protocol == PROT_ANY || rule->protocol == ip_header->protocol;
static inline bool match_protocol(rule_t *rule, packet_t *packet) {
return rule->protocol == PROT_ANY || rule->protocol == packet->protocol;
}

static inline bool match_ack(rule_t *rule, struct iphdr *ip_header,
struct tcphdr *tcp_header) {
return rule->ack == ACK_ANY ||
(ip_header->protocol == PROT_TCP &&
((rule->ack == ACK_YES && tcp_header->ack) ||
(rule->ack == ACK_NO && !tcp_header->ack)));
static inline bool match_ack(rule_t *rule, packet_t *packet) {
return rule->ack == ACK_ANY || (packet->protocol == PROT_TCP &&
((rule->ack == ACK_YES && packet->ack) ||
(rule->ack == ACK_NO && !packet->ack)));
}

static inline bool match_rule_skb(rule_t *rule, struct sk_buff *skb) {
struct iphdr *ip_header = ip_hdr(skb);
return match_direction(rule, skb) && match_ip_addrs(rule, ip_header) &&
match_ports(rule, skb) && match_protocol(rule, ip_header) &&
match_ack(rule, ip_header, tcp_hdr(skb));
static inline bool match_rule_packet(rule_t *rule, packet_t *packet) {
return match_direction(rule, packet) && match_ip_addrs(rule, packet) &&
match_ports(rule, packet) && match_protocol(rule, packet) &&
match_ack(rule, packet);
}

static inline bool log_match_rule(log_row_t *log_row, rule_t *rule) {
Expand All @@ -121,24 +83,6 @@ static inline bool log_match_rule(log_row_t *log_row, rule_t *rule) {
log_row->dst_port == rule->dst_port;
}

static inline struct ports_tuple ports_from_skb(struct sk_buff *skb) {
struct ports_tuple p;
struct iphdr *ip_header = ip_hdr(skb);

// Note that we the store the exact ports, even if they're above 1023.
if (ip_header->protocol == PROT_UDP) {
p.sport = udp_hdr(skb)->source;
p.dport = udp_hdr(skb)->dest;
} else if (ip_header->protocol == PROT_TCP) {
p.sport = tcp_hdr(skb)->source;
p.dport = tcp_hdr(skb)->dest;
} else {
p.sport = 0;
p.dport = 0;
}
return p;
}

static inline log_row_t new_log_row_by_rule(rule_t *rule, reason_t reason) {
return (log_row_t){
.timestamp = ktime_get_real_seconds(),
Expand All @@ -153,19 +97,16 @@ static inline log_row_t new_log_row_by_rule(rule_t *rule, reason_t reason) {
};
}

static inline log_row_t new_log_row_by_skb(struct sk_buff *skb,
reason_t reason) {
struct ports_tuple ports = ports_from_skb(skb);
struct iphdr *ip_header = ip_hdr(skb);

static inline log_row_t new_log_row_by_packet(packet_t *packet,
reason_t reason) {
return (log_row_t){
.timestamp = ktime_get_real_seconds(),
.protocol = ip_header->protocol,
.protocol = packet->protocol,
.action = FW_POLICY,
.src_ip = ip_header->saddr,
.dst_ip = ip_header->daddr,
.src_port = ports.sport,
.dst_port = ports.dport,
.src_ip = packet->src_ip,
.dst_ip = packet->dst_ip,
.src_port = packet->src_port,
.dst_port = packet->dst_port,
.reason = reason,
.count = 1,
};
Expand Down Expand Up @@ -196,25 +137,22 @@ static void update_log_entry_by_matching_rule(rule_t *rule, reason_t reason) {
logs_count++;
}

static bool log_entry_matches_skb(struct log_entry *log_entry,
struct sk_buff *skb) {
struct ports_tuple ports = ports_from_skb(skb);
struct iphdr *ip_header = ip_hdr(skb);

return log_entry->log_row.protocol == ip_header->protocol &&
log_entry->log_row.src_ip == ip_header->saddr &&
log_entry->log_row.dst_ip == ip_header->daddr &&
log_entry->log_row.src_port == ports.sport &&
log_entry->log_row.dst_port == ports.dport;
static bool log_entry_matches_packet(struct log_entry *log_entry,
packet_t *packet) {
return log_entry->log_row.protocol == packet->protocol &&
log_entry->log_row.src_ip == packet->src_ip &&
log_entry->log_row.dst_ip == packet->dst_ip &&
log_entry->log_row.src_port == packet->src_port &&
log_entry->log_row.dst_port == packet->dst_port;
}

static void update_log_entry_by_skb(struct sk_buff *skb, reason_t reason) {
static void update_log_entry_by_packet(packet_t *packet, reason_t reason) {
struct log_entry *log_entry;
struct list_head *pos;
list_for_each(pos, &logs_list) {
log_entry = list_entry(pos, struct log_entry, list);
if (log_entry->log_row.reason == reason &&
log_entry_matches_skb(log_entry, skb)) {
log_entry_matches_packet(log_entry, packet)) {
log_entry->log_row.count++;
log_entry->log_row.timestamp = ktime_get_real_seconds();
return;
Expand All @@ -230,31 +168,36 @@ static void update_log_entry_by_skb(struct sk_buff *skb, reason_t reason) {
return;
}

log_entry->log_row = new_log_row_by_skb(skb, reason);
log_entry->log_row = new_log_row_by_packet(packet, reason);
list_add_tail(&log_entry->list, &logs_list);
logs_count++;
}

static unsigned int forward_hook_func(void *priv, struct sk_buff *skb,
const struct nf_hook_state *state) {
__u8 i;
if (is_loopback_skb(skb)) {
return NF_ACCEPT;
} else if (is_unhandled_protocol_skb(skb)) {
packet_t packet;

parse_packet(&packet, skb);

if (packet.type == PACKET_TYPE_LOOPBACK ||
packet.type == PACKET_TYPE_UNHANDLED_PROTOCOL) {
// In this case we want to accept the packet without logging it.
return NF_ACCEPT;
} else if (is_xmas_skb(skb)) {
update_log_entry_by_skb(skb, REASON_XMAS_PACKET);
} else if (packet.type == PACKET_TYPE_XMAS) {
update_log_entry_by_packet(&packet, REASON_XMAS_PACKET);
return NF_DROP;
}

// In this case the packet must be a normal packet.
for (i = 0; i < rules_count; i++) {
if (match_rule_skb(&rules[i], skb)) {
if (match_rule_packet(&rules[i], &packet)) {
update_log_entry_by_matching_rule(&rules[i], i);
return rules[i].action;
}
}

update_log_entry_by_skb(skb, REASON_NO_MATCHING_RULE);
update_log_entry_by_packet(&packet, REASON_NO_MATCHING_RULE);
return FW_POLICY;
}

Expand Down
53 changes: 53 additions & 0 deletions module/parser.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "parser.h"

static const __be32 LOOPBACK_PREFIX = 0x7f000000;
static const __be32 LOOPBACK_MASK = 0xff000000;

static inline bool is_loopback_addr(__be32 addr) {
return (addr & LOOPBACK_MASK) == LOOPBACK_PREFIX;
}

void parse_packet(packet_t *packet, struct sk_buff *skb) {
struct iphdr *ip_header = ip_hdr(skb);
struct tcphdr *tcp_header;
struct udphdr *udp_header;

packet->src_ip = ip_header->saddr;
packet->dst_ip = ip_header->daddr;
packet->dev_name = skb->dev->name;

if (is_loopback_addr(packet->src_ip) || is_loopback_addr(packet->dst_ip)) {
// In this case we don't care about the rest of the fields, and they
// might contain garbage.
packet->type = PACKET_TYPE_LOOPBACK;
return;
}

packet->protocol = ip_header->protocol;

// Noteice that we the store the exact ports, even if they're above 1023.
if (packet->protocol == PROT_TCP) {
tcp_header = tcp_hdr(skb);
packet->src_port = tcp_header->source;
packet->dst_port = tcp_header->dest;
packet->ack = tcp_header->ack;
if (tcp_header->fin && tcp_header->psh && tcp_header->urg) {
packet->type = PACKET_TYPE_XMAS;
} else {
packet->type = PACKET_TYPE_NORMAL;
}
} else if (ip_header->protocol == PROT_UDP) {
udp_header = udp_hdr(skb);
packet->type = PACKET_TYPE_NORMAL;
packet->src_port = udp_header->source;
packet->dst_port = udp_header->dest;
} else if (ip_header->protocol == PROT_ICMP) {
packet->type = PACKET_TYPE_NORMAL;
packet->src_port = 0;
packet->dst_port = 0;
} else {
packet->type = PACKET_TYPE_UNHANDLED_PROTOCOL;
// In this case we don't care about the ports and they might contain
// garbage.
}
}
26 changes: 26 additions & 0 deletions module/parser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef _PARSER_H_
#define _PARSER_H_

#include "fw.h"

typedef enum {
PACKET_TYPE_NORMAL,
PACKET_TYPE_XMAS,
PACKET_TYPE_UNHANDLED_PROTOCOL,
PACKET_TYPE_LOOPBACK
} packet_type;

typedef struct {
packet_type type;
char *dev_name;
__be32 src_ip;
__be32 dst_ip;
__be16 src_port;
__be16 dst_port;
__u8 protocol;
unsigned short ack;
} packet_t;

void parse_packet(packet_t *packet, struct sk_buff *skb);

#endif // _PARSER_H_

0 comments on commit 40a1976

Please sign in to comment.