Bug 1752896: In NAT simulator, check that incoming packets are arriving at the right port in addition to checking that they are arriving _from_ the right port. r=mjf

Also, some logging that was helpful in diagnosing the problem.

Differential Revision: https://phabricator.services.mozilla.com/D137529
This commit is contained in:
Byron Campen [:bwc]
2022-02-08 23:37:54 +00:00
parent ac39525ec7
commit 09eeb78b97
5 changed files with 118 additions and 13 deletions

View File

@@ -3,6 +3,7 @@
* You can obtain one at http://mozilla.org/MPL/2.0/. */ * You can obtain one at http://mozilla.org/MPL/2.0/. */
#include <string> #include <string>
#include <set> #include <set>
#include <iomanip>
extern "C" { extern "C" {
#include "nr_api.h" #include "nr_api.h"
@@ -10,6 +11,7 @@ extern "C" {
#include "stun.h" #include "stun.h"
} }
#include "logging.h"
#include "mozilla/Attributes.h" #include "mozilla/Attributes.h"
#include "mozilla/net/DNS.h" #include "mozilla/net/DNS.h"
#include "stun_socket_filter.h" #include "stun_socket_filter.h"
@@ -17,6 +19,8 @@ extern "C" {
namespace { namespace {
MOZ_MTLOG_MODULE("mtransport")
class NetAddrCompare { class NetAddrCompare {
public: public:
bool operator()(const mozilla::net::NetAddr& lhs, bool operator()(const mozilla::net::NetAddr& lhs,
@@ -81,6 +85,19 @@ class PendingSTUNRequest {
const bool is_id_set_; const bool is_id_set_;
}; };
static uint16_t GetPortInfallible(const mozilla::net::NetAddr& aAddr) {
uint16_t result = 0;
(void)aAddr.GetPort(&result);
return result;
}
static std::ostream& operator<<(std::ostream& aStream, UINT12 aId) {
for (int octet : aId.octet) {
aStream << std::hex << std::setfill('0') << std::setw(2) << octet;
}
return aStream;
}
class STUNUDPSocketFilter : public nsISocketFilter { class STUNUDPSocketFilter : public nsISocketFilter {
public: public:
STUNUDPSocketFilter() : white_list_(), pending_requests_() {} STUNUDPSocketFilter() : white_list_(), pending_requests_() {}
@@ -127,6 +144,9 @@ bool STUNUDPSocketFilter::filter_incoming_packet(
uint32_t len) { uint32_t len) {
// Check white list // Check white list
if (white_list_.find(*remote_addr) != white_list_.end()) { if (white_list_.find(*remote_addr) != white_list_.end()) {
MOZ_MTLOG(ML_DEBUG, __func__ << this << " Address in whitelist: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr));
return true; return true;
} }
@@ -143,6 +163,12 @@ bool STUNUDPSocketFilter::filter_incoming_packet(
pending_requests_.erase(it); pending_requests_.erase(it);
response_allowed_.erase(pending_req); response_allowed_.erase(pending_req);
white_list_.insert(*remote_addr); white_list_.insert(*remote_addr);
MOZ_MTLOG(ML_DEBUG, __func__ << this
<< " Allowing known STUN response, "
"remembering address in whitelist: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr)
<< " id=" << msg->id);
return true; return true;
} }
} }
@@ -153,6 +179,12 @@ bool STUNUDPSocketFilter::filter_incoming_packet(
const nr_stun_message_header* msg = const nr_stun_message_header* msg =
reinterpret_cast<const nr_stun_message_header*>(data); reinterpret_cast<const nr_stun_message_header*>(data);
response_allowed_.insert(PendingSTUNRequest(*remote_addr, msg->id)); response_allowed_.insert(PendingSTUNRequest(*remote_addr, msg->id));
MOZ_MTLOG(
ML_DEBUG,
__func__ << this
<< " Allowing STUN request, will allow packets in return: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr) << " id=" << msg->id);
return true; return true;
} }
// Lastly if we have send a STUN request to the destination of this // Lastly if we have send a STUN request to the destination of this
@@ -161,9 +193,21 @@ bool STUNUDPSocketFilter::filter_incoming_packet(
std::set<PendingSTUNRequest>::iterator it = std::set<PendingSTUNRequest>::iterator it =
pending_requests_.find(PendingSTUNRequest(*remote_addr)); pending_requests_.find(PendingSTUNRequest(*remote_addr));
if (it != pending_requests_.end()) { if (it != pending_requests_.end()) {
MOZ_MTLOG(
ML_DEBUG,
__func__
<< this
<< " Allowing packet from source while waiting for a response: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr));
return true; return true;
} }
MOZ_MTLOG(
ML_DEBUG,
__func__
<< " Disallowing packet that is neither a STUN request or response: "
<< remote_addr->ToString() << ":" << GetPortInfallible(*remote_addr));
return false; return false;
} }
@@ -172,6 +216,9 @@ bool STUNUDPSocketFilter::filter_outgoing_packet(
uint32_t len) { uint32_t len) {
// Check white list // Check white list
if (white_list_.find(*remote_addr) != white_list_.end()) { if (white_list_.find(*remote_addr) != white_list_.end()) {
MOZ_MTLOG(ML_DEBUG, __func__ << this << " Address in whitelist: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr));
return true; return true;
} }
@@ -182,6 +229,12 @@ bool STUNUDPSocketFilter::filter_outgoing_packet(
const nr_stun_message_header* msg = const nr_stun_message_header* msg =
reinterpret_cast<const nr_stun_message_header*>(data); reinterpret_cast<const nr_stun_message_header*>(data);
pending_requests_.insert(PendingSTUNRequest(*remote_addr, msg->id)); pending_requests_.insert(PendingSTUNRequest(*remote_addr, msg->id));
MOZ_MTLOG(
ML_DEBUG,
__func__ << this
<< " Allowing STUN request, will allow packets in return: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr) << " id=" << msg->id);
return true; return true;
} }
@@ -196,10 +249,27 @@ bool STUNUDPSocketFilter::filter_outgoing_packet(
if (it != response_allowed_.end()) { if (it != response_allowed_.end()) {
white_list_.insert(*remote_addr); white_list_.insert(*remote_addr);
response_allowed_.erase(it); response_allowed_.erase(it);
MOZ_MTLOG(ML_DEBUG, __func__ << this
<< " Allowing known STUN response, "
"remembering address in whitelist: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr)
<< " id=" << msg->id);
return true; return true;
} }
MOZ_MTLOG(ML_DEBUG,
__func__ << this << " Disallowing unknown STUN response: "
<< remote_addr->ToString() << ":"
<< GetPortInfallible(*remote_addr) << " id=" << msg->id);
return false;
} }
MOZ_MTLOG(
ML_DEBUG,
__func__
<< " Disallowing packet that is neither a STUN request or response: "
<< remote_addr->ToString() << ":" << GetPortInfallible(*remote_addr));
return false; return false;
} }

View File

@@ -392,10 +392,11 @@ int TestNrSocket::recvfrom(void* buf, size_t maxlen, size_t* len, int flags,
if (readable_socket_) { if (readable_socket_) {
// If any of the external sockets got data, see if it will be passed through // If any of the external sockets got data, see if it will be passed through
r = readable_socket_->recvfrom(buf, maxlen, len, 0, from); r = readable_socket_->recvfrom(buf, maxlen, len, 0, from);
const nr_transport_addr to = readable_socket_->my_addr();
readable_socket_ = nullptr; readable_socket_ = nullptr;
if (!r) { if (!r) {
PortMapping* port_mapping_used; PortMapping* port_mapping_used;
ingress_allowed = allow_ingress(*from, &port_mapping_used); ingress_allowed = allow_ingress(to, *from, &port_mapping_used);
if (ingress_allowed) { if (ingress_allowed) {
r_log(LOG_GENERIC, LOG_DEBUG, "TestNrSocket %s received from %s via %s", r_log(LOG_GENERIC, LOG_DEBUG, "TestNrSocket %s received from %s via %s",
internal_socket_->my_addr().as_string, from->as_string, internal_socket_->my_addr().as_string, from->as_string,
@@ -440,17 +441,34 @@ int TestNrSocket::recvfrom(void* buf, size_t maxlen, size_t* len, int flags,
return r; return r;
} }
bool TestNrSocket::allow_ingress(const nr_transport_addr& from, bool TestNrSocket::allow_ingress(const nr_transport_addr& to,
const nr_transport_addr& from,
PortMapping** port_mapping_used) const { PortMapping** port_mapping_used) const {
// This is only called for traffic arriving at a port mapping // This is only called for traffic arriving at a port mapping
MOZ_ASSERT(nat_->enabled_); MOZ_ASSERT(nat_->enabled_);
MOZ_ASSERT(!nat_->is_an_internal_tuple(from)); MOZ_ASSERT(!nat_->is_an_internal_tuple(from));
*port_mapping_used = get_port_mapping(from, nat_->filtering_type_); // Find the port mapping (if any) that this packet landed on
if (!(*port_mapping_used)) { for (PortMapping* port_mapping : port_mappings_) {
if (!nr_transport_addr_cmp(&to, &port_mapping->external_socket_->my_addr(),
NR_TRANSPORT_ADDR_CMP_MODE_ALL)) {
*port_mapping_used = port_mapping;
}
}
if (NS_WARN_IF(!(*port_mapping_used))) {
MOZ_ASSERT(false);
r_log(LOG_GENERIC, LOG_INFO, r_log(LOG_GENERIC, LOG_INFO,
"TestNrSocket %s denying ingress from %s: " "TestNrSocket %s denying ingress from %s: "
"Filtered", "No port mapping for this local port! What?",
internal_socket_->my_addr().as_string, from.as_string);
return false;
}
if (!port_mapping_matches(**port_mapping_used, from, nat_->filtering_type_)) {
r_log(LOG_GENERIC, LOG_INFO,
"TestNrSocket %s denying ingress from %s: "
"Filtered (no port mapping for source)",
internal_socket_->my_addr().as_string, from.as_string); internal_socket_->my_addr().as_string, from.as_string);
return false; return false;
} }
@@ -872,6 +890,18 @@ bool TestNrSocket::is_tcp_connection_behind_nat() const {
TestNrSocket::PortMapping* TestNrSocket::get_port_mapping( TestNrSocket::PortMapping* TestNrSocket::get_port_mapping(
const nr_transport_addr& remote_address, const nr_transport_addr& remote_address,
TestNat::NatBehavior filter) const { TestNat::NatBehavior filter) const {
for (PortMapping* port_mapping : port_mappings_) {
if (port_mapping_matches(*port_mapping, remote_address, filter)) {
return port_mapping;
}
}
return nullptr;
}
/* static */
bool TestNrSocket::port_mapping_matches(const PortMapping& port_mapping,
const nr_transport_addr& remote_addr,
TestNat::NatBehavior filter) {
int compare_flags; int compare_flags;
switch (filter) { switch (filter) {
case TestNat::ENDPOINT_INDEPENDENT: case TestNat::ENDPOINT_INDEPENDENT:
@@ -885,13 +915,8 @@ TestNrSocket::PortMapping* TestNrSocket::get_port_mapping(
break; break;
} }
for (PortMapping* port_mapping : port_mappings_) { return !nr_transport_addr_cmp(&remote_addr, &port_mapping.remote_address_,
if (!nr_transport_addr_cmp(&remote_address, &port_mapping->remote_address_, compare_flags);
compare_flags)) {
return port_mapping;
}
}
return nullptr;
} }
TestNrSocket::PortMapping* TestNrSocket::create_port_mapping( TestNrSocket::PortMapping* TestNrSocket::create_port_mapping(

View File

@@ -310,7 +310,7 @@ class TestNrSocket : public NrSocketBase {
}; };
bool is_port_mapping_stale(const PortMapping& port_mapping) const; bool is_port_mapping_stale(const PortMapping& port_mapping) const;
bool allow_ingress(const nr_transport_addr& from, bool allow_ingress(const nr_transport_addr& to, const nr_transport_addr& from,
PortMapping** port_mapping_used) const; PortMapping** port_mapping_used) const;
void destroy_stale_port_mappings(); void destroy_stale_port_mappings();
@@ -330,6 +330,9 @@ class TestNrSocket : public NrSocketBase {
PortMapping* get_port_mapping(const nr_transport_addr& remote_addr, PortMapping* get_port_mapping(const nr_transport_addr& remote_addr,
TestNat::NatBehavior filter) const; TestNat::NatBehavior filter) const;
static bool port_mapping_matches(const PortMapping& port_mapping,
const nr_transport_addr& remote_addr,
TestNat::NatBehavior filter);
PortMapping* create_port_mapping( PortMapping* create_port_mapping(
const nr_transport_addr& remote_addr, const nr_transport_addr& remote_addr,
const RefPtr<NrSocketBase>& external_socket) const; const RefPtr<NrSocketBase>& external_socket) const;

View File

@@ -104,6 +104,7 @@ mozilla::ipc::IPCResult UDPSocketParent::RecvBind(
UDPSOCKET_LOG( UDPSOCKET_LOG(
("%s: SendCallbackOpened: %s:%u", __FUNCTION__, addr.get(), port)); ("%s: SendCallbackOpened: %s:%u", __FUNCTION__, addr.get(), port));
mAddress = {addr, port};
mozilla::Unused << SendCallbackOpened(UDPAddressInfo(addr, port)); mozilla::Unused << SendCallbackOpened(UDPAddressInfo(addr, port));
return IPC_OK(); return IPC_OK();
@@ -319,6 +320,9 @@ mozilla::ipc::IPCResult UDPSocketParent::RecvOutgoingData(
bool allowed; bool allowed;
const nsTArray<uint8_t>& data(aData.get_ArrayOfuint8_t()); const nsTArray<uint8_t>& data(aData.get_ArrayOfuint8_t());
UDPSOCKET_LOG(("%s(%s:%d): Filtering outgoing packet", __FUNCTION__,
mAddress.addr().get(), mAddress.port()));
rv = mFilter->FilterPacket(&aAddr.get_NetAddr(), data.Elements(), rv = mFilter->FilterPacket(&aAddr.get_NetAddr(), data.Elements(),
data.Length(), nsISocketFilter::SF_OUTGOING, data.Length(), nsISocketFilter::SF_OUTGOING,
&allowed); &allowed);
@@ -489,6 +493,8 @@ UDPSocketParent::OnPacketReceived(nsIUDPSocket* aSocket,
bool allowed; bool allowed;
mozilla::net::NetAddr addr; mozilla::net::NetAddr addr;
fromAddr->GetNetAddr(&addr); fromAddr->GetNetAddr(&addr);
UDPSOCKET_LOG(("%s(%s:%d): Filtering incoming packet", __FUNCTION__,
mAddress.addr().get(), mAddress.port()));
nsresult rv = mFilter->FilterPacket(&addr, (const uint8_t*)buffer, len, nsresult rv = mFilter->FilterPacket(&addr, (const uint8_t*)buffer, len,
nsISocketFilter::SF_INCOMING, &allowed); nsISocketFilter::SF_INCOMING, &allowed);
// Receiving unallowed data, drop. // Receiving unallowed data, drop.

View File

@@ -75,6 +75,7 @@ class UDPSocketParent : public mozilla::net::PUDPSocketParent,
nsCOMPtr<nsIUDPSocket> mSocket; nsCOMPtr<nsIUDPSocket> mSocket;
nsCOMPtr<nsISocketFilter> mFilter; nsCOMPtr<nsISocketFilter> mFilter;
nsCOMPtr<nsIPrincipal> mPrincipal; nsCOMPtr<nsIPrincipal> mPrincipal;
UDPAddressInfo mAddress;
}; };
} // namespace dom } // namespace dom