Skip to content

Commit

Permalink
Merge pull request #25590 from vespa-engine/revert-25588-revert-25586…
Browse files Browse the repository at this point in the history
…-andreer/wg-wip-3

Reapply "open wireguard port for config servers"
  • Loading branch information
baldersheim authored Jan 16, 2023
2 parents c18caef + 2dd2e2b commit 09f909c
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,28 @@
*/
public class Acl {

public static final Acl EMPTY = new Acl(Set.of(), Set.of(), Set.of());
public static final Acl EMPTY = new Acl(Set.of(), Set.of(), Set.of(), Set.of());

private final Set<Node> trustedNodes;
private final Set<Integer> trustedPorts;
private final Set<Integer> trustedUdpPorts;
private final Set<String> trustedNetworks;

/**
* @param trustedPorts Ports to trust
* @param trustedPorts TCP Ports to trust
* @param trustedUdpPorts UDP ports to trust
* @param trustedNodes Nodes to trust
* @param trustedNetworks Networks (in CIDR notation) to trust
*/
public Acl(Set<Integer> trustedPorts, Set<Node> trustedNodes, Set<String> trustedNetworks) {
public Acl(Set<Integer> trustedPorts, Set<Integer> trustedUdpPorts, Set<Node> trustedNodes, Set<String> trustedNetworks) {
this.trustedNodes = copyOfNullable(trustedNodes);
this.trustedPorts = copyOfNullable(trustedPorts);
this.trustedUdpPorts = copyOfNullable(trustedUdpPorts);
this.trustedNetworks = copyOfNullable(trustedNetworks);
}

public Acl(Set<Integer> trustedPorts, Set<Node> trustedNodes) {
this(trustedPorts, trustedNodes, Set.of());
this(trustedPorts, Set.of(), trustedNodes, Set.of());
}

public List<String> toRules(IPVersion ipVersion) {
Expand All @@ -66,6 +69,11 @@ public List<String> toRules(IPVersion ipVersion) {
rules.add("-A INPUT -p tcp -m multiport --dports " + joinPorts(trustedPorts) + " -j ACCEPT");
}

// Allow trusted UDP ports if any
if (!trustedUdpPorts.isEmpty()) {
rules.add("-A INPUT -p udp -m multiport --dports " + joinPorts(trustedUdpPorts) + " -j ACCEPT");
}

// Allow traffic from trusted nodes, limited to specific ports, if any
getTrustedNodes(ipVersion).stream()
.map(node -> {
Expand Down Expand Up @@ -113,8 +121,8 @@ public Set<Integer> getTrustedPorts() {
return trustedPorts;
}

public Set<Integer> getTrustedPorts(IPVersion ipVersion) {
return trustedPorts;
public Set<Integer> getTrustedUdpPorts() {
return trustedUdpPorts;
}

@Override
Expand All @@ -124,19 +132,21 @@ public boolean equals(Object o) {
Acl acl = (Acl) o;
return trustedNodes.equals(acl.trustedNodes) &&
trustedPorts.equals(acl.trustedPorts) &&
trustedUdpPorts.equals(acl.trustedUdpPorts) &&
trustedNetworks.equals(acl.trustedNetworks);
}

@Override
public int hashCode() {
return Objects.hash(trustedNodes, trustedPorts, trustedNetworks);
return Objects.hash(trustedNodes, trustedPorts, trustedUdpPorts, trustedNetworks);
}

@Override
public String toString() {
return "Acl{" +
"trustedNodes=" + trustedNodes +
", trustedPorts=" + trustedPorts +
", trustedUdpPorts=" + trustedUdpPorts +
", trustedNetworks=" + trustedNetworks +
'}';
}
Expand Down Expand Up @@ -175,6 +185,7 @@ public static class Builder {

private final Set<Node> trustedNodes = new HashSet<>();
private final Set<Integer> trustedPorts = new HashSet<>();
private final Set<Integer> trustedUdpPorts = new HashSet<>();
private final Set<String> trustedNetworks = new HashSet<>();

public Builder() { }
Expand Down Expand Up @@ -207,13 +218,18 @@ public Builder withTrustedPorts(Integer... ports) {
return this;
}

public Builder withTrustedUdpPorts(Integer... ports) {
trustedUdpPorts.addAll(List.of(ports));
return this;
}

public Builder withTrustedNetworks(Set<String> networks) {
trustedNetworks.addAll(networks);
return this;
}

public Acl build() {
return new Acl(trustedPorts, trustedNodes, trustedNetworks);
return new Acl(trustedPorts, trustedUdpPorts, trustedNodes, trustedNetworks);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ public Map<String, Acl> getAcls(String hostName) {
GetAclResponse.Port::getTrustedBy,
Collectors.mapping(port -> port.port, Collectors.toSet())));

// Group UDP ports by container hostname that trusts them
Map<String, Set<Integer>> trustedUdpPorts = response.trustedUdpPorts.stream()
.collect(Collectors.groupingBy(
GetAclResponse.Port::getTrustedBy,
Collectors.mapping(port -> port.port, Collectors.toSet())));

// Group node ip-addresses by container hostname that trusts them
Map<String, Set<Acl.Node>> trustedNodes = response.trustedNodes.stream()
.collect(Collectors.groupingBy(
Expand All @@ -106,12 +112,14 @@ public Map<String, Acl> getAcls(String hostName) {


// For each hostname create an ACL
return Stream.of(trustedNodes.keySet(), trustedPorts.keySet(), trustedNetworks.keySet())
return Stream.of(trustedNodes.keySet(), trustedPorts.keySet(), trustedUdpPorts.keySet(), trustedNetworks.keySet())
.flatMap(Set::stream)
.distinct()
.collect(Collectors.toMap(
Function.identity(),
hostname -> new Acl(trustedPorts.get(hostname), trustedNodes.get(hostname),
hostname -> new Acl(trustedPorts.get(hostname),
trustedUdpPorts.get(hostname),
trustedNodes.get(hostname),
trustedNetworks.get(hostname))));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ public class GetAclResponse {
@JsonProperty("trustedPorts")
public final List<Port> trustedPorts;

@JsonProperty("trustedUdpPorts")
public final List<Port> trustedUdpPorts;

@JsonCreator
public GetAclResponse(@JsonProperty("trustedNodes") List<Node> trustedNodes,
@JsonProperty("trustedNetworks") List<Network> trustedNetworks,
@JsonProperty("trustedPorts") List<Port> trustedPorts) {
@JsonProperty("trustedPorts") List<Port> trustedPorts,
@JsonProperty("trustedUdpPorts") List<Port> trustedUdpPorts) {
this.trustedNodes = trustedNodes == null ? List.of() : List.copyOf(trustedNodes);
this.trustedNetworks = trustedNetworks == null ? List.of() : List.copyOf(trustedNetworks);
this.trustedPorts = trustedPorts == null ? List.of() : List.copyOf(trustedPorts);
this.trustedUdpPorts = trustedUdpPorts == null ? List.of() : List.copyOf(trustedUdpPorts);
}

@JsonIgnoreProperties(ignoreUnknown = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import java.util.List;

/**
* An editor that assumes all rules in the filter table are exactly as the the wanted rules
* An editor that assumes all rules in the filter table are exactly as the wanted rules
*
* @author smorgrav
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,67 +19,72 @@
public class AclTest {

private static final Acl aclCommon = new Acl(
Set.of(1234, 453),
Set.of(1234, 453), Set.of(4321),
testNodes(Set.of(), "192.1.2.2", "fb00::1", "fe80::2", "fe80::3"),
Set.of());

private static final Acl aclWithoutPorts = new Acl(
Set.of(),
Set.of(), Set.of(),
testNodes(Set.of(), "192.1.2.2", "fb00::1", "fe80::2"),
Set.of());

@Test
void no_trusted_ports() {
String listRulesIpv4 = String.join("\n", aclWithoutPorts.toRules(IPVersion.IPv4));
assertEquals(
"-P INPUT ACCEPT\n" +
"-P FORWARD ACCEPT\n" +
"-P OUTPUT ACCEPT\n" +
"-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT\n" +
"-A INPUT -i lo -j ACCEPT\n" +
"-A INPUT -p icmp -j ACCEPT\n" +
"-A INPUT -s 192.1.2.2/32 -j ACCEPT\n" +
"-A INPUT -j REJECT --reject-with icmp-port-unreachable",
"""
-P INPUT ACCEPT
-P FORWARD ACCEPT
-P OUTPUT ACCEPT
-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
-A INPUT -i lo -j ACCEPT
-A INPUT -p icmp -j ACCEPT
-A INPUT -s 192.1.2.2/32 -j ACCEPT
-A INPUT -j REJECT --reject-with icmp-port-unreachable""",
listRulesIpv4);
}

@Test
void ipv4_rules() {
String listRulesIpv4 = String.join("\n", aclCommon.toRules(IPVersion.IPv4));
assertEquals(
"-P INPUT ACCEPT\n" +
"-P FORWARD ACCEPT\n" +
"-P OUTPUT ACCEPT\n" +
"-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT\n" +
"-A INPUT -i lo -j ACCEPT\n" +
"-A INPUT -p icmp -j ACCEPT\n" +
"-A INPUT -p tcp -m multiport --dports 453,1234 -j ACCEPT\n" +
"-A INPUT -s 192.1.2.2/32 -j ACCEPT\n" +
"-A INPUT -j REJECT --reject-with icmp-port-unreachable",
"""
-P INPUT ACCEPT
-P FORWARD ACCEPT
-P OUTPUT ACCEPT
-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
-A INPUT -i lo -j ACCEPT
-A INPUT -p icmp -j ACCEPT
-A INPUT -p tcp -m multiport --dports 453,1234 -j ACCEPT
-A INPUT -p udp -m multiport --dports 4321 -j ACCEPT
-A INPUT -s 192.1.2.2/32 -j ACCEPT
-A INPUT -j REJECT --reject-with icmp-port-unreachable""",
listRulesIpv4);
}

@Test
void ipv6_rules() {
String listRulesIpv6 = String.join("\n", aclCommon.toRules(IPVersion.IPv6));
assertEquals(
"-P INPUT ACCEPT\n" +
"-P FORWARD ACCEPT\n" +
"-P OUTPUT ACCEPT\n" +
"-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT\n" +
"-A INPUT -i lo -j ACCEPT\n" +
"-A INPUT -p ipv6-icmp -j ACCEPT\n" +
"-A INPUT -p tcp -m multiport --dports 453,1234 -j ACCEPT\n" +
"-A INPUT -s fb00::1/128 -j ACCEPT\n" +
"-A INPUT -s fe80::2/128 -j ACCEPT\n" +
"-A INPUT -s fe80::3/128 -j ACCEPT\n" +
"-A INPUT -j REJECT --reject-with icmp6-port-unreachable", listRulesIpv6);
"""
-P INPUT ACCEPT
-P FORWARD ACCEPT
-P OUTPUT ACCEPT
-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
-A INPUT -i lo -j ACCEPT
-A INPUT -p ipv6-icmp -j ACCEPT
-A INPUT -p tcp -m multiport --dports 453,1234 -j ACCEPT
-A INPUT -p udp -m multiport --dports 4321 -j ACCEPT
-A INPUT -s fb00::1/128 -j ACCEPT
-A INPUT -s fe80::2/128 -j ACCEPT
-A INPUT -s fe80::3/128 -j ACCEPT
-A INPUT -j REJECT --reject-with icmp6-port-unreachable""", listRulesIpv6);
}

@Test
void ipv6_rules_stable_order() {
Acl aclCommonDifferentOrder = new Acl(
Set.of(453, 1234),
Set.of(453, 1234), Set.of(4321),
testNodes(Set.of(), "fe80::2", "192.1.2.2", "fb00::1", "fe80::3"),
Set.of());

Expand All @@ -90,29 +95,31 @@ void ipv6_rules_stable_order() {

@Test
void trusted_networks() {
Acl acl = new Acl(Set.of(4080), testNodes(Set.of(), "127.0.0.1"), Set.of("10.0.0.0/24", "2001:db8::/32"));

assertEquals("-P INPUT ACCEPT\n" +
"-P FORWARD ACCEPT\n" +
"-P OUTPUT ACCEPT\n" +
"-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT\n" +
"-A INPUT -i lo -j ACCEPT\n" +
"-A INPUT -p icmp -j ACCEPT\n" +
"-A INPUT -p tcp -m multiport --dports 4080 -j ACCEPT\n" +
"-A INPUT -s 127.0.0.1/32 -j ACCEPT\n" +
"-A INPUT -s 10.0.0.0/24 -j ACCEPT\n" +
"-A INPUT -j REJECT --reject-with icmp-port-unreachable",
Acl acl = new Acl(Set.of(4080), Set.of(), testNodes(Set.of(), "127.0.0.1"), Set.of("10.0.0.0/24", "2001:db8::/32"));

assertEquals("""
-P INPUT ACCEPT
-P FORWARD ACCEPT
-P OUTPUT ACCEPT
-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
-A INPUT -i lo -j ACCEPT
-A INPUT -p icmp -j ACCEPT
-A INPUT -p tcp -m multiport --dports 4080 -j ACCEPT
-A INPUT -s 127.0.0.1/32 -j ACCEPT
-A INPUT -s 10.0.0.0/24 -j ACCEPT
-A INPUT -j REJECT --reject-with icmp-port-unreachable""",
String.join("\n", acl.toRules(IPVersion.IPv4)));

assertEquals("-P INPUT ACCEPT\n" +
"-P FORWARD ACCEPT\n" +
"-P OUTPUT ACCEPT\n" +
"-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT\n" +
"-A INPUT -i lo -j ACCEPT\n" +
"-A INPUT -p ipv6-icmp -j ACCEPT\n" +
"-A INPUT -p tcp -m multiport --dports 4080 -j ACCEPT\n" +
"-A INPUT -s 2001:db8::/32 -j ACCEPT\n" +
"-A INPUT -j REJECT --reject-with icmp6-port-unreachable",
assertEquals("""
-P INPUT ACCEPT
-P FORWARD ACCEPT
-P OUTPUT ACCEPT
-A INPUT -m state --state RELATED,ESTABLISHED -j ACCEPT
-A INPUT -i lo -j ACCEPT
-A INPUT -p ipv6-icmp -j ACCEPT
-A INPUT -p tcp -m multiport --dports 4080 -j ACCEPT
-A INPUT -s 2001:db8::/32 -j ACCEPT
-A INPUT -j REJECT --reject-with icmp6-port-unreachable""",
String.join("\n", acl.toRules(IPVersion.IPv6)));
}

Expand All @@ -121,7 +128,7 @@ void config_server_acl() {
Set<Acl.Node> testNodes = Stream.concat(testNodes(NodeType.config, Set.of(), "172.17.0.41", "172.17.0.42", "172.17.0.43").stream(),
testNodes(NodeType.tenant, Set.of(19070), "172.17.0.81", "172.17.0.82", "172.17.0.83").stream())
.collect(Collectors.toSet());
Acl acl = new Acl(Set.of(22, 4443), testNodes, Set.of());
Acl acl = new Acl(Set.of(22, 4443), Set.of(), testNodes, Set.of());
assertEquals("""
-P INPUT ACCEPT
-P FORWARD ACCEPT
Expand All @@ -142,7 +149,7 @@ void config_server_acl() {
Set<Acl.Node> testNodes2 = Stream.concat(testNodes(NodeType.config, Set.of(), "2001:db8::41", "2001:db8::42", "2001:db8::43").stream(),
testNodes(NodeType.tenant, Set.of(19070), "2001:db8::81", "2001:db8::82", "2001:db8::83").stream())
.collect(Collectors.toSet());
Acl acl2 = new Acl(Set.of(22, 4443), testNodes2, Set.of());
Acl acl2 = new Acl(Set.of(22, 4443), Set.of(), testNodes2, Set.of());

assertEquals("""
-P INPUT ACCEPT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
public record NodeAcl(Node node,
Set<TrustedNode> trustedNodes,
Set<String> trustedNetworks,
Set<Integer> trustedPorts) {
Set<Integer> trustedPorts,
Set<Integer> trustedUdpPorts) {

private static final Set<Integer> RPC_PORTS = Set.of(19070);
private static final int WIREGUARD_PORT = 51820;

public NodeAcl {
Objects.requireNonNull(node, "node must be non-null");
Expand All @@ -40,6 +42,7 @@ public record NodeAcl(Node node,
public static NodeAcl from(Node node, NodeList allNodes, LoadBalancers loadBalancers) {
Set<TrustedNode> trustedNodes = new TreeSet<>(Comparator.comparing(TrustedNode::hostname));
Set<Integer> trustedPorts = new LinkedHashSet<>();
Set<Integer> trustedUdpPorts = new LinkedHashSet<>();
Set<String> trustedNetworks = new LinkedHashSet<>();

// For all cases below, trust:
Expand Down Expand Up @@ -86,10 +89,12 @@ public static NodeAcl from(Node node, NodeList allNodes, LoadBalancers loadBalan
// - port 19070 (RPC) from all tenant nodes (and their hosts, in case traffic is NAT-ed via parent)
// - port 19070 (RPC) from all proxy nodes (and their hosts, in case traffic is NAT-ed via parent)
// - port 4443 from the world
// - udp port 51820 from the world
trustedNodes.addAll(TrustedNode.of(allNodes.nodeType(NodeType.host, NodeType.tenant,
NodeType.proxyhost, NodeType.proxy),
RPC_PORTS));
trustedPorts.add(4443);
trustedUdpPorts.add(WIREGUARD_PORT);
}
case proxy -> {
// Proxy nodes trust:
Expand All @@ -109,7 +114,7 @@ public static NodeAcl from(Node node, NodeList allNodes, LoadBalancers loadBalan
default -> throw new IllegalArgumentException("Don't know how to create ACL for " + node +
" of type " + node.type());
}
return new NodeAcl(node, trustedNodes, trustedNetworks, trustedPorts);
return new NodeAcl(node, trustedNodes, trustedNetworks, trustedPorts, trustedUdpPorts);
}

public record TrustedNode(String hostname, NodeType type, Set<String> ipAddresses, Set<Integer> ports) {
Expand Down

0 comments on commit 09f909c

Please sign in to comment.