]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/netinet/flow_divert.c
xnu-3789.70.16.tar.gz
[apple/xnu.git] / bsd / netinet / flow_divert.c
index cc2e2c8fe54a2c4c19d149d6368f12d40040d747..1e46e42c4167c7534eaf071a04b196381f50e7b4 100644 (file)
@@ -63,6 +63,7 @@
 #include <dev/random/randomdev.h>
 #include <libkern/crypto/sha1.h>
 #include <libkern/crypto/crypto_internal.h>
+#include <os/log.h>
 
 #define FLOW_DIVERT_CONNECT_STARTED            0x00000001
 #define FLOW_DIVERT_READ_CLOSED                        0x00000002
 #define FLOW_DIVERT_TRANSFERRED                        0x00000020
 #define FLOW_DIVERT_HAS_HMAC            0x00000040
 
-#define FDLOG(level, pcb, format, ...) do {                                                                                    \
-       if (level <= (pcb)->log_level) {                                                                                                \
-               log((level > LOG_NOTICE ? LOG_NOTICE : level), "%s (%u): " format "\n", __FUNCTION__, (pcb)->hash, __VA_ARGS__);        \
-       }                                                                                                                                                               \
-} while (0)
+#define FDLOG(level, pcb, format, ...) \
+       os_log_with_type(OS_LOG_DEFAULT, flow_divert_syslog_type_to_oslog_type(level), "(%u): " format "\n", (pcb)->hash, __VA_ARGS__)
 
-#define FDLOG0(level, pcb, msg) do {                                                                                           \
-       if (level <= (pcb)->log_level) {                                                                                                \
-               log((level > LOG_NOTICE ? LOG_NOTICE : level), "%s (%u): %s\n", __FUNCTION__, (pcb)->hash, msg);                                \
-       }                                                                                                                                                               \
-} while (0)
+#define FDLOG0(level, pcb, msg) \
+       os_log_with_type(OS_LOG_DEFAULT, flow_divert_syslog_type_to_oslog_type(level), "(%u): " msg "\n", (pcb)->hash)
 
 #define FDRETAIN(pcb)                  if ((pcb) != NULL) OSIncrementAtomic(&(pcb)->ref_count)
 #define FDRELEASE(pcb)                                                                                                         \
@@ -95,7 +90,7 @@
 #define FDLOCK(pcb)                                            lck_mtx_lock(&(pcb)->mtx)
 #define FDUNLOCK(pcb)                                  lck_mtx_unlock(&(pcb)->mtx)
 
-#define FD_CTL_SENDBUFF_SIZE                   (2 * FLOW_DIVERT_CHUNK_SIZE)
+#define FD_CTL_SENDBUFF_SIZE                   (128 * 1024)
 #define FD_CTL_RCVBUFF_SIZE                            (128 * 1024)
 
 #define GROUP_BIT_CTL_ENQUEUE_BLOCKED  0
 #define GROUP_COUNT_MAX                                        32
 #define FLOW_DIVERT_MAX_NAME_SIZE              4096
 #define FLOW_DIVERT_MAX_KEY_SIZE               1024
-
-#define DNS_SERVICE_GROUP_UNIT                 (GROUP_COUNT_MAX + 1)
+#define FLOW_DIVERT_MAX_TRIE_MEMORY            (1024 * 1024)
 
 struct flow_divert_trie_node
 {
        uint16_t start;
        uint16_t length;
        uint16_t child_map;
-       uint32_t group_unit;
-};
-
-struct flow_divert_trie
-{
-       struct flow_divert_trie_node *nodes;
-       uint16_t *child_maps;
-       uint8_t *bytes;
-       void *memory;
-       size_t nodes_count;
-       size_t child_maps_count;
-       size_t bytes_count;
-       size_t nodes_free_next;
-       size_t child_maps_free_next;
-       size_t bytes_free_next;
-       uint16_t root;
 };
 
 #define CHILD_MAP_SIZE                 256
@@ -140,7 +118,6 @@ static struct flow_divert_pcb               nil_pcb;
 decl_lck_rw_data(static, g_flow_divert_group_lck);
 static struct flow_divert_group                **g_flow_divert_groups                  = NULL;
 static uint32_t                                                g_active_group_count                    = 0;
-static struct flow_divert_trie         g_signing_id_trie;
 
 static lck_grp_attr_t                          *flow_divert_grp_attr                   = NULL;
 static lck_attr_t                                      *flow_divert_mtx_attr                   = NULL;
@@ -186,6 +163,17 @@ flow_divert_has_pcb_local_address(const struct inpcb *inp);
 static void
 flow_divert_disconnect_socket(struct socket *so);
 
+static inline uint8_t
+flow_divert_syslog_type_to_oslog_type(int syslog_type)
+{
+       switch (syslog_type) {
+               case LOG_ERR: return OS_LOG_TYPE_ERROR;
+               case LOG_INFO: return OS_LOG_TYPE_INFO;
+               case LOG_DEBUG: return OS_LOG_TYPE_DEBUG;
+               default: return OS_LOG_TYPE_DEFAULT;
+       }
+}
+
 static inline int
 flow_divert_pcb_cmp(const struct flow_divert_pcb *pcb_a, const struct flow_divert_pcb *pcb_b)
 {
@@ -211,8 +199,6 @@ flow_divert_packet_type2str(uint8_t packet_type)
                        return "read notification";
                case FLOW_DIVERT_PKT_PROPERTIES_UPDATE:
                        return "properties update";
-               case FLOW_DIVERT_PKT_APP_MAP_UPDATE:
-                       return "app map update";
                case FLOW_DIVERT_PKT_APP_MAP_CREATE:
                        return "app map create";
                default:
@@ -418,7 +404,7 @@ flow_divert_packet_append_tlv(mbuf_t packet, uint8_t type, uint32_t length, cons
 
        error = mbuf_copyback(packet, mbuf_pkthdr_len(packet), sizeof(net_length), &net_length, MBUF_DONTWAIT);
        if (error) {
-               FDLOG(LOG_ERR, &nil_pcb, "failed to append the length (%lu)", length);
+               FDLOG(LOG_ERR, &nil_pcb, "failed to append the length (%u)", length);
                return error;
        }
 
@@ -1005,6 +991,8 @@ flow_divert_create_connect_packet(struct flow_divert_pcb *fd_cb, struct sockaddr
        char                    *signing_id = NULL;
        int                             free_signing_id = 0;
        mbuf_t                  connect_packet = NULL;
+       proc_t                  src_proc = p;
+       int                             release_proc = 0;
 
        error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_CONNECT, &connect_packet);
        if (error) {
@@ -1027,69 +1015,63 @@ flow_divert_create_connect_packet(struct flow_divert_pcb *fd_cb, struct sockaddr
        }
 
        socket_unlock(so, 0);
-       if (g_signing_id_trie.root != NULL_TRIE_IDX) {
-               proc_t src_proc = p;
-               int release_proc = 0;
-
-               if (signing_id == NULL) {
-                       release_proc = flow_divert_get_src_proc(so, &src_proc);
-                       if (src_proc != PROC_NULL) {
-                               proc_lock(src_proc);
-                               if (src_proc->p_csflags & CS_VALID) {
-                    const char * cs_id;
-                    cs_id = cs_identity_get(src_proc);
-                    signing_id = __DECONST(char *, cs_id);
-                               } else {
-                                       FDLOG0(LOG_WARNING, fd_cb, "Signature is invalid");
-                               }
+
+       if (signing_id == NULL) {
+               release_proc = flow_divert_get_src_proc(so, &src_proc);
+               if (src_proc != PROC_NULL) {
+                       proc_lock(src_proc);
+                       if (src_proc->p_csflags & (CS_VALID|CS_DEBUGGED)) {
+                               const char * cs_id;
+                               cs_id = cs_identity_get(src_proc);
+                               signing_id = __DECONST(char *, cs_id);
                        } else {
-                               FDLOG0(LOG_WARNING, fd_cb, "Failed to determine the current proc");
+                               FDLOG0(LOG_WARNING, fd_cb, "Signature is invalid");
                        }
                } else {
-                       src_proc = PROC_NULL;
+                       FDLOG0(LOG_WARNING, fd_cb, "Failed to determine the current proc");
                }
+       } else {
+               src_proc = PROC_NULL;
+       }
 
-               if (signing_id != NULL) {
-                       uint16_t result = NULL_TRIE_IDX;
-                       lck_rw_lock_shared(&g_flow_divert_group_lck);
-                       result = flow_divert_trie_search(&g_signing_id_trie, (uint8_t *)signing_id);
-                       lck_rw_done(&g_flow_divert_group_lck);
-                       if (result != NULL_TRIE_IDX) {
-                               error = 0;
-                               FDLOG(LOG_INFO, fd_cb, "%s matched", signing_id);
+       if (signing_id != NULL) {
+               uint16_t result = NULL_TRIE_IDX;
+               lck_rw_lock_shared(&fd_cb->group->lck);
+               result = flow_divert_trie_search(&fd_cb->group->signing_id_trie, (uint8_t *)signing_id);
+               lck_rw_done(&fd_cb->group->lck);
+               if (result != NULL_TRIE_IDX) {
+                       error = 0;
+                       FDLOG(LOG_INFO, fd_cb, "%s matched", signing_id);
 
-                               error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_SIGNING_ID, strlen(signing_id), signing_id);
-                               if (error == 0) {
-                                       if (src_proc != PROC_NULL) {
-                                               unsigned char cdhash[SHA1_RESULTLEN];
-                                               error = proc_getcdhash(src_proc, cdhash);
-                                               if (error == 0) {
-                                                       error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_CDHASH, sizeof(cdhash), cdhash);
-                                                       if (error) {
-                                                               FDLOG(LOG_ERR, fd_cb, "failed to append the cdhash: %d", error);
-                                                       }
-                                               } else {
-                                                       FDLOG(LOG_ERR, fd_cb, "failed to get the cdhash: %d", error);
+                       error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_SIGNING_ID, strlen(signing_id), signing_id);
+                       if (error == 0) {
+                               if (src_proc != PROC_NULL) {
+                                       unsigned char cdhash[SHA1_RESULTLEN];
+                                       error = proc_getcdhash(src_proc, cdhash);
+                                       if (error == 0) {
+                                               error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_CDHASH, sizeof(cdhash), cdhash);
+                                               if (error) {
+                                                       FDLOG(LOG_ERR, fd_cb, "failed to append the cdhash: %d", error);
                                                }
+                                       } else {
+                                               FDLOG(LOG_ERR, fd_cb, "failed to get the cdhash: %d", error);
                                        }
-                               } else {
-                                       FDLOG(LOG_ERR, fd_cb, "failed to append the signing ID: %d", error);
                                }
                        } else {
-                               FDLOG(LOG_WARNING, fd_cb, "%s did not match", signing_id);
+                               FDLOG(LOG_ERR, fd_cb, "failed to append the signing ID: %d", error);
                        }
                } else {
-                       FDLOG0(LOG_WARNING, fd_cb, "Failed to get the code signing identity");
+                       FDLOG(LOG_WARNING, fd_cb, "%s did not match", signing_id);
                }
+       } else {
+               FDLOG0(LOG_WARNING, fd_cb, "Failed to get the code signing identity");
+       }
 
-               if (src_proc != PROC_NULL) {
-                       proc_unlock(src_proc);
-                       if (release_proc) {
-                               proc_rele(src_proc);
-                       }
+       if (src_proc != PROC_NULL) {
+               proc_unlock(src_proc);
+               if (release_proc) {
+                       proc_rele(src_proc);
                }
-       } else {
-               FDLOG0(LOG_WARNING, fd_cb, "The signing ID trie is empty");
        }
        socket_lock(so, 0);
 
@@ -1465,7 +1447,7 @@ flow_divert_send_buffered_data(struct flow_divert_pcb *fd_cb, Boolean force)
                                }
                        }
                        data_len = mbuf_pkthdr_len(m);
-                       FDLOG(LOG_DEBUG, fd_cb, "mbuf_copym() data_len = %u", data_len);
+                       FDLOG(LOG_DEBUG, fd_cb, "mbuf_copym() data_len = %lu", data_len);
                        error = mbuf_copym(m, 0, data_len, MBUF_DONTWAIT, &data);
                        if (error) {
                                FDLOG(LOG_ERR, fd_cb, "mbuf_copym failed: %d", error);
@@ -1572,7 +1554,7 @@ flow_divert_send_app_data(struct flow_divert_pcb *fd_cb, mbuf_t data, struct soc
                if (to_send) {
                        error = flow_divert_send_data_packet(fd_cb, data, to_send, toaddr, FALSE);
                        if (error) {
-                               FDLOG(LOG_ERR, fd_cb, "flow_divert_send_data_packet failed. send data size = %u", to_send);
+                               FDLOG(LOG_ERR, fd_cb, "flow_divert_send_data_packet failed. send data size = %lu", to_send);
                        } else {
                                fd_cb->send_window -= to_send;
                        }
@@ -1675,6 +1657,7 @@ flow_divert_handle_connect_result(struct flow_divert_pcb *fd_cb, mbuf_t packet,
        int                                                     out_if_index            = 0;
        struct sockaddr_storage         remote_address;
        uint32_t                                        send_window;
+       uint32_t                                        app_data_length         = 0;
 
        memset(&local_address, 0, sizeof(local_address));
        memset(&remote_address, 0, sizeof(remote_address));
@@ -1695,32 +1678,37 @@ flow_divert_handle_connect_result(struct flow_divert_pcb *fd_cb, mbuf_t packet,
 
        error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_CTL_UNIT, sizeof(ctl_unit), &ctl_unit, NULL);
        if (error) {
-               FDLOG(LOG_ERR, fd_cb, "failed to get the control unit: %d", error);
-               return;
+               FDLOG0(LOG_INFO, fd_cb, "No control unit provided in the connect result");
        }
 
        error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_LOCAL_ADDR, sizeof(local_address), &local_address, NULL);
        if (error) {
-               FDLOG0(LOG_NOTICE, fd_cb, "No local address provided");
+               FDLOG0(LOG_INFO, fd_cb, "No local address provided");
        }
 
        error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_REMOTE_ADDR, sizeof(remote_address), &remote_address, NULL);
        if (error) {
-               FDLOG0(LOG_NOTICE, fd_cb, "No remote address provided");
+               FDLOG0(LOG_INFO, fd_cb, "No remote address provided");
        }
 
        error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_OUT_IF_INDEX, sizeof(out_if_index), &out_if_index, NULL);
        if (error) {
-               FDLOG0(LOG_NOTICE, fd_cb, "No output if index provided");
+               FDLOG0(LOG_INFO, fd_cb, "No output if index provided");
+       }
+
+       error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_APP_DATA, 0, NULL, &app_data_length);
+       if (error) {
+               FDLOG0(LOG_INFO, fd_cb, "No application data provided in connect result");
        }
 
+       error = 0;
        connect_error   = ntohl(connect_error);
        ctl_unit                = ntohl(ctl_unit);
 
        lck_rw_lock_shared(&g_flow_divert_group_lck);
 
-       if (connect_error == 0) {
-               if (ctl_unit == 0 || ctl_unit >= GROUP_COUNT_MAX) {
+       if (connect_error == 0 && ctl_unit > 0) {
+               if (ctl_unit >= GROUP_COUNT_MAX) {
                        FDLOG(LOG_ERR, fd_cb, "Connect result contains an invalid control unit: %u", ctl_unit);
                        error = EINVAL;
                } else if (g_flow_divert_groups == NULL || g_active_group_count == 0) {
@@ -1773,6 +1761,27 @@ flow_divert_handle_connect_result(struct flow_divert_pcb *fd_cb, mbuf_t packet,
                        goto set_socket_state;
                }
 
+               if (app_data_length > 0) {
+                       uint8_t *app_data = NULL;
+                       MALLOC(app_data, uint8_t *, app_data_length, M_TEMP, M_WAITOK);
+                       if (app_data != NULL) {
+                               error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_APP_DATA, app_data_length, app_data, NULL);
+                               if (error == 0) {
+                                       FDLOG(LOG_INFO, fd_cb, "Got %u bytes of app data from the connect result", app_data_length);
+                                       if (fd_cb->app_data != NULL) {
+                                               FREE(fd_cb->app_data, M_TEMP);
+                                       }
+                                       fd_cb->app_data = app_data;
+                                       fd_cb->app_data_length = app_data_length;
+                               } else {
+                                       FDLOG(LOG_ERR, fd_cb, "Failed to copy %u bytes of application data from the connect result packet", app_data_length);
+                                       FREE(app_data, M_TEMP);
+                               }
+                       } else {
+                               FDLOG(LOG_ERR, fd_cb, "Failed to allocate a buffer of size %u to hold the application data from the connect result", app_data_length);
+                       }
+               }
+
                ifnet_head_lock_shared();
                if (out_if_index > 0 && out_if_index <= if_index) {
                        ifp = ifindex2ifnet[out_if_index];
@@ -1794,20 +1803,22 @@ flow_divert_handle_connect_result(struct flow_divert_pcb *fd_cb, mbuf_t packet,
                        goto set_socket_state;
                }
 
-               old_group = fd_cb->group;
+               if (grp != NULL) {
+                       old_group = fd_cb->group;
 
-               lck_rw_lock_exclusive(&old_group->lck);
-               lck_rw_lock_exclusive(&grp->lck);
+                       lck_rw_lock_exclusive(&old_group->lck);
+                       lck_rw_lock_exclusive(&grp->lck);
 
-               RB_REMOVE(fd_pcb_tree, &old_group->pcb_tree, fd_cb);
-               if (RB_INSERT(fd_pcb_tree, &grp->pcb_tree, fd_cb) != NULL) {
-                       panic("group with unit %u already contains a connection with hash %u", grp->ctl_unit, fd_cb->hash);
-               }
+                       RB_REMOVE(fd_pcb_tree, &old_group->pcb_tree, fd_cb);
+                       if (RB_INSERT(fd_pcb_tree, &grp->pcb_tree, fd_cb) != NULL) {
+                               panic("group with unit %u already contains a connection with hash %u", grp->ctl_unit, fd_cb->hash);
+                       }
 
-               fd_cb->group = grp;
+                       fd_cb->group = grp;
 
-               lck_rw_done(&grp->lck);
-               lck_rw_done(&old_group->lck);
+                       lck_rw_done(&grp->lck);
+                       lck_rw_done(&old_group->lck);
+               }
 
                fd_cb->send_window = ntohl(send_window);
 
@@ -2038,7 +2049,7 @@ flow_divert_handle_group_init(struct flow_divert_group *group, mbuf_t packet, in
        }
 
        if (key_size == 0 || key_size > FLOW_DIVERT_MAX_KEY_SIZE) {
-               FDLOG(LOG_ERR, &nil_pcb, "Invalid key size: %lu", key_size);
+               FDLOG(LOG_ERR, &nil_pcb, "Invalid key size: %u", key_size);
                return;
        }
 
@@ -2167,7 +2178,7 @@ flow_divert_handle_properties_update(struct flow_divert_pcb *fd_cb, mbuf_t packe
 }
 
 static void
-flow_divert_handle_app_map_create(mbuf_t packet, int offset)
+flow_divert_handle_app_map_create(struct flow_divert_group *group, mbuf_t packet, int offset)
 {
        size_t bytes_mem_size;
        size_t child_maps_mem_size;
@@ -2178,21 +2189,27 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
        size_t nodes_mem_size;
        int prefix_count = 0;
        int signing_id_count = 0;
+       size_t trie_memory_size = 0;
 
-       lck_rw_lock_exclusive(&g_flow_divert_group_lck);
+       lck_rw_lock_exclusive(&group->lck);
 
        /* Re-set the current trie */
-       if (g_signing_id_trie.memory != NULL) {
-               FREE(g_signing_id_trie.memory, M_TEMP);
+       if (group->signing_id_trie.memory != NULL) {
+               FREE(group->signing_id_trie.memory, M_TEMP);
        }
-       memset(&g_signing_id_trie, 0, sizeof(g_signing_id_trie));
-       g_signing_id_trie.root = NULL_TRIE_IDX;
+       memset(&group->signing_id_trie, 0, sizeof(group->signing_id_trie));
+       group->signing_id_trie.root = NULL_TRIE_IDX;
 
        memset(&new_trie, 0, sizeof(new_trie));
 
        /* Get the number of shared prefixes in the new set of signing ID strings */
        flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_PREFIX_COUNT, sizeof(prefix_count), &prefix_count, NULL);
 
+       if (prefix_count < 0) {
+               lck_rw_done(&group->lck);
+               return;
+       }
+
        /* Compute the number of signing IDs and the total amount of bytes needed to store them */
        for (cursor = flow_divert_packet_find_tlv(packet, offset, FLOW_DIVERT_TLV_SIGNING_ID, &error, 0);
             cursor >= 0;
@@ -2205,7 +2222,7 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
        }
 
        if (signing_id_count == 0) {
-               lck_rw_done(&g_flow_divert_group_lck);
+               lck_rw_done(&group->lck);
                return;
        }
 
@@ -2219,10 +2236,18 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
        child_maps_mem_size = (sizeof(*new_trie.child_maps) * CHILD_MAP_SIZE * new_trie.child_maps_count);
        bytes_mem_size = (sizeof(*new_trie.bytes) * new_trie.bytes_count);
 
-       MALLOC(new_trie.memory, void *, nodes_mem_size + child_maps_mem_size + bytes_mem_size, M_TEMP, M_WAITOK);
+       trie_memory_size = nodes_mem_size + child_maps_mem_size + bytes_mem_size;
+       if (trie_memory_size > FLOW_DIVERT_MAX_TRIE_MEMORY) {
+               FDLOG(LOG_ERR, &nil_pcb, "Trie memory size (%lu) is too big (maximum is %u)", trie_memory_size, FLOW_DIVERT_MAX_TRIE_MEMORY);
+               lck_rw_done(&group->lck);
+               return;
+       }
+
+       MALLOC(new_trie.memory, void *, trie_memory_size, M_TEMP, M_WAITOK);
        if (new_trie.memory == NULL) {
                FDLOG(LOG_ERR, &nil_pcb, "Failed to allocate %lu bytes of memory for the signing ID trie",
                      nodes_mem_size + child_maps_mem_size + bytes_mem_size);
+               lck_rw_done(&group->lck);
                return;
        }
 
@@ -2249,20 +2274,10 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
                uint32_t sid_size = 0;
                flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
                if (new_trie.bytes_free_next + sid_size <= new_trie.bytes_count) {
-                       boolean_t is_dns;
                        uint16_t new_node_idx;
                        flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, sid_size, &TRIE_BYTE(&new_trie, new_trie.bytes_free_next), NULL);
-                       is_dns = (sid_size == sizeof(FLOW_DIVERT_DNS_SERVICE_SIGNING_ID) - 1 && 
-                                 !memcmp(&TRIE_BYTE(&new_trie, new_trie.bytes_free_next),
-                                         FLOW_DIVERT_DNS_SERVICE_SIGNING_ID,
-                                         sid_size));
                        new_node_idx = flow_divert_trie_insert(&new_trie, new_trie.bytes_free_next, sid_size);
-                       if (new_node_idx != NULL_TRIE_IDX) {
-                               if (is_dns) {
-                                       FDLOG(LOG_INFO, &nil_pcb, "Setting group unit for %s to %d", FLOW_DIVERT_DNS_SERVICE_SIGNING_ID, DNS_SERVICE_GROUP_UNIT);
-                                       TRIE_NODE(&new_trie, new_node_idx).group_unit = DNS_SERVICE_GROUP_UNIT;
-                               }
-                       } else {
+                       if (new_node_idx == NULL_TRIE_IDX) {
                                insert_error = EINVAL;
                                break;
                        }
@@ -2274,72 +2289,12 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
        }
 
        if (!insert_error) {
-               g_signing_id_trie = new_trie;
+               group->signing_id_trie = new_trie;
        } else {
                FREE(new_trie.memory, M_TEMP);
        }
 
-       lck_rw_done(&g_flow_divert_group_lck);
-}
-
-static void
-flow_divert_handle_app_map_update(struct flow_divert_group *group, mbuf_t packet, int offset)
-{
-       int error = 0;
-       int cursor;
-       size_t max_size = 0;
-       uint8_t *signing_id;
-       uint32_t ctl_unit;
-
-       lck_rw_lock_shared(&group->lck);
-       ctl_unit = group->ctl_unit;
        lck_rw_done(&group->lck);
-
-       for (cursor = flow_divert_packet_find_tlv(packet, offset, FLOW_DIVERT_TLV_SIGNING_ID, &error, 0);
-            cursor >= 0;
-            cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1))
-       {
-               uint32_t sid_size = 0;
-               flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
-               if (sid_size > max_size) {
-                       max_size = sid_size;
-               }
-       }
-
-       MALLOC(signing_id, uint8_t *, max_size + 1, M_TEMP, M_WAITOK);
-       if (signing_id == NULL) {
-               FDLOG(LOG_ERR, &nil_pcb, "Failed to allocate a string to hold the signing ID (size %lu)", max_size);
-               return;
-       }
-
-       for (cursor = flow_divert_packet_find_tlv(packet, offset, FLOW_DIVERT_TLV_SIGNING_ID, &error, 0);
-            cursor >= 0;
-            cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1))
-       {
-               uint32_t signing_id_len = 0;
-               uint16_t node;
-
-               flow_divert_packet_get_tlv(packet,
-                               cursor, FLOW_DIVERT_TLV_SIGNING_ID, max_size, signing_id, &signing_id_len);
-
-               signing_id[signing_id_len] = '\0';
-
-               lck_rw_lock_exclusive(&g_flow_divert_group_lck);
-
-               node = flow_divert_trie_search(&g_signing_id_trie, signing_id);
-               if (node != NULL_TRIE_IDX) {
-                       if (TRIE_NODE(&g_signing_id_trie, node).group_unit != DNS_SERVICE_GROUP_UNIT) {
-                               FDLOG(LOG_INFO, &nil_pcb, "Setting %s to ctl unit %u", signing_id, group->ctl_unit);
-                               TRIE_NODE(&g_signing_id_trie, node).group_unit = ctl_unit;
-                       }
-               } else {
-                       FDLOG(LOG_ERR, &nil_pcb, "Failed to find signing ID %s", signing_id);
-               }
-
-               lck_rw_done(&g_flow_divert_group_lck);
-       }
-
-       FREE(signing_id, M_TEMP);
 }
 
 static int
@@ -2356,7 +2311,7 @@ flow_divert_input(mbuf_t packet, struct flow_divert_group *group)
        }
 
        if (mbuf_pkthdr_len(packet) > FD_CTL_RCVBUFF_SIZE) {
-               FDLOG(LOG_ERR, &nil_pcb, "got a bad packet, length (%lu) > %lu", mbuf_pkthdr_len(packet), FD_CTL_RCVBUFF_SIZE);
+               FDLOG(LOG_ERR, &nil_pcb, "got a bad packet, length (%lu) > %d", mbuf_pkthdr_len(packet), FD_CTL_RCVBUFF_SIZE);
                error = EINVAL;
                goto done;
        }
@@ -2376,10 +2331,7 @@ flow_divert_input(mbuf_t packet, struct flow_divert_group *group)
                                flow_divert_handle_group_init(group, packet, sizeof(hdr));
                                break;
                        case FLOW_DIVERT_PKT_APP_MAP_CREATE:
-                               flow_divert_handle_app_map_create(packet, sizeof(hdr));
-                               break;
-                       case FLOW_DIVERT_PKT_APP_MAP_UPDATE:
-                               flow_divert_handle_app_map_update(group, packet, sizeof(hdr));
+                               flow_divert_handle_app_map_create(group, packet, sizeof(hdr));
                                break;
                        default:
                                FDLOG(LOG_WARNING, &nil_pcb, "got an unknown message type: %d", hdr.packet_type);
@@ -2452,6 +2404,7 @@ flow_divert_close_all(struct flow_divert_group *group)
                        flow_divert_pcb_remove(fd_cb);
                        flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
                        fd_cb->so->so_error = ECONNABORTED;
+                       flow_divert_disconnect_socket(fd_cb->so);
                        socket_unlock(fd_cb->so, 0);
                }
                FDUNLOCK(fd_cb);
@@ -2870,13 +2823,9 @@ done:
 }
 
 static int
-flow_divert_connectx_out_common(struct socket *so, int af,
-    struct sockaddr_list **src_sl, struct sockaddr_list **dst_sl,
-    struct proc *p, uint32_t ifscope __unused, sae_associd_t aid __unused,
-    sae_connid_t *pcid, uint32_t flags __unused, void *arg __unused,
-    uint32_t arglen __unused, struct uio *auio, user_ssize_t *bytes_written)
+flow_divert_connectx_out_common(struct socket *so, struct sockaddr *dst,
+    struct proc *p, sae_connid_t *pcid, struct uio *auio, user_ssize_t *bytes_written)
 {
-       struct sockaddr_entry *src_se = NULL, *dst_se = NULL;
        struct inpcb *inp = sotoinpcb(so);
        int error;
 
@@ -2884,20 +2833,9 @@ flow_divert_connectx_out_common(struct socket *so, int af,
                return (EINVAL);
        }
 
-       VERIFY(dst_sl != NULL);
+       VERIFY(dst != NULL);
 
-       /* select source (if specified) and destination addresses */
-       error = in_selectaddrs(af, src_sl, &src_se, dst_sl, &dst_se);
-       if (error != 0) {
-               return (error);
-       }
-
-       VERIFY(*dst_sl != NULL && dst_se != NULL);
-       VERIFY(src_se == NULL || *src_sl != NULL);
-       VERIFY(dst_se->se_addr->sa_family == af);
-       VERIFY(src_se == NULL || src_se->se_addr->sa_family == af);
-
-       error = flow_divert_connect_out(so, dst_se->se_addr, p);
+       error = flow_divert_connect_out(so, dst, p);
 
        if (error != 0) {
                return error;
@@ -2940,26 +2878,22 @@ flow_divert_connectx_out_common(struct socket *so, int af,
 }
 
 static int
-flow_divert_connectx_out(struct socket *so, struct sockaddr_list **src_sl,
-    struct sockaddr_list **dst_sl, struct proc *p, uint32_t ifscope,
-    sae_associd_t aid, sae_connid_t *pcid, uint32_t flags, void *arg,
-    uint32_t arglen, struct uio *uio, user_ssize_t *bytes_written)
+flow_divert_connectx_out(struct socket *so, struct sockaddr *src __unused,
+    struct sockaddr *dst, struct proc *p, uint32_t ifscope __unused,
+    sae_associd_t aid __unused, sae_connid_t *pcid, uint32_t flags __unused, void *arg __unused,
+    uint32_t arglen __unused, struct uio *uio, user_ssize_t *bytes_written)
 {
-#pragma unused(uio, bytes_written)
-       return (flow_divert_connectx_out_common(so, AF_INET, src_sl, dst_sl,
-           p, ifscope, aid, pcid, flags, arg, arglen, uio, bytes_written));
+       return (flow_divert_connectx_out_common(so, dst, p, pcid, uio, bytes_written));
 }
 
 #if INET6
 static int
-flow_divert_connectx6_out(struct socket *so, struct sockaddr_list **src_sl,
-    struct sockaddr_list **dst_sl, struct proc *p, uint32_t ifscope,
-    sae_associd_t aid, sae_connid_t *pcid, uint32_t flags, void *arg,
-    uint32_t arglen, struct uio *uio, user_ssize_t *bytes_written)
+flow_divert_connectx6_out(struct socket *so, struct sockaddr *src __unused,
+    struct sockaddr *dst, struct proc *p, uint32_t ifscope __unused,
+    sae_associd_t aid __unused, sae_connid_t *pcid, uint32_t flags __unused, void *arg __unused,
+    uint32_t arglen __unused, struct uio *uio, user_ssize_t *bytes_written)
 {
-#pragma unused(uio, bytes_written)
-       return (flow_divert_connectx_out_common(so, AF_INET6, src_sl, dst_sl,
-           p, ifscope, aid, pcid, flags, arg, arglen, uio, bytes_written));
+       return (flow_divert_connectx_out_common(so, dst, p, pcid, uio, bytes_written));
 }
 #endif /* INET6 */
 
@@ -3443,6 +3377,9 @@ flow_divert_token_set(struct socket *so, struct sockopt *sopt)
        error = flow_divert_packet_get_tlv(token, 0, FLOW_DIVERT_TLV_KEY_UNIT, sizeof(key_unit), (void *)&key_unit, NULL);
        if (!error) {
                key_unit = ntohl(key_unit);
+               if (key_unit >= GROUP_COUNT_MAX) {
+                       key_unit = 0;
+               }
        } else if (error != ENOENT) {
                FDLOG(LOG_ERR, &nil_pcb, "Failed to get the key unit from the token: %d", error);
                goto done;
@@ -3595,6 +3532,12 @@ flow_divert_token_get(struct socket *so, struct sockopt *sopt)
                goto done;
        }
 
+       if (sopt->sopt_val == USER_ADDR_NULL) {
+               /* If the caller passed NULL to getsockopt, just set the size of the token and return */
+               sopt->sopt_valsize = mbuf_pkthdr_len(token);
+               goto done;
+       }
+
        error = soopt_mcopyout(sopt, token);
        if (error) {
                token = NULL;   /* For some reason, soopt_mcopyout() frees the mbuf if it fails */
@@ -3634,6 +3577,7 @@ flow_divert_kctl_connect(kern_ctl_ref kctlref __unused, struct sockaddr_ctl *sac
        RB_INIT(&new_group->pcb_tree);
        new_group->ctl_unit = sac->sc_unit;
        MBUFQ_INIT(&new_group->send_queue);
+       new_group->signing_id_trie.root = NULL_TRIE_IDX;
 
        lck_rw_lock_exclusive(&g_flow_divert_group_lck);
 
@@ -3670,7 +3614,6 @@ flow_divert_kctl_disconnect(kern_ctl_ref kctlref __unused, uint32_t unit, void *
 {
        struct flow_divert_group        *group  = NULL;
        errno_t                                         error   = 0;
-       uint16_t                                        node    = 0;
 
        if (unit >= GROUP_COUNT_MAX) {
                return EINVAL;
@@ -3699,6 +3642,14 @@ flow_divert_kctl_disconnect(kern_ctl_ref kctlref __unused, uint32_t unit, void *
                        group->token_key = NULL;
                        group->token_key_size = 0;
                }
+
+               /* Re-set the current trie */
+               if (group->signing_id_trie.memory != NULL) {
+                       FREE(group->signing_id_trie.memory, M_TEMP);
+               }
+               memset(&group->signing_id_trie, 0, sizeof(group->signing_id_trie));
+               group->signing_id_trie.root = NULL_TRIE_IDX;
+
                FREE_ZONE(group, sizeof(*group), M_FLOW_DIVERT_GROUP);
                g_flow_divert_groups[unit] = NULL;
                g_active_group_count--;
@@ -3711,13 +3662,6 @@ flow_divert_kctl_disconnect(kern_ctl_ref kctlref __unused, uint32_t unit, void *
                g_flow_divert_groups = NULL;
        }
 
-       /* Remove all signing IDs that point to this unit */
-       for (node = 0; node < g_signing_id_trie.nodes_count; node++) {
-               if (TRIE_NODE(&g_signing_id_trie, node).group_unit == unit) {
-                       TRIE_NODE(&g_signing_id_trie, node).group_unit = 0;
-               }
-       }
-
        lck_rw_done(&g_flow_divert_group_lck);
 
        return error;
@@ -3978,9 +3922,6 @@ flow_divert_init(void)
 
        lck_rw_init(&g_flow_divert_group_lck, flow_divert_mtx_grp, flow_divert_mtx_attr);
 
-       memset(&g_signing_id_trie, 0, sizeof(g_signing_id_trie));
-       g_signing_id_trie.root = NULL_TRIE_IDX;
-
 done:
        if (g_init_result != 0) {
                if (flow_divert_mtx_attr != NULL) {